Skip to content

Commit 08b2fee

Browse files
Merge branch 'master' into csr-dispatch
2 parents 8111af4 + 591f61d commit 08b2fee

12 files changed

Lines changed: 111 additions & 124 deletions

File tree

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "CUDA"
22
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
3-
version = "5.7.1"
3+
version = "5.7.2"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -64,7 +64,7 @@ EnzymeCore = "0.8.2"
6464
ExprTools = "0.1"
6565
GPUArrays = "11.2.1"
6666
GPUCompiler = "0.24, 0.25, 0.26, 0.27, 1"
67-
GPUToolbox = "0.1, 0.2"
67+
GPUToolbox = "0.2"
6868
KernelAbstractions = "0.9.2"
6969
LLVM = "9.1"
7070
LLVMLoopInfo = "1"

lib/cusparse/broadcast.jl

Lines changed: 40 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -300,153 +300,87 @@ _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
300300

301301
## sparse broadcast implementation
302302

303-
# TODO: unify CSC/CSR kernels
303+
iter_type(::Type{<:CuSparseMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
304+
iter_type(::Type{<:CuSparseMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
305+
iter_type(::Type{<:CuSparseDeviceMatrixCSC}, ::Type{Ti}) where {Ti} = CSCIterator{Ti}
306+
iter_type(::Type{<:CuSparseDeviceMatrixCSR}, ::Type{Ti}) where {Ti} = CSRIterator{Ti}
307+
304308
# kernel to count the number of non-zeros in a row, to determine the row offsets
305-
function compute_offsets_kernel(::Type{<:CuSparseMatrixCSR}, offsets::AbstractVector{Ti},
309+
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
306310
args...) where Ti
307311
# every thread processes an entire row
308-
row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
309-
row > length(offsets)-1 && return
310-
iter = @inbounds CSRIterator{Ti}(row, args...)
311-
312-
# count the nonzero columns of all inputs
313-
accum = zero(Ti)
314-
for (col, vals) in iter
315-
accum += one(Ti)
316-
end
317-
318-
# the way we write the nnz counts is a bit strange, but done so that the result
319-
# after accumulation can be directly used as the rowPtr array of a CSR matrix.
320-
@inbounds begin
321-
if row == 1
322-
offsets[1] = 1
323-
end
324-
offsets[row+1] = accum
325-
end
312+
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
313+
leading_dim > length(offsets)-1 && return
314+
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
326315

327-
return
328-
end
329-
function compute_offsets_kernel(::Type{<:CuSparseMatrixCSC}, offsets::AbstractVector{Ti},
330-
args...) where Ti
331-
# every thread processes an entire columm
332-
col = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
333-
col > length(offsets)-1 && return
334-
iter = @inbounds CSCIterator{Ti}(col, args...)
335-
336-
# count the nonzero columns of all inputs
316+
# count the nonzero leading_dims of all inputs
337317
accum = zero(Ti)
338-
for (col, vals) in iter
318+
for (leading_dim, vals) in iter
339319
accum += one(Ti)
340320
end
341321

342322
# the way we write the nnz counts is a bit strange, but done so that the result
343-
# after accumulation can be directly used as the colPtr array of a CSC matrix.
323+
# after accumulation can be directly used as the rowPtr/colPtr array of a CSR/CSC matrix.
344324
@inbounds begin
345-
if col == 1
325+
if leading_dim == 1
346326
offsets[1] = 1
347327
end
348-
offsets[col+1] = accum
328+
offsets[leading_dim+1] = accum
349329
end
350330

351331
return
352332
end
353333

354334
# broadcast kernels that iterate the elements of sparse arrays
355-
function sparse_to_sparse_broadcast_kernel(f, output::CuSparseDeviceMatrixCSR{<:Any,Ti},
356-
offsets::Union{AbstractVector,Nothing},
357-
args...) where {Ti}
335+
function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}}
358336
# every thread processes an entire row
359-
row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
360-
row > size(output, 1) && return
361-
iter = @inbounds CSRIterator{Ti}(row, args...)
337+
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
338+
leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
339+
leading_dim > leading_dim_size && return
340+
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
341+
362342

343+
output_ptrs = output isa CuSparseDeviceMatrixCSR ? output.rowPtr : output.colPtr
344+
output_ivals = output isa CuSparseDeviceMatrixCSR ? output.colVal : output.rowVal
363345
# fetch the row offset, and write it to the output
364346
@inbounds begin
365-
output_ptr = output.rowPtr[row] = offsets[row]
366-
if row == size(output, 1)
367-
output.rowPtr[row+1i32] = offsets[row+1i32]
347+
output_ptr = output_ptrs[leading_dim] = offsets[leading_dim]
348+
if leading_dim == leading_dim_size
349+
output_ptrs[leading_dim+1i32] = offsets[leading_dim+1i32]
368350
end
369351
end
370352

371353
# set the values for this row
372-
for (col, ptrs) in iter
373-
I = CartesianIndex(row, col)
354+
for (sub_leading_dim, ptrs) in iter
355+
index_first = output isa CuSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
356+
index_second = output isa CuSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim
357+
I = CartesianIndex(index_first, index_second)
374358
vals = ntuple(Val(length(args))) do i
375359
arg = @inbounds args[i]
376360
ptr = @inbounds ptrs[i]
377361
_getindex(arg, I, ptr)
378362
end
379363

380-
@inbounds output.colVal[output_ptr] = col
364+
@inbounds output_ivals[output_ptr] = sub_leading_dim
381365
@inbounds output.nzVal[output_ptr] = f(vals...)
382366
output_ptr += one(Ti)
383367
end
384368

385369
return
386370
end
387-
function sparse_to_sparse_broadcast_kernel(f, output::CuSparseDeviceMatrixCSC{<:Any,Ti},
388-
offsets::Union{AbstractVector,Nothing},
389-
args...) where {Ti}
390-
# every thread processes an entire column
391-
col = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
392-
col > size(output, 2) && return
393-
iter = @inbounds CSCIterator{Ti}(col, args...)
394-
395-
# fetch the column offset, and write it to the output
396-
@inbounds begin
397-
output_ptr = output.colPtr[col] = offsets[col]
398-
if col == size(output, 2)
399-
output.colPtr[col+1i32] = offsets[col+1i32]
400-
end
401-
end
402-
403-
# set the values for this col
404-
for (row, ptrs) in iter
405-
I = CartesianIndex(col, row)
406-
vals = ntuple(Val(length(args))) do i
407-
arg = @inbounds args[i]
408-
ptr = @inbounds ptrs[i]
409-
_getindex(arg, I, ptr)
410-
end
411-
412-
@inbounds output.rowVal[output_ptr] = row
413-
@inbounds output.nzVal[output_ptr] = f(vals...)
414-
output_ptr += one(Ti)
415-
end
416-
417-
return
418-
end
419-
function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseMatrixCSR}, f,
420-
output::CuDeviceArray, args...)
371+
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}}, f,
372+
output::CuDeviceArray, args...) where {Tv, Ti}
421373
# every thread processes an entire row
422-
row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
423-
row > size(output, 1) && return
424-
iter = @inbounds CSRIterator{Int}(row, args...)
374+
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
375+
leading_dim_size = T <: CuSparseMatrixCSR ? size(output, 1) : size(output, 2)
376+
leading_dim > leading_dim_size && return
377+
iter = @inbounds iter_type(T, Ti)(leading_dim, args...)
425378

426379
# set the values for this row
427-
for (col, ptrs) in iter
428-
I = CartesianIndex(row, col)
429-
vals = ntuple(Val(length(args))) do i
430-
arg = @inbounds args[i]
431-
ptr = @inbounds ptrs[i]
432-
_getindex(arg, I, ptr)
433-
end
434-
435-
@inbounds output[I] = f(vals...)
436-
end
437-
438-
return
439-
end
440-
function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseMatrixCSC}, f,
441-
output::CuDeviceArray, args...)
442-
# every thread processes an entire column
443-
col = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
444-
col > size(output, 2) && return
445-
iter = @inbounds CSCIterator{Int}(col, args...)
446-
447-
# set the values for this col
448-
for (row, ptrs) in iter
449-
I = CartesianIndex(row, col)
380+
for (sub_leading_dim, ptrs) in iter
381+
index_first = T <: CuSparseMatrixCSR ? leading_dim : sub_leading_dim
382+
index_second = T <: CuSparseMatrixCSR ? sub_leading_dim : leading_dim
383+
I = CartesianIndex(index_first, index_second)
450384
vals = ntuple(Val(length(args))) do i
451385
arg = @inbounds args[i]
452386
ptr = @inbounds ptrs[i]

lib/cusparse/generic.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ function densetosparse(A::CuMatrix{T}, fmt::Symbol, index::SparseChar, algo::cus
5353
colPtr = CuVector{Cint}(undef, n+1)
5454
desc_sparse = CuSparseMatrixDescriptor(CuSparseMatrixCSC, colPtr, T, Cint, m, n, index)
5555
else
56-
error("Format :$fmt not available, use :csc, :csr or :coo.")
56+
throw(ArgumentError("Format :$fmt not available, use :csc, :csr or :coo."))
5757
end
5858
desc_dense = CuDenseMatrixDescriptor(A)
5959

@@ -82,8 +82,6 @@ function densetosparse(A::CuMatrix{T}, fmt::Symbol, index::SparseChar, algo::cus
8282
nzVal = CuVector{T}(undef, nnzB[])
8383
B = CuSparseMatrixCSC{T, Cint}(colPtr, rowVal, nzVal, (m,n))
8484
cusparseCscSetPointers(desc_sparse, B.colPtr, B.rowVal, B.nzVal)
85-
else
86-
error("Format :$fmt not available, use :csc, :csr or :coo.")
8785
end
8886
cusparseDenseToSparse_convert(handle(), desc_dense, desc_sparse, algo, buffer)
8987
end

lib/cutensor/src/interfaces.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function Base.:(+)(A::CuTensor, B::CuTensor)
99
elementwise_binary_execute!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY,
1010
γ, B.data, B.inds, CUTENSOR_OP_IDENTITY,
1111
C.data, C.inds, CUTENSOR_OP_ADD)
12-
C
12+
return C
1313
end
1414

1515
function Base.:(-)(A::CuTensor, B::CuTensor)
@@ -19,7 +19,7 @@ function Base.:(-)(A::CuTensor, B::CuTensor)
1919
elementwise_binary_execute!(α, A.data, A.inds, CUTENSOR_OP_IDENTITY,
2020
γ, B.data, B.inds, CUTENSOR_OP_IDENTITY,
2121
C.data, C.inds, CUTENSOR_OP_ADD)
22-
C
22+
return C
2323
end
2424

2525
function Base.:(*)(A::CuTensor, B::CuTensor)

src/CUDA.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using GPUCompiler
44

55
using GPUArrays
66

7-
using GPUToolbox: SimpleVersion, @sv_str
7+
using GPUToolbox
88

99
using LLVM
1010
using LLVM.Interop

src/array.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,35 @@ function Base.unsafe_convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::DenseCuArr
472472
end
473473

474474

475-
## memory copying
475+
## synchronization
476476

477477
synchronize(x::CuArray) = synchronize(x.data[])
478478

479+
"""
480+
enable_synchronization!(arr::CuArray, enable::Bool)
481+
482+
By default `CuArray`s are implicitly synchronized when they are accessed on different CUDA
483+
devices or streams. This may be unwanted when e.g. using disjoint slices of memory across
484+
different tasks. This function allows to enable or disable this behavior.
485+
486+
!!! warning
487+
488+
Disabling implicit synchronization affects _all_ `CuArray`s that are referring to the
489+
same underlying memory. Unsafe use of this API _will_ result in data corruption.
490+
491+
This API is only provided as an escape hatch, and should not be used without careful
492+
consideration. If automatic synchronization is generally problematic for your use case,
493+
it is recommended to figure out a better model instead and file an issue or pull request.
494+
For more details see [this discussion](https://github.com/JuliaGPU/CUDA.jl/issues/2617).
495+
"""
496+
function enable_synchronization!(arr::CuArray, enable::Bool=true)
497+
arr.data[].synchronizing = enable
498+
return arr
499+
end
500+
501+
502+
## memory copying
503+
479504
if VERSION >= v"1.11.0-DEV.753"
480505
function typetagdata(a::Array, i=1)
481506
ptr_or_offset = Int(a.ref.ptr_or_offset)

src/device/utils.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,5 @@
11
# helpers for writing device functionality
22

3-
# helper type for writing Int32 literals
4-
# TODO: upstream this
5-
struct Literal{T} end
6-
Base.:(*)(x::Number, ::Type{Literal{T}}) where {T} = T(x)
7-
const i32 = Literal{Int32}
8-
93
# local method table for device functions
104
@static if isdefined(Base.Experimental, Symbol("@overlay"))
115
Base.Experimental.@MethodTable(method_table)

src/memory.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,16 +503,20 @@ mutable struct Managed{M}
503503
# which stream is currently using the memory.
504504
stream::CuStream
505505

506+
# whether accessing this memory can cause implicit synchronization
507+
synchronizing::Bool
508+
506509
# whether there are outstanding operations that haven't been synchronized
507510
dirty::Bool
508511

509512
# whether the memory has been captured in a way that would make the dirty bit unreliable
510513
captured::Bool
511514

512-
function Managed(mem::AbstractMemory; stream=CUDA.stream(), dirty=true, captured=false)
515+
function Managed(mem::AbstractMemory; stream = CUDA.stream(), synchronizing = true,
516+
dirty = true, captured = false)
513517
# NOTE: memory starts as dirty, because stream-ordered allocations are only
514518
# guaranteed to be physically allocated at a synchronization event.
515-
new{typeof(mem)}(mem, stream, dirty, captured)
519+
new{typeof(mem)}(mem, stream, synchronizing, dirty, captured)
516520
end
517521
end
518522

@@ -524,7 +528,7 @@ function synchronize(managed::Managed)
524528
managed.dirty = false
525529
end
526530
function maybe_synchronize(managed::Managed)
527-
if managed.dirty || managed.captured
531+
if managed.synchronizing && (managed.dirty || managed.captured)
528532
synchronize(managed)
529533
end
530534
end

test/base/array.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ using ChainRulesCore: add!!, is_inplaceable_destination
5151
end
5252
end
5353

54+
@testset "synchronization" begin
55+
a = CUDA.zeros(2, 2)
56+
synchronize(a)
57+
CUDA.enable_synchronization!(a, false)
58+
CUDA.enable_synchronization!(a)
59+
end
60+
5461
@testset "unsafe_wrap" begin
5562
# managed memory -> CuArray
5663
for a in [cu([1]; device=true), cu([1]; unified=true)]

test/core/cudadrv.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ exclusive = attribute(dev, CUDA.DEVICE_ATTRIBUTE_COMPUTE_MODE) == CUDA.CU_COMPUT
1515

1616
synchronize(ctx)
1717

18+
@test startswith(sprint(show, MIME"text/plain"(), ctx), "CuContext")
19+
@test CUDA.api_version(ctx) isa Cuint
20+
1821
if !exclusive
1922
let ctx2 = CuContext(dev)
2023
@test ctx2 == current_context() # ctor implicitly pushes

0 commit comments

Comments
 (0)