@@ -51,7 +51,7 @@ mutable struct CuTensorDescriptor{T}
5151 T, desc_ref)
5252 obj = new {T} (desc_ref[])
5353 finalizer (cutensornetDestroyTensorDescriptor, obj)
54- obj
54+ return obj
5555 end
5656end
5757CuTensorDescriptor (T:: DataType , extents, strides, modes) = CuTensorDescriptor {T} (extents, strides, modes)
@@ -65,12 +65,14 @@ function Base.ndims(desc::CuTensorDescriptor)
6565end
6666
6767function Base. size (desc:: CuTensorDescriptor )
68+ numModes = Ref {Int32} (C_NULL )
6869 extents = Vector {Int64} (undef, ndims (desc))
6970 cutensornetGetTensorDetails (handle (), desc, numModes, C_NULL , C_NULL , extents, C_NULL )
7071 return tuple (extents... )
7172end
7273
7374function Base. strides (desc:: CuTensorDescriptor )
75+ numModes = Ref {Int32} (C_NULL )
7476 strides = Vector {Int64} (undef, ndims (desc))
7577 cutensornetGetTensorDetails (handle (), desc, numModes, C_NULL , C_NULL , C_NULL , strides)
7678 return tuple (strides... )
@@ -89,16 +91,16 @@ mutable struct CuTensorNetworkDescriptor
8991 extentsOut, stridesOut, modesOut, dataType, computeType, desc_ref)
9092 obj = new (desc_ref[])
9193 finalizer (cutensornetDestroyNetworkDescriptor, obj)
92- obj
94+ return obj
9395 end
9496end
9597Base. unsafe_convert (:: Type{cutensornetNetworkDescriptor_t} , desc:: CuTensorNetworkDescriptor ) = desc. handle
9698
9799function compute_type (T:: DataType )
98100 if T == Float16
99- return Float32
100- elseif T == Float32
101101 return Float16
102+ elseif T == Float32
103+ return Float32
102104 elseif T == Float64
103105 return Float64
104106 end
@@ -133,7 +135,7 @@ mutable struct CuTensorSVDInfo
133135 cutensornetCreateTensorSVDInfo (handle (), info_ref)
134136 obj = new (info_ref[])
135137 finalizer (cutensornetDestroyTensorSVDInfo, obj)
136- obj
138+ return obj
137139 end
138140end
139141Base. unsafe_convert (:: Type{cutensornetTensorSVDInfo_t} , info:: CuTensorSVDInfo ) = info. handle
@@ -162,7 +164,7 @@ mutable struct CuTensorNetworkContractionOptimizerInfo
162164 cutensornetCreateContractionOptimizerInfo (handle (), net_desc, desc_ref)
163165 obj = new (desc_ref[])
164166 finalizer (cutensornetDestroyContractionOptimizerInfo, obj)
165- obj
167+ return obj
166168 end
167169end
168170
@@ -175,7 +177,7 @@ mutable struct CuTensorNetworkWorkspaceDescriptor
175177 cutensornetCreateWorkspaceDescriptor (handle (), desc_ref)
176178 obj = new (desc_ref[])
177179 finalizer (cutensornetDestroyWorkspaceDescriptor, obj)
178- obj
180+ return obj
179181 end
180182end
181183
@@ -188,7 +190,7 @@ mutable struct CuTensorNetworkContractionPlan
188190 cutensornetCreateContractionPlan (handle (), net_desc, info, ws_desc, desc_ref)
189191 obj = new (desc_ref[])
190192 finalizer (cutensornetDestroyContractionPlan, obj)
191- obj
193+ return obj
192194 end
193195end
194196
@@ -250,7 +252,7 @@ mutable struct CuTensorNetworkContractionOptimizerConfig
250252 attr_buf = Ref (Base. getproperty (prefs, attr[1 ]))
251253 cutensornetContractionOptimizerConfigSetAttribute (handle (), desc_ref[], attr[2 ], attr_buf, sizeof (attr_buf))
252254 end
253- obj
255+ return obj
254256 end
255257end
256258
@@ -287,7 +289,7 @@ mutable struct CuTensorSVDConfig
287289 attr_buf = Ref (Base. getproperty (prefs, attr[1 ]))
288290 cutensornetTensorSVDConfigSetAttribute (handle (), desc_ref[], attr[2 ], attr_buf, sizeof (attr_buf))
289291 end
290- obj
292+ return obj
291293 end
292294end
293295function abs_cutoff (conf:: CuTensorSVDConfig )
@@ -323,7 +325,7 @@ mutable struct CuTensorNetworkAutotunePreference
323325 attr_buf = Ref (Base. getproperty (prefs, attr[1 ]))
324326 cutensornetContractionAutotunePreferenceSetAttribute (handle (), pref_ref[], attr[2 ], attr_buf, sizeof (attr_buf))
325327 end
326- obj
328+ return obj
327329 end
328330end
329331Base. unsafe_convert (:: Type{cutensornetContractionAutotunePreference_t} , prefs:: CuTensorNetworkAutotunePreference ) = prefs. handle
@@ -336,14 +338,14 @@ mutable struct CuTensorNetworkSliceGroup
336338 cutensornetCreateSliceGroupFromIDRange (handle (), sliceStart, sliceStop, sliceStep, group_ref)
337339 obj = new (group_ref[])
338340 finalizer (cutensornetDestroySliceGroup, obj)
339- obj
341+ return obj
340342 end
341343 function CuTensorNetworkSliceGroup (slices:: Vector{Int64} )
342344 group_ref = Ref {cutensornetSliceGroup_t} ()
343345 cutensornetCreateSliceGroupFromIDs (handle (), pointer (slices), pointer (slices, length (slices)), group_ref)
344346 obj = new (group_ref[])
345347 finalizer (cutensornetDestroySliceGroup, obj)
346- obj
348+ return obj
347349 end
348350end
349351Base. unsafe_convert (:: Type{cutensornetSliceGroup_t} , prefs:: CuTensorNetworkSliceGroup ) = prefs. handle
0 commit comments