@@ -300,153 +300,87 @@ _getindex(arg, I, ptr) = Broadcast._broadcast_getindex(arg, I)
300300
301301# # sparse broadcast implementation
302302
303- # TODO : unify CSC/CSR kernels
303+ iter_type (:: Type{<:CuSparseMatrixCSC} , :: Type{Ti} ) where {Ti} = CSCIterator{Ti}
304+ iter_type (:: Type{<:CuSparseMatrixCSR} , :: Type{Ti} ) where {Ti} = CSRIterator{Ti}
305+ iter_type (:: Type{<:CuSparseDeviceMatrixCSC} , :: Type{Ti} ) where {Ti} = CSCIterator{Ti}
306+ iter_type (:: Type{<:CuSparseDeviceMatrixCSR} , :: Type{Ti} ) where {Ti} = CSRIterator{Ti}
307+
304308# kernel to count the number of non-zeros in a row, to determine the row offsets
305- function compute_offsets_kernel (:: Type{<:CuSparseMatrixCSR} , offsets:: AbstractVector{Ti} ,
309+ function compute_offsets_kernel (T :: Type{<:Union{ CuSparseMatrixCSR, CuSparseMatrixCSC} } , offsets:: AbstractVector{Ti} ,
306310 args... ) where Ti
307311 # every thread processes an entire row
308- row = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
309- row > length (offsets)- 1 && return
310- iter = @inbounds CSRIterator {Ti} (row, args... )
311-
312- # count the nonzero columns of all inputs
313- accum = zero (Ti)
314- for (col, vals) in iter
315- accum += one (Ti)
316- end
317-
318- # the way we write the nnz counts is a bit strange, but done so that the result
319- # after accumulation can be directly used as the rowPtr array of a CSR matrix.
320- @inbounds begin
321- if row == 1
322- offsets[1 ] = 1
323- end
324- offsets[row+ 1 ] = accum
325- end
312+ leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
313+ leading_dim > length (offsets)- 1 && return
314+ iter = @inbounds iter_type (T, Ti)(leading_dim, args... )
326315
327- return
328- end
329- function compute_offsets_kernel (:: Type{<:CuSparseMatrixCSC} , offsets:: AbstractVector{Ti} ,
330- args... ) where Ti
331- # every thread processes an entire columm
332- col = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
333- col > length (offsets)- 1 && return
334- iter = @inbounds CSCIterator {Ti} (col, args... )
335-
336- # count the nonzero columns of all inputs
316+ # count the nonzero leading_dims of all inputs
337317 accum = zero (Ti)
338- for (col , vals) in iter
318+ for (leading_dim , vals) in iter
339319 accum += one (Ti)
340320 end
341321
342322 # the way we write the nnz counts is a bit strange, but done so that the result
343- # after accumulation can be directly used as the colPtr array of a CSC matrix.
323+ # after accumulation can be directly used as the rowPtr/ colPtr array of a CSR/ CSC matrix.
344324 @inbounds begin
345- if col == 1
325+ if leading_dim == 1
346326 offsets[1 ] = 1
347327 end
348- offsets[col + 1 ] = accum
328+ offsets[leading_dim + 1 ] = accum
349329 end
350330
351331 return
352332end
353333
354334# broadcast kernels that iterate the elements of sparse arrays
355- function sparse_to_sparse_broadcast_kernel (f, output:: CuSparseDeviceMatrixCSR{<:Any,Ti} ,
356- offsets:: Union{AbstractVector,Nothing} ,
357- args... ) where {Ti}
335+ function sparse_to_sparse_broadcast_kernel (f, output:: T , offsets:: Union{AbstractVector,Nothing} , args... ) where {Ti, T<: Union{CuSparseDeviceMatrixCSR{<:Any,Ti},CuSparseDeviceMatrixCSC{<:Any,Ti}} }
358336 # every thread processes an entire row
359- row = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
360- row > size (output, 1 ) && return
361- iter = @inbounds CSRIterator {Ti} (row, args... )
337+ leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
338+ leading_dim_size = output isa CuSparseDeviceMatrixCSR ? size (output, 1 ) : size (output, 2 )
339+ leading_dim > leading_dim_size && return
340+ iter = @inbounds iter_type (T, Ti)(leading_dim, args... )
341+
362342
343+ output_ptrs = output isa CuSparseDeviceMatrixCSR ? output. rowPtr : output. colPtr
344+ output_ivals = output isa CuSparseDeviceMatrixCSR ? output. colVal : output. rowVal
363345 # fetch the row offset, and write it to the output
364346 @inbounds begin
365- output_ptr = output . rowPtr[row ] = offsets[row ]
366- if row == size (output, 1 )
367- output . rowPtr[row + 1 i32] = offsets[row + 1 i32]
347+ output_ptr = output_ptrs[leading_dim ] = offsets[leading_dim ]
348+ if leading_dim == leading_dim_size
349+ output_ptrs[leading_dim + 1 i32] = offsets[leading_dim + 1 i32]
368350 end
369351 end
370352
371353 # set the values for this row
372- for (col, ptrs) in iter
373- I = CartesianIndex (row, col)
354+ for (sub_leading_dim, ptrs) in iter
355+ index_first = output isa CuSparseDeviceMatrixCSR ? leading_dim : sub_leading_dim
356+ index_second = output isa CuSparseDeviceMatrixCSR ? sub_leading_dim : leading_dim
357+ I = CartesianIndex (index_first, index_second)
374358 vals = ntuple (Val (length (args))) do i
375359 arg = @inbounds args[i]
376360 ptr = @inbounds ptrs[i]
377361 _getindex (arg, I, ptr)
378362 end
379363
380- @inbounds output . colVal [output_ptr] = col
364+ @inbounds output_ivals [output_ptr] = sub_leading_dim
381365 @inbounds output. nzVal[output_ptr] = f (vals... )
382366 output_ptr += one (Ti)
383367 end
384368
385369 return
386370end
387- function sparse_to_sparse_broadcast_kernel (f, output:: CuSparseDeviceMatrixCSC{<:Any,Ti} ,
388- offsets:: Union{AbstractVector,Nothing} ,
389- args... ) where {Ti}
390- # every thread processes an entire column
391- col = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
392- col > size (output, 2 ) && return
393- iter = @inbounds CSCIterator {Ti} (col, args... )
394-
395- # fetch the column offset, and write it to the output
396- @inbounds begin
397- output_ptr = output. colPtr[col] = offsets[col]
398- if col == size (output, 2 )
399- output. colPtr[col+ 1 i32] = offsets[col+ 1 i32]
400- end
401- end
402-
403- # set the values for this col
404- for (row, ptrs) in iter
405- I = CartesianIndex (col, row)
406- vals = ntuple (Val (length (args))) do i
407- arg = @inbounds args[i]
408- ptr = @inbounds ptrs[i]
409- _getindex (arg, I, ptr)
410- end
411-
412- @inbounds output. rowVal[output_ptr] = row
413- @inbounds output. nzVal[output_ptr] = f (vals... )
414- output_ptr += one (Ti)
415- end
416-
417- return
418- end
419- function sparse_to_dense_broadcast_kernel (:: Type{<:CuSparseMatrixCSR} , f,
420- output:: CuDeviceArray , args... )
371+ function sparse_to_dense_broadcast_kernel (T:: Type{<:Union{CuSparseMatrixCSR{Tv, Ti}, CuSparseMatrixCSC{Tv, Ti}}} , f,
372+ output:: CuDeviceArray , args... ) where {Tv, Ti}
421373 # every thread processes an entire row
422- row = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
423- row > size (output, 1 ) && return
424- iter = @inbounds CSRIterator {Int} (row, args... )
374+ leading_dim = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
375+ leading_dim_size = T <: CuSparseMatrixCSR ? size (output, 1 ) : size (output, 2 )
376+ leading_dim > leading_dim_size && return
377+ iter = @inbounds iter_type (T, Ti)(leading_dim, args... )
425378
426379 # set the values for this row
427- for (col, ptrs) in iter
428- I = CartesianIndex (row, col)
429- vals = ntuple (Val (length (args))) do i
430- arg = @inbounds args[i]
431- ptr = @inbounds ptrs[i]
432- _getindex (arg, I, ptr)
433- end
434-
435- @inbounds output[I] = f (vals... )
436- end
437-
438- return
439- end
440- function sparse_to_dense_broadcast_kernel (:: Type{<:CuSparseMatrixCSC} , f,
441- output:: CuDeviceArray , args... )
442- # every thread processes an entire column
443- col = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
444- col > size (output, 2 ) && return
445- iter = @inbounds CSCIterator {Int} (col, args... )
446-
447- # set the values for this col
448- for (row, ptrs) in iter
449- I = CartesianIndex (row, col)
380+ for (sub_leading_dim, ptrs) in iter
381+ index_first = T <: CuSparseMatrixCSR ? leading_dim : sub_leading_dim
382+ index_second = T <: CuSparseMatrixCSR ? sub_leading_dim : leading_dim
383+ I = CartesianIndex (index_first, index_second)
450384 vals = ntuple (Val (length (args))) do i
451385 arg = @inbounds args[i]
452386 ptr = @inbounds ptrs[i]
0 commit comments