8484 end
8585 end
8686end
87- @inline function _capturescalars (arg) # this definition is just an optimization (to bottom out the recursion slightly sooner)
87+ @inline function _capturescalars (arg)
88+ # this definition is just an optimization (to bottom out the recursion slightly sooner)
8889 if scalararg (arg)
8990 return (), () -> (arg,) # add scalararg
9091 elseif scalarwrappedarg (arg)
287288end
288289
289290# helpers to index a sparse or dense array
290- @inline function _getindex (arg:: Union{CuSparseDeviceMatrixCSR{Tv},CuSparseDeviceMatrixCSC{Tv},CuSparseDeviceVector{Tv}} , I, ptr):: Tv where {Tv}
291+ @inline function _getindex (arg:: Union {CuSparseDeviceMatrixCSR{Tv},
292+ CuSparseDeviceMatrixCSC{Tv},
293+ CuSparseDeviceVector{Tv}}, I, ptr):: Tv where {Tv}
291294 if ptr == 0
292295 return zero (Tv)
293296 else
@@ -323,7 +326,9 @@ function _get_my_row(first_row)::Int32
323326 return row_ix + first_row - 1 i32
324327end
325328
326- function compute_offsets_kernel (:: Type{<:CuSparseVector} , first_row:: Ti , last_row:: Ti , fpreszeros:: Bool , offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} , args... ) where {Ti, N}
329+ function compute_offsets_kernel (:: Type{<:CuSparseVector} , first_row:: Ti , last_row:: Ti ,
330+ fpreszeros:: Bool , offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
331+ args... ) where {Ti, N}
327332 row = _get_my_row (first_row)
328333 row > last_row && return
329334
@@ -343,7 +348,8 @@ function compute_offsets_kernel(::Type{<:CuSparseVector}, first_row::Ti, last_ro
343348end
344349
345350# kernel to count the number of non-zeros in a row, to determine the row offsets
346- function compute_offsets_kernel (T:: Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}} , offsets:: AbstractVector{Ti} ,
351+ function compute_offsets_kernel (T:: Type{<:Union{CuSparseMatrixCSR, CuSparseMatrixCSC}} ,
352+ offsets:: AbstractVector{Ti} ,
347353 args... ) where Ti
348354 # every thread processes an entire row
349355 leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
@@ -368,7 +374,9 @@ function compute_offsets_kernel(T::Type{<:Union{CuSparseMatrixCSR, CuSparseMatri
368374 return
369375end
370376
371- function sparse_to_sparse_broadcast_kernel (f:: F , output:: CuSparseDeviceVector{Tv,Ti} , offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} , args... ) where {Tv, Ti, N, F}
377+ function sparse_to_sparse_broadcast_kernel (f:: F , output:: CuSparseDeviceVector{Tv,Ti} ,
378+ offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
379+ args... ) where {Tv, Ti, N, F}
372380 row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
373381 row_ix > output. nnz && return
374382 row_and_ptrs = @inbounds offsets[row_ix]
@@ -382,12 +390,14 @@ function sparse_to_sparse_broadcast_kernel(f::F, output::CuSparseDeviceVector{Tv
382390 _getindex (arg, row, ptr):: Tv
383391 end
384392 output_val = f (vals... )
385- @inbounds output. iPtr[row_ix] = row
393+ @inbounds output. iPtr[row_ix] = row
386394 @inbounds output. nzVal[row_ix] = output_val
387395 return
388396end
389397
390- function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} , args... ) where {Ti, T<: Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}} }
398+ function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} ,
399+ args... ) where {Ti, T<: Union {CuSparseDeviceMatrixCSR{<: Any ,Ti},
400+ CuSparseDeviceMatrixCSC{<: Any ,Ti}}}
391401 # every thread processes an entire row
392402 leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
393403 leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size (output, 1 ) : size (output, 2 )
@@ -423,7 +433,8 @@ function sparse_to_sparse_broadcast_kernel(f, output::T, offsets::Union{Abstract
423433
424434 return
425435end
426- function sparse_to_dense_broadcast_kernel (T:: Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}} , f,
436+ function sparse_to_dense_broadcast_kernel (T:: Type {<: Union {CuSparseMatrixCSR{Tv, Ti},
437+ CuSparseMatrixCSC{Tv, Ti}}}, f,
427438 output:: CuDeviceArray , args... ) where {Tv, Ti}
428439 # every thread processes an entire row
429440 leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
@@ -449,7 +460,9 @@ function sparse_to_dense_broadcast_kernel(T::Type{<:Union{CuSparseMatrixCSR{Tv,
449460end
450461
451462function sparse_to_dense_broadcast_kernel (:: Type{<:CuSparseVector} , f:: F ,
452- output:: CuDeviceArray{Tv} , offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} , args... ) where {Tv, F, N, Ti}
463+ output:: CuDeviceArray{Tv} ,
464+ offsets:: AbstractVector{Pair{Ti, NTuple{N, Ti}}} ,
465+ args... ) where {Tv, F, N, Ti}
453466 # every thread processes an entire row
454467 row_ix = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
455468 row_ix > length (output) && return
@@ -468,7 +481,7 @@ function sparse_to_dense_broadcast_kernel(::Type{<:CuSparseVector}, f::F,
468481 return
469482end
470483# # COV_EXCL_STOP
471- const N_VEC_THREADS = 512
484+
472485function Broadcast. copy (bc:: Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyle}} )
473486 # find the sparse inputs
474487 bc = Broadcast. flatten (bc)
@@ -510,7 +523,7 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
510523
511524 # the kernels below parallelize across rows or cols, not elements, so it's unlikely
512525 # we'll launch many threads. to maximize utilization, parallelize across blocks first.
513- rows, cols = sparse_typ <: CuSparseVector ? ( length (bc), 1 ) : size (bc)
526+ rows, cols = get ( size (bc), 1 , 1 ), get ( size (bc), 2 , 1 ) # ` size(bc, ::Int)` is missing
514527 function compute_launch_config (kernel)
515528 config = launch_configuration (kernel. fun)
516529 if sparse_typ <: CuSparseMatrixCSR
@@ -522,15 +535,15 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
522535 blocks = max (cld (cols, threads), config. blocks)
523536 threads = cld (cols, blocks)
524537 elseif sparse_typ <: CuSparseVector
525- threads = N_VEC_THREADS
538+ threads = 512
526539 blocks = max (cld (rows, threads), config. blocks)
527- threads = N_VEC_THREADS
528540 end
529541 (; threads, blocks)
530542 end
531543 # for CuSparseVec, figure out the actual row range we need to address, e.g. if m = 2^20
532544 # but the only rows present in any sparse vector input are between 2 and 128, no need to
533- # launch massive threads. TODO : use the difference here to set the thread count
545+ # launch massive threads.
546+ # TODO : use the difference here to set the thread count
534547 overall_first_row = one (Ti)
535548 overall_last_row = Ti (rows)
536549 offsets = nothing
@@ -592,10 +605,10 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
592605 CuVector {Pair{Ti, NTuple{length(bc.args), Ti}}} (undef, overall_last_row - overall_first_row + 1 )
593606 end
594607 let
595- if sparse_typ <: CuSparseVector
596- args = (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc. args... )
608+ args = if sparse_typ <: CuSparseVector
609+ (sparse_typ, overall_first_row, overall_last_row, fpreszeros, offsets, bc. args... )
597610 else
598- args = (sparse_typ, offsets, bc. args... )
611+ (sparse_typ, offsets, bc. args... )
599612 end
600613 kernel = @cuda launch= false compute_offsets_kernel (args... )
601614 threads, blocks = compute_launch_config (kernel)
@@ -642,14 +655,13 @@ function Broadcast.copy(bc::Broadcasted{<:Union{CuSparseVecStyle,CuSparseMatStyl
642655 if output isa AbstractCuSparseArray
643656 args = (bc. f, output, offsets, bc. args... )
644657 kernel = @cuda launch= false sparse_to_sparse_broadcast_kernel (args... )
645- threads, blocks = compute_launch_config (kernel)
646- kernel (args... ; threads, blocks)
647658 else
648- args = sparse_typ <: CuSparseVector ? (sparse_typ, bc. f, output, offsets, bc. args... ) : (sparse_typ, bc. f, output, bc. args... )
659+ args = sparse_typ <: CuSparseVector ? (sparse_typ, bc. f, output, offsets, bc. args... ) :
660+ (sparse_typ, bc. f, output, bc. args... )
649661 kernel = @cuda launch= false sparse_to_dense_broadcast_kernel (args... )
650- threads, blocks = compute_launch_config (kernel)
651- kernel (args... ; threads, blocks)
652662 end
663+ threads, blocks = compute_launch_config (kernel)
664+ kernel (args... ; threads, blocks)
653665
654666 return output
655667end
0 commit comments