Skip to content

Commit 43dc26f

Browse files
committed
Clean-ups.
1 parent e2f3d4b commit 43dc26f

1 file changed

Lines changed: 34 additions & 22 deletions

File tree

lib/cusparse/broadcast.jl

Lines changed: 34 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ end
8484
end
8585
end
8686
end
87-
@inline function _capturescalars(arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
87+
@inline function _capturescalars(arg)
88+
# this definition is just an optimization (to bottom out the recursion slightly sooner)
8889
if scalararg(arg)
8990
return (), () -> (arg,) # add scalararg
9091
elseif scalarwrappedarg(arg)
@@ -287,7 +288,9 @@ end
287288
end
288289

289290
# helpers to index a sparse or dense array
290-
@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},CuSparseDeviceMatrixCSC{Tv},CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
291+
@inline function _getindex(arg::Union{CuSparseDeviceMatrixCSR{Tv},
292+
CuSparseDeviceMatrixCSC{Tv},
293+
CuSparseDeviceVector{Tv}}, I, ptr)::Tv where {Tv}
291294
if ptr == 0
292295
return zero(Tv)
293296
else
@@ -323,7 +326,9 @@ function _get_my_row(first_row)::Int32
323326
return row_ix + first_row - 1i32
324327
end
325328

326-
function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti, fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Ti, N}
329+
function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_row::Ti,
330+
fpreszeros::Bool, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
331+
args...) where {Ti, N}
327332
row = _get_my_row(first_row)
328333
row > last_row && return
329334

@@ -343,7 +348,8 @@ function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_ro
343348
end
344349

345350
# kernel to count the number of non-zeros in a row, to determine the row offsets
346-
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}}, offsets::AbstractVector{Ti},
351+
function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}},
352+
offsets::AbstractVector{Ti},
347353
args...) where Ti
348354
# every thread processes an entire row
349355
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -368,7 +374,9 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
368374
return
369375
end
370376

371-
function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, Ti, N, F}
377+
function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv,Ti},
378+
offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
379+
args...) where {Tv, Ti, N, F}
372380
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
373381
row_ix > output.nnz && return
374382
row_and_ptrs = @inbounds offsets[row_ix]
@@ -382,12 +390,14 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
382390
_getindex(arg, row, ptr)::Tv
383391
end
384392
output_val = f(vals...)
385-
@inbounds output.iPtr[row_ix] = row
393+
@inbounds output.iPtr[row_ix] = row
386394
@inbounds output.nzVal[row_ix] = output_val
387395
return
388396
end
389397

390-
function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing}, args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}}}
398+
function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{AbstractVector,Nothing},
399+
args...) where {Ti, T<:Union{CuSparseDeviceMatrixCSR{<:Any,Ti},
400+
CuSparseDeviceMatrixCSC{<:Any,Ti}}}
391401
# every thread processes an entire row
392402
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
393403
leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size(output, 1) : size(output, 2)
@@ -423,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
423433

424434
return
425435
end
426-
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}}, f,
436+
function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv, Ti},
437+
CuSparseMatrixCSC{Tv, Ti}}}, f,
427438
output::CuDeviceArray, args...) where {Tv, Ti}
428439
# every thread processes an entire row
429440
leading_dim = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
@@ -449,7 +460,9 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
449460
end
450461

451462
function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
452-
output::CuDeviceArray{Tv}, offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}}, args...) where {Tv, F, N, Ti}
463+
output::CuDeviceArray{Tv},
464+
offsets::AbstractVector{Pair{Ti, NTuple{N, Ti}}},
465+
args...) where {Tv, F, N, Ti}
453466
# every thread processes an entire row
454467
row_ix = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
455468
row_ix > length(output) && return
@@ -468,7 +481,7 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
468481
return
469482
end
470483
## COV_EXCL_STOP
471-
const N_VEC_THREADS = 512
484+
472485
function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}})
473486
# find the sparse inputs
474487
bc = Broadcast.flatten(bc)
@@ -510,7 +523,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
510523

511524
# the kernels below parallelize across rows or cols, not elements, so it's unlikely
512525
# we'll launch many threads. to maximize utilization, parallelize across blocks first.
513-
rows, cols = sparse_typ <: CuSparseVector ? (length(bc), 1) : size(bc)
526+
rows, cols = get(size(bc), 1, 1), get(size(bc), 2, 1) # `size(bc, ::Int)` is missing
514527
function compute_launch_config(kernel)
515528
config = launch_configuration(kernel.fun)
516529
if sparse_typ <: CuSparseMatrixCSR
@@ -522,15 +535,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
522535
blocks = max(cld(cols, threads), config.blocks)
523536
threads = cld(cols, blocks)
524537
elseif sparse_typ <: CuSparseVector
525-
threads = N_VEC_THREADS
538+
threads = 512
526539
blocks = max(cld(rows, threads), config.blocks)
527-
threads = N_VEC_THREADS
528540
end
529541
(; threads, blocks)
530542
end
531543
# for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
532544
# but the only rows present in any sparse vector input are between 2 and 128, no need to
533-
# launch massive threads. TODO: use the difference here to set the thread count
545+
# launch massive threads.
546+
# TODO: use the difference here to set the thread count
534547
overall_first_row = one(Ti)
535548
overall_last_row = Ti(rows)
536549
offsets = nothing
@@ -592,10 +605,10 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
592605
CuVector{Pair{Ti, NTuple{length(bc.args), Ti}}}(undef, overall_last_row - overall_first_row + 1)
593606
end
594607
let
595-
if sparse_typ <: CuSparseVector
596-
args = (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
608+
args = if sparse_typ <: CuSparseVector
609+
(sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc.args...)
597610
else
598-
args = (sparse_typ, offsets, bc.args...)
611+
(sparse_typ, offsets, bc.args...)
599612
end
600613
kernel = @cuda launch=false compute_offsets_kernel(args...)
601614
threads, blocks = compute_launch_config(kernel)
@@ -642,14 +655,13 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
642655
if output isa AbstractCuSparseArray
643656
args = (bc.f, output, offsets, bc.args...)
644657
kernel = @cuda launch=false sparse_to_sparse_broadcast_kernel(args...)
645-
threads, blocks = compute_launch_config(kernel)
646-
kernel(args...; threads, blocks)
647658
else
648-
args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) : (sparse_typ, bc.f, output, bc.args...)
659+
args = sparse_typ <: CuSparseVector ? (sparse_typ, bc.f, output, offsets, bc.args...) :
660+
(sparse_typ, bc.f, output, bc.args...)
649661
kernel = @cuda launch=false sparse_to_dense_broadcast_kernel(args...)
650-
threads, blocks = compute_launch_config(kernel)
651-
kernel(args...; threads, blocks)
652662
end
663+
threads, blocks = compute_launch_config(kernel)
664+
kernel(args...; threads, blocks)
653665

654666
return output
655667
end

0 commit comments

Comments
 (0)