Skip to content

Commit 22b2689

Browse files
kmp5VTkshyattlkdvos
authored
Wrapper for Blocksparse CuTensor code (#3057)
* Working on implementing the wrapper for the new blocksparse cutensor code * Revert to cutensor_jll.libcutensor as this has the blocksparse cutensor support now * Remove redudant convert function * Make blocksparse code more generic (generic case). Would it be better to make it a union type of CuTensorBS and AbstractArray? * Working on simplyfying and making accessors * Fix problem with stride * Small comment reminder * Add a contraction test for the blocksparse system (not comprehensive but the C++ code is still in flux) * Closer to clang.jl construction * Update cutensor.toml for block sparse contraction * Apply suggestion from @lkdvos Remove left over code. Will need to make something like this to define mul! in the future Co-authored-by: Lukas Devos <ldevos98@gmail.com> * Document C_NULL cutensorBSDescriptor * Remove comment * Fix issues with new CUDA organization * Add type restrictions to CuTensorBS type to make downstream easier * I believe this is the "generic" stride (i.e. all blocks are packed into a contigous memory block) * Skip blocksparse tests for failing versions. * More broken versions. Will send Mathias a message * Remove synchronize * Add destroy descriptor * Removing skipped versions for now. --------- Co-authored-by: Katharine Hyatt <khyatt@flatironinstitute.org> Co-authored-by: Lukas Devos <ldevos98@gmail.com>
1 parent 5a141fe commit 22b2689

7 files changed

Lines changed: 324 additions & 5 deletions

File tree

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
2+
3+
## LinearAlgebra
4+
5+
using LinearAlgebra
6+
7+
function LinearAlgebra.mul!(C::CuTensorBS, A::CuTensorBS, B::CuTensorBS, α::Number, β::Number)
8+
contract!(α,
9+
A, A.inds, CUTENSOR_OP_IDENTITY,
10+
B, B.inds, CUTENSOR_OP_IDENTITY,
11+
β,
12+
C, C.inds, CUTENSOR_OP_IDENTITY,
13+
CUTENSOR_OP_IDENTITY; jit=CUTENSOR_JIT_MODE_DEFAULT)
14+
return C
15+
end
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
function nonzero_blocks(A::CuTensorBS)
2+
return A.nonzero_data
3+
end
4+
5+
function contract!(
6+
@nospecialize(alpha::Number),
7+
@nospecialize(A), Ainds::ModeType, opA::cutensorOperator_t,
8+
@nospecialize(B), Binds::ModeType, opB::cutensorOperator_t,
9+
@nospecialize(beta::Number),
10+
@nospecialize(C), Cinds::ModeType, opC::cutensorOperator_t,
11+
opOut::cutensorOperator_t;
12+
jit::cutensorJitMode_t=JIT_MODE_NONE,
13+
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
14+
algo::cutensorAlgo_t=ALGO_DEFAULT,
15+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
16+
plan::Union{CuTensorPlan, Nothing}=nothing)
17+
18+
actual_plan = if plan === nothing
19+
plan_contraction(A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut;
20+
jit, workspace, algo, compute_type)
21+
else
22+
plan
23+
end
24+
25+
contractBS!(actual_plan, alpha, nonzero_blocks(A), nonzero_blocks(B), beta, nonzero_blocks(C))
26+
27+
if plan === nothing
28+
CUDACore.unsafe_free!(actual_plan)
29+
end
30+
31+
return C
32+
end
33+
34+
## This function assumes A, B, and C are Arrays of pointers to CuArrays.
35+
## Please overwrite the `nonzero_blocks` function for your datatype to access this function from contract!
36+
function contractBS!(plan::CuTensorPlan,
37+
@nospecialize(alpha::Number),
38+
@nospecialize(A::AbstractArray),
39+
@nospecialize(B::AbstractArray),
40+
@nospecialize(beta::Number),
41+
@nospecialize(C::AbstractArray))
42+
scalar_type = plan.scalar_type
43+
44+
# Extract GPU pointers from each CuArray block
45+
# cuTENSOR expects a host-accessible array of GPU pointers
46+
A_ptrs = CuPtr{Cvoid}[pointer(block) for block in A]
47+
B_ptrs = CuPtr{Cvoid}[pointer(block) for block in B]
48+
C_ptrs = CuPtr{Cvoid}[pointer(block) for block in C]
49+
50+
cutensorBlockSparseContract(handle(), plan,
51+
Ref{scalar_type}(alpha), A_ptrs, B_ptrs,
52+
Ref{scalar_type}(beta), C_ptrs, C_ptrs,
53+
plan.workspace, sizeof(plan.workspace), stream())
54+
return C
55+
end
56+
57+
function plan_contraction(
58+
@nospecialize(A), Ainds::ModeType, opA::cutensorOperator_t,
59+
@nospecialize(B), Binds::ModeType, opB::cutensorOperator_t,
60+
@nospecialize(C), Cinds::ModeType, opC::cutensorOperator_t,
61+
opOut::cutensorOperator_t;
62+
jit::cutensorJitMode_t=JIT_MODE_NONE,
63+
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
64+
algo::cutensorAlgo_t=ALGO_DEFAULT,
65+
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing)
66+
67+
!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
68+
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
69+
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
70+
!is_unary(opOut) && throw(ArgumentError("opOut must be a unary op!"))
71+
72+
descA = CuTensorBSDescriptor(A)
73+
descB = CuTensorBSDescriptor(B)
74+
descC = CuTensorBSDescriptor(C)
75+
# for now, D must be identical to C (and thus, descD must be identical to descC)
76+
77+
modeA = collect(Cint, Ainds)
78+
modeB = collect(Cint, Binds)
79+
modeC = collect(Cint, Cinds)
80+
81+
actual_compute_type = if compute_type === nothing
82+
contraction_compute_types[(eltype(A), eltype(B), eltype(C))]
83+
else
84+
compute_type
85+
end
86+
87+
88+
desc = Ref{cutensorOperationDescriptor_t}()
89+
cutensorCreateBlockSparseContraction(handle(),
90+
desc,
91+
descA, modeA, opA,
92+
descB, modeB, opB,
93+
descC, modeC, opC,
94+
descC, modeC, actual_compute_type)
95+
96+
plan_pref = Ref{cutensorPlanPreference_t}()
97+
cutensorCreatePlanPreference(handle(), plan_pref, algo, jit)
98+
99+
plan = CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace)
100+
cutensorDestroyOperationDescriptor(desc[])
101+
cutensorDestroyPlanPreference(plan_pref[])
102+
return plan
103+
end
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
## tensor
2+
3+
export CuTensorBS
4+
5+
## TODO add checks to see if size of data matches expected block size
6+
mutable struct CuTensorBS{T, N}
7+
nonzero_data::Vector{<:CuArray}
8+
inds::Vector{Int}
9+
blocks_per_mode::Vector{Int32}
10+
## This expects a Vector{Tuple(Int)} right now
11+
block_extents::Vector{<:Tuple}
12+
## This expects a Vector{Tuple(Int)} right now
13+
nonzero_block_coords::Vector{NTuple{N,Int32}}
14+
15+
function CuTensorBS{T, N}(nonzero_data,
16+
blocks_per_mode,
17+
block_extents,
18+
nonzero_block_coords,
19+
inds) where {T<:Number, N}
20+
CuArrayT = eltype(nonzero_data)
21+
@assert eltype(CuArrayT) == T
22+
# @assert ndims(CuArrayT) == N
23+
@assert length(block_extents) == N
24+
new(nonzero_data, inds, blocks_per_mode, block_extents, nonzero_block_coords)
25+
end
26+
end
27+
28+
function CuTensorBS(nonzero_data::Vector{<:CuArray{T}},
29+
blocks_per_mode, block_extents, nonzero_block_coords, inds) where {T<:Number}
30+
CuTensorBS{T,length(block_extents)}(nonzero_data,
31+
blocks_per_mode, block_extents, nonzero_block_coords, inds)
32+
end
33+
# array interface
34+
function Base.size(T::CuTensorBS)
35+
return tuple(sum.(T.block_extents)...)
36+
end
37+
Base.length(T::CuTensorBS) = prod(size(T))
38+
nonzero_length(T::CuTensorBS) = sum(length.(T.nonzero_data))
39+
Base.ndims(T::CuTensorBS) = Int32(length(T.inds))
40+
41+
## This tells how far away each block is from the other block in memory.
42+
Base.strides(T::CuTensorBS) = strides(T.nonzero_data)
43+
Base.eltype(T::CuTensorBS) = eltype(eltype(T.nonzero_data))
44+
45+
function block_extents(T::CuTensorBS)
46+
extents = Vector{Int64}()
47+
48+
for ex in T.block_extents
49+
extents = vcat(extents, ex...)
50+
end
51+
return extents
52+
end
53+
54+
nblocks_per_mode(T::CuTensorBS) = T.blocks_per_mode
55+
56+
num_nonzero_blocks(T::CuTensorBS) = length(T.nonzero_block_coords)
57+
58+
## This function turns the tuple of the block coordinates into a single
59+
## list of blocks
60+
function list_nonzero_block_coords(T::CuTensorBS)
61+
block_list = Vector{Int64}()
62+
for block in T.nonzero_block_coords
63+
block_list = vcat(block_list, block...)
64+
end
65+
return block_list
66+
end
67+
68+
# ## descriptor
69+
mutable struct CuTensorBSDescriptor
70+
handle::cutensorBlockSparseTensorDescriptor_t
71+
# inner constructor handles creation and finalizer of the descriptor
72+
function CuTensorBSDescriptor(
73+
numModes::Int32,
74+
numNonZeroBlocks::Int64,
75+
numSectionsPerMode::Vector{Int32},
76+
extent::Vector{Int64},
77+
nonZeroCoordinates::Vector{Int32},
78+
stride, ## Union{Vector{Int64}, C_NULL},
79+
eltype::Type)
80+
81+
desc = Ref{cuTENSOR.cutensorBlockSparseTensorDescriptor_t}()
82+
cutensorCreateBlockSparseTensorDescriptor(handle(), desc,
83+
numModes, numNonZeroBlocks, numSectionsPerMode, extent, nonZeroCoordinates,
84+
stride, eltype)
85+
86+
obj = new(desc[])
87+
finalizer(unsafe_destroy!, obj)
88+
return obj
89+
end
90+
end
91+
92+
## This function assumes that strides are C_NULL, i.e. canonical stride
93+
function CuTensorBSDescriptor(
94+
numModes::Int32,
95+
numNonZeroBlocks::Int64,
96+
numSectionsPerMode::Vector{Int32},
97+
extent::Vector{Int64},
98+
nonZeroCoordinates::Vector{Int32},
99+
# strides = C_NULL,
100+
eltype::Type)
101+
102+
return CuTensorBSDescriptor(numModes, numNonZeroBlocks, numSectionsPerMode, extent, nonZeroCoordinates, C_NULL, eltype)
103+
end
104+
105+
Base.show(io::IO, desc::CuTensorBSDescriptor) = @printf(io, "CuTensorBSDescriptor(%p)", desc.handle)
106+
107+
Base.unsafe_convert(::Type{cutensorBlockSparseTensorDescriptor_t}, obj::CuTensorBSDescriptor) = obj.handle
108+
109+
function unsafe_destroy!(obj::CuTensorBSDescriptor)
110+
cutensorDestroyBlockSparseTensorDescriptor(obj)
111+
end
112+
113+
## Descriptor function for CuTensorBS type. Please overwrite for custom objects
114+
function CuTensorBSDescriptor(A::CuTensorBS)
115+
numModes = ndims(A)
116+
numNonZeroBlocks = length(A.nonzero_block_coords)
117+
numSectionsPerMode = collect(nblocks_per_mode(A))
118+
extent = block_extents(A)
119+
nonZeroCoordinates = collect(Base.Iterators.flatten(A.nonzero_block_coords)) .- Int32(1)
120+
st = strides(A)
121+
@assert all(st .== 1)
122+
123+
dataType = eltype(A)
124+
125+
## Right now assume stride is NULL. I am not sure if stride works, need to discuss with cuTENSOR team.
126+
CuTensorBSDescriptor(numModes, numNonZeroBlocks,
127+
numSectionsPerMode, extent, nonZeroCoordinates, dataType)
128+
end

lib/cutensor/src/cuTENSOR.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using CUDACore
44
using CUDACore: CUstream, cudaDataType, @gcsafe_ccall, @checked, @enum_without_prefix
55
using CUDACore: retry_reclaim, initialize_context, isdebug
66

7+
using CUDACore.GPUToolbox
8+
79
using CEnum: @cenum
810

911
using Printf: @printf
@@ -32,8 +34,14 @@ include("utils.jl")
3234
include("types.jl")
3335
include("operations.jl")
3436

37+
38+
# Block sparse wrappers
39+
include("blocksparse/types.jl")
40+
include("blocksparse/operations.jl")
41+
3542
# high-level integrations
3643
include("interfaces.jl")
44+
include("blocksparse/interfaces.jl")
3745

3846

3947
## handles

lib/cutensor/src/libcutensor.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -545,12 +545,12 @@ end
545545
@gcsafe_ccall libcutensor.cutensorBlockSparseContract(handle::cutensorHandle_t,
546546
plan::cutensorPlan_t,
547547
alpha::Ptr{Cvoid},
548-
A::Ptr{Ptr{Cvoid}},
549-
B::Ptr{Ptr{Cvoid}},
548+
A::Ptr{CuPtr{Cvoid}},
549+
B::Ptr{CuPtr{Cvoid}},
550550
beta::Ptr{Cvoid},
551-
C::Ptr{Ptr{Cvoid}},
552-
D::Ptr{Ptr{Cvoid}},
553-
workspace::Ptr{Cvoid},
551+
C::Ptr{CuPtr{Cvoid}},
552+
D::Ptr{CuPtr{Cvoid}},
553+
workspace::CuPtr{Cvoid},
554554
workspaceSize::UInt64,
555555
stream::cudaStream_t)::cutensorStatus_t
556556
end

lib/cutensor/test/contractions.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,4 +188,62 @@ end
188188
end
189189
end
190190

191+
eltypes_compact = [
192+
(Float32, Float32, Float32, Float32),
193+
(ComplexF32, ComplexF32, ComplexF32, Float32),
194+
(Float64, Float64, Float64, Float64),
195+
(ComplexF64, ComplexF64, ComplexF64, Float64)
196+
]
197+
@testset "Blocksparse Contraction" begin
198+
## There are many unsupported types because this is a new functionality
199+
## So I will test with Float32 and ComplexF32 only
200+
@testset for (eltyA, eltyB, eltyC, eltyCompute) in eltypes_compact
201+
## i = [20,20,25]
202+
## k = [10,10,15]
203+
## l = [30,30,35]
204+
## A = Tensor(k,i,l)
205+
## Nonzero blocks are
206+
## [1,1,1], [1,1,3], [1,3,1], [1,3,3], [3,1,1], [3,1,3], [3,3,1], [3,3,3]
207+
A = Vector{CuArray{eltyA, 3}}()
208+
for k in [10,15]
209+
for i in [20,25]
210+
for l in [30,35]
211+
push!(A, CuArray(ones(eltyA, k,i,l)))
212+
end
213+
end
214+
end
215+
216+
## B = Tensor(k,l)
217+
## Nonzero blocks are
218+
## [1,1], [2,3]
219+
B = Array{CuArray{eltyB, 2}}(
220+
[CuArray(randn(eltyB, 10, 30)),
221+
CuArray(randn(eltyB, 10, 35))])
222+
223+
## C = Tensor(i)
224+
## Nonzero blocks are
225+
## [1,], [3,]
226+
C = Vector{CuArray{eltyC, 1}}(
227+
[CuArray(zeros(eltyC, 20)),
228+
CuArray(zeros(eltyC, 25))]
229+
)
230+
231+
cuTenA = cuTENSOR.CuTensorBS(A, [3,3,3],
232+
[(10,10,15), (20,20,25), (30,30,35)],
233+
[(1,1,1), (1,1,3), (1,3,1), (1,3,3), (3,1,1), (3,1,3), (3,3,1), (3,3,3)],
234+
[1,3,2])
235+
cuTenB = cuTENSOR.CuTensorBS(B, [3,3],
236+
[(10,10,15), (30,30,35)],
237+
[(1,1),(2,3)], [1,2], )
238+
cuTenC = cuTENSOR.CuTensorBS(C, [3],
239+
[(20,20,25)],[(1,),(3,)], [3])
240+
241+
mul!(cuTenC, cuTenA, cuTenB, 1, 0)
242+
## C[1] = A[1,1,1] * B[1,1]
243+
@test C[1] reshape(permutedims(A[1], (2,1,3)), (20, 10 * 30)) * reshape(B[1], (10 * 30))
244+
## C[3] = A[1,3,1] * B[1,1]
245+
@test C[2] reshape(permutedims(A[3], (2,1,3)), (25, 10 * 30)) * reshape(B[1], (10 * 30))
246+
end
247+
end
248+
191249
end

res/wrap/cutensor.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,10 @@ needs_context = false
5757
6 = "CuPtr{Cvoid}"
5858
7 = "CuPtr{Cvoid}"
5959
8 = "CuPtr{Cvoid}"
60+
61+
[api.cutensorBlockSparseContract.argtypes]
62+
4 = "Ptr{CuPtr{Cvoid}}"
63+
5 = "Ptr{CuPtr{Cvoid}}"
64+
7 = "Ptr{CuPtr{Cvoid}}"
65+
8 = "Ptr{CuPtr{Cvoid}}"
66+
9 = "Ptr{CuPtr{Cvoid}}"

0 commit comments

Comments
 (0)