Skip to content

Commit d0e0676

Browse files
authored
Small fixes and missed tests for CUTENSORNET (#2713)
* Small fixes and missed tests for CUTENSORNET * More fixes and tests
1 parent 7b28f2d commit d0e0676

2 files changed

Lines changed: 50 additions & 13 deletions

File tree

lib/cutensornet/src/types.jl

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ mutable struct CuTensorDescriptor{T}
5151
T, desc_ref)
5252
obj = new{T}(desc_ref[])
5353
finalizer(cutensornetDestroyTensorDescriptor, obj)
54-
obj
54+
return obj
5555
end
5656
end
5757
CuTensorDescriptor(T::DataType, extents, strides, modes) = CuTensorDescriptor{T}(extents, strides, modes)
@@ -65,12 +65,14 @@ function Base.ndims(desc::CuTensorDescriptor)
6565
end
6666

6767
function Base.size(desc::CuTensorDescriptor)
68+
numModes = Ref{Int32}(C_NULL)
6869
extents = Vector{Int64}(undef, ndims(desc))
6970
cutensornetGetTensorDetails(handle(), desc, numModes, C_NULL, C_NULL, extents, C_NULL)
7071
return tuple(extents...)
7172
end
7273

7374
function Base.strides(desc::CuTensorDescriptor)
75+
numModes = Ref{Int32}(C_NULL)
7476
strides = Vector{Int64}(undef, ndims(desc))
7577
cutensornetGetTensorDetails(handle(), desc, numModes, C_NULL, C_NULL, C_NULL, strides)
7678
return tuple(strides...)
@@ -89,16 +91,16 @@ mutable struct CuTensorNetworkDescriptor
8991
extentsOut, stridesOut, modesOut, dataType, computeType, desc_ref)
9092
obj = new(desc_ref[])
9193
finalizer(cutensornetDestroyNetworkDescriptor, obj)
92-
obj
94+
return obj
9395
end
9496
end
9597
Base.unsafe_convert(::Type{cutensornetNetworkDescriptor_t}, desc::CuTensorNetworkDescriptor) = desc.handle
9698

9799
function compute_type(T::DataType)
98100
if T == Float16
99-
return Float32
100-
elseif T == Float32
101101
return Float16
102+
elseif T == Float32
103+
return Float32
102104
elseif T == Float64
103105
return Float64
104106
end
@@ -133,7 +135,7 @@ mutable struct CuTensorSVDInfo
133135
cutensornetCreateTensorSVDInfo(handle(), info_ref)
134136
obj = new(info_ref[])
135137
finalizer(cutensornetDestroyTensorSVDInfo, obj)
136-
obj
138+
return obj
137139
end
138140
end
139141
Base.unsafe_convert(::Type{cutensornetTensorSVDInfo_t}, info::CuTensorSVDInfo) = info.handle
@@ -162,7 +164,7 @@ mutable struct CuTensorNetworkContractionOptimizerInfo
162164
cutensornetCreateContractionOptimizerInfo(handle(), net_desc, desc_ref)
163165
obj = new(desc_ref[])
164166
finalizer(cutensornetDestroyContractionOptimizerInfo, obj)
165-
obj
167+
return obj
166168
end
167169
end
168170

@@ -175,7 +177,7 @@ mutable struct CuTensorNetworkWorkspaceDescriptor
175177
cutensornetCreateWorkspaceDescriptor(handle(), desc_ref)
176178
obj = new(desc_ref[])
177179
finalizer(cutensornetDestroyWorkspaceDescriptor, obj)
178-
obj
180+
return obj
179181
end
180182
end
181183

@@ -188,7 +190,7 @@ mutable struct CuTensorNetworkContractionPlan
188190
cutensornetCreateContractionPlan(handle(), net_desc, info, ws_desc, desc_ref)
189191
obj = new(desc_ref[])
190192
finalizer(cutensornetDestroyContractionPlan, obj)
191-
obj
193+
return obj
192194
end
193195
end
194196

@@ -250,7 +252,7 @@ mutable struct CuTensorNetworkContractionOptimizerConfig
250252
attr_buf = Ref(Base.getproperty(prefs, attr[1]))
251253
cutensornetContractionOptimizerConfigSetAttribute(handle(), desc_ref[], attr[2], attr_buf, sizeof(attr_buf))
252254
end
253-
obj
255+
return obj
254256
end
255257
end
256258

@@ -287,7 +289,7 @@ mutable struct CuTensorSVDConfig
287289
attr_buf = Ref(Base.getproperty(prefs, attr[1]))
288290
cutensornetTensorSVDConfigSetAttribute(handle(), desc_ref[], attr[2], attr_buf, sizeof(attr_buf))
289291
end
290-
obj
292+
return obj
291293
end
292294
end
293295
function abs_cutoff(conf::CuTensorSVDConfig)
@@ -323,7 +325,7 @@ mutable struct CuTensorNetworkAutotunePreference
323325
attr_buf = Ref(Base.getproperty(prefs, attr[1]))
324326
cutensornetContractionAutotunePreferenceSetAttribute(handle(), pref_ref[], attr[2], attr_buf, sizeof(attr_buf))
325327
end
326-
obj
328+
return obj
327329
end
328330
end
329331
Base.unsafe_convert(::Type{cutensornetContractionAutotunePreference_t}, prefs::CuTensorNetworkAutotunePreference) = prefs.handle
@@ -336,14 +338,14 @@ mutable struct CuTensorNetworkSliceGroup
336338
cutensornetCreateSliceGroupFromIDRange(handle(), sliceStart, sliceStop, sliceStep, group_ref)
337339
obj = new(group_ref[])
338340
finalizer(cutensornetDestroySliceGroup, obj)
339-
obj
341+
return obj
340342
end
341343
function CuTensorNetworkSliceGroup(slices::Vector{Int64})
342344
group_ref = Ref{cutensornetSliceGroup_t}()
343345
cutensornetCreateSliceGroupFromIDs(handle(), pointer(slices), pointer(slices, length(slices)), group_ref)
344346
obj = new(group_ref[])
345347
finalizer(cutensornetDestroySliceGroup, obj)
346-
obj
348+
return obj
347349
end
348350
end
349351
Base.unsafe_convert(::Type{cutensornetSliceGroup_t}, prefs::CuTensorNetworkSliceGroup) = prefs.handle

lib/cutensornet/test/runtests.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,37 @@ import cuTensorNet: CuTensorNetwork, rehearse_contraction, perform_contraction!,
1515
using TensorOperations
1616

1717
@testset "cuTensorNet" begin
18+
@testset "Helpers and types" begin
19+
@test convert(cuTensorNet.cutensornetComputeType_t, Int8) == cuTensorNet.CUTENSORNET_COMPUTE_8I
20+
@test convert(cuTensorNet.cutensornetComputeType_t, UInt8) == cuTensorNet.CUTENSORNET_COMPUTE_8U
21+
@test convert(cuTensorNet.cutensornetComputeType_t, Float16) == cuTensorNet.CUTENSORNET_COMPUTE_16F
22+
@test convert(cuTensorNet.cutensornetComputeType_t, Int32) == cuTensorNet.CUTENSORNET_COMPUTE_32I
23+
@test convert(cuTensorNet.cutensornetComputeType_t, UInt32) == cuTensorNet.CUTENSORNET_COMPUTE_32U
24+
@test_throws ArgumentError("cuTensorNet type equivalent for compute type ComplexF64 does not exist!") convert(cuTensorNet.cutensornetComputeType_t, ComplexF64)
25+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_8I) == Int8
26+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_8U) == UInt8
27+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_16F) == Float16
28+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_32F) == Float32
29+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_32U) == UInt32
30+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_32I) == Int32
31+
@test convert(Type, cuTensorNet.CUTENSORNET_COMPUTE_64F) == Float64
32+
33+
34+
modesA = ['m', 'h', 'k', 'n']
35+
extent = Dict{Char, Int}()
36+
extent['m'] = 96;
37+
extent['n'] = 96;
38+
extent['h'] = 64;
39+
extent['k'] = 64;
40+
extentsA = [extent[mode] for mode in modesA]
41+
@testset for elty in [Float32, Float64, ComplexF32, ComplexF64]
42+
A = CUDA.rand(elty, extentsA...)
43+
descA = cuTensorNet.CuTensorDescriptor(A, modesA)
44+
@test ndims(descA) == ndims(A)
45+
@test size(descA) == size(A)
46+
@test strides(descA) == strides(A)
47+
end
48+
end
1849
n = 8
1950
m = 16
2051
k = 32
@@ -84,6 +115,10 @@ using TensorOperations
84115
@test cuTensorNet.reduced_extent(info) == n
85116
@test cuTensorNet.discarded_weight(info) 0.0
86117
@test collect(U)*diagm(collect(S))*collect(V) collect(A)
118+
config = cuTensorNet.CuTensorSVDConfig()
119+
@test cuTensorNet.abs_cutoff(config) == 0.0
120+
@test cuTensorNet.rel_cutoff(config) == 0.0
121+
@test cuTensorNet.normalization(config) == cuTensorNet.CUTENSORNET_TENSOR_SVD_NORMALIZATION_NONE
87122
end
88123
@testset "GateSplit" begin
89124
a = 16

0 commit comments

Comments
 (0)