Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CUDACore/src/library_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function Base.convert(::Type{cudaDataType}, T::DataType)
end
end

function Base.convert(::Type{Type}, T::cudaDataType)
function Base.Type(T::cudaDataType)
if T == R_16F
return Float16
elseif T == C_16F
Expand Down
2 changes: 1 addition & 1 deletion lib/cusparse/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function Base.convert(::Type{cusparseIndexType_t}, T::DataType)
end
end

function Base.convert(::Type{Type}, T::cusparseIndexType_t)
function Base.Type(T::cusparseIndexType_t)
if T == CUSPARSE_INDEX_32I
return Int32
elseif T == CUSPARSE_INDEX_64I
Expand Down
4 changes: 2 additions & 2 deletions lib/cusparse/test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ end
@test convert(cuSPARSE.cusparseIndexType_t, Int32) == cuSPARSE.CUSPARSE_INDEX_32I
@test convert(cuSPARSE.cusparseIndexType_t, Int64) == cuSPARSE.CUSPARSE_INDEX_64I
@test_throws ArgumentError("CUSPARSE type equivalent for index type Int8 does not exist!") convert(cuSPARSE.cusparseIndexType_t, Int8)
@test convert(Type, cuSPARSE.CUSPARSE_INDEX_32I) == Int32
@test convert(Type, cuSPARSE.CUSPARSE_INDEX_64I) == Int64
@test Type(cuSPARSE.CUSPARSE_INDEX_32I) == Int32
@test Type(cuSPARSE.CUSPARSE_INDEX_64I) == Int64

@test convert(cuSPARSE.cusparseIndexBase_t, 0) == cuSPARSE.CUSPARSE_INDEX_BASE_ZERO
@test convert(cuSPARSE.cusparseIndexBase_t, 1) == cuSPARSE.CUSPARSE_INDEX_BASE_ONE
Expand Down
2 changes: 1 addition & 1 deletion lib/custatevec/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ function Base.convert(::Type{custatevecComputeType_t}, T::DataType)
end
end

function Base.convert(::Type{Type}, T::custatevecComputeType_t)
function Base.Type(T::custatevecComputeType_t)
if T == CUSTATEVEC_COMPUTE_32F || T == CUSTATEVEC_COMPUTE_TF32
return Float32
elseif T == CUSTATEVEC_COMPUTE_64F
Expand Down
2 changes: 1 addition & 1 deletion lib/cutensor/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ mutable struct CuTensorPlan
cutensorPlanGetAttribute(handle(), plan_ref[], CUTENSOR_PLAN_REQUIRED_WORKSPACE, actualWorkspaceSize, sizeof(actualWorkspaceSize))
workspace = CuArray{UInt8}(undef, actualWorkspaceSize[])

obj = new(context(), plan_ref[], workspace, convert(Type, required_scalar_type[]))
obj = new(context(), plan_ref[], workspace, Type(required_scalar_type[]))
finalizer(CUDACore.unsafe_free!, obj)
return obj
end
Expand Down
2 changes: 1 addition & 1 deletion lib/cutensornet/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function Base.convert(::Type{cutensornetComputeType_t}, T::DataType)
end
end

function Base.convert(::Type{Type}, T::cutensornetComputeType_t)
function Base.Type(T::cutensornetComputeType_t)
if T == CUTENSORNET_COMPUTE_16F
return Float16
elseif T == CUTENSORNET_COMPUTE_32F
Expand Down
16 changes: 8 additions & 8 deletions lib/cutensornet/test/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
@test convert(cuTensorNet.cutensornetComputeType_t, Int32) == cuTensorNet.CUTENSORNET_COMPUTE_32I
@test convert(cuTensorNet.cutensornetComputeType_t, UInt32) == cuTensorNet.CUTENSORNET_COMPUTE_32U
@test_throws ArgumentError("cuTensorNet type equivalent for compute type ComplexF64 does not exist!") convert(cuTensorNet.cutensornetComputeType_t, ComplexF64)
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_8I) == Int8
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_8U) == UInt8
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_16F) == Float16
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_32F) == Float32
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_32U) == UInt32
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_32I) == Int32
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_64F) == Float64
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_8I) == Int8
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_8U) == UInt8
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_16F) == Float16
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_32F) == Float32
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_32U) == UInt32
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_32I) == Int32
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_64F) == Float64
@test convert(cuTensorNet.cutensornetComputeType_t, CUDACore.BFloat16) == cuTensorNet.CUTENSORNET_COMPUTE_16BF
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_16BF) == CUDACore.BFloat16
@test Type(cuTensorNet.CUTENSORNET_COMPUTE_16BF) == CUDACore.BFloat16


modesA = ['m', 'h', 'k', 'n']
Expand Down
4 changes: 2 additions & 2 deletions test/core/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ end
(Int32, CUDACore.R_32I), (Complex{Int32}, CUDACore.C_32I), (UInt32, CUDACore.R_32U), (Complex{UInt32}, CUDACore.C_32U),
(Int64, CUDACore.R_64I), (Complex{Int64}, CUDACore.C_64I), (UInt64, CUDACore.R_64U), (Complex{UInt64}, CUDACore.C_64U))
@test convert(CUDACore.cudaDataType, j_type) == c_type
@test convert(Type, c_type) == j_type
@test Type(c_type) == j_type
end
@test_throws ArgumentError convert(CUDACore.cudaDataType, BigFloat)
@test_throws ArgumentError convert(Type, CUDACore.R_4I) # adjust once we support 4-bit Ints
@test_throws ArgumentError Type(CUDACore.R_4I) # adjust once we support 4-bit Ints
end