@@ -10,11 +10,12 @@ function Base.mapreduce(f, op, A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC};
1010 (ET === Union{} || ET === Any) &&
1111 error (" mapreduce cannot figure the output element type, please pass an explicit init value" )
1212
13- init = GPUArrays . neutral_element (op, ET)
13+ init = zero ( ET)
1414 else
1515 ET = typeof (init)
1616 end
1717
18+ f_preserves_zeros = ( f (zero (ET)) == zero (ET) )
1819 # we only handle reducing along one of the two dimensions,
1920 # or a complete reduction (requiring an additional pass)
2021 in (dims, [Colon (), 1 , 2 ]) || error (" only dims=:, dims=1 or dims=2 is supported" )
@@ -29,29 +30,29 @@ function Base.mapreduce(f, op, A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC};
2930 if A isa CuSparseMatrixCSR
3031 output = CuArray {ET} (undef, m)
3132
32- kernel = @cuda launch= false csr_reduce_kernel (f, op, init, output, A)
33+ kernel = @cuda launch= false csr_reduce_kernel (f, op, init, f_preserves_zeros, output, A)
3334 config = launch_configuration (kernel. fun)
3435 threads = min (m, config. threads)
3536 blocks = cld (m, threads)
3637 elseif A isa CuSparseMatrixCSC
3738 output = CuArray {ET} (undef, (1 , n))
3839
39- kernel = @cuda launch= false csc_reduce_kernel (f, op, init, output, A)
40+ kernel = @cuda launch= false csc_reduce_kernel (f, op, init, f_preserves_zeros, output, A)
4041 config = launch_configuration (kernel. fun)
4142 threads = min (n, config. threads)
4243 blocks = cld (n, threads)
4344 end
44- kernel (f, op, init, output, A; threads, blocks)
45+ kernel (f, op, init, f_preserves_zeros, output, A; threads, blocks)
4546
4647 if dims == Colon ()
47- mapreduce (f , op, output; init)
48+ mapreduce (identity , op, output; init)
4849 else
4950 output
5051 end
5152end
5253
5354# # COV_EXCL_START
54- function csr_reduce_kernel (f:: F , op:: OP , neutral, output:: CuDeviceArray , args... ) where {F, OP}
55+ function csr_reduce_kernel (f:: F , op:: OP , neutral, zeros_preserved :: Bool , output:: CuDeviceArray , args... ) where {F, OP}
5556 # every thread processes an entire row
5657 row = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
5758 row > size (output, 1 ) && return
@@ -69,12 +70,23 @@ function csr_reduce_kernel(f::F, op::OP, neutral, output::CuDeviceArray, args...
6970 end
7071 val = op (val, f (vals... ))
7172 end
73+ if ! zeros_preserved
74+ f_zero_val = f (zero (neutral))
75+ next_row_ind = row+ 1 i32
76+ nzs_this_row = ntuple (Val (length (args))) do i
77+ max_n_zeros = size (args[i], 2 )
78+ arg_row_ptr = args[i]. rowPtr
79+ nz_this_row = max_n_zeros - (@inbounds (arg_row_ptr[next_row_ind]) - @inbounds (arg_row_ptr[row]))
80+ return nz_this_row * f_zero_val
81+ end
82+ val = op (val, nzs_this_row... )
83+ end
7284
7385 @inbounds output[row] = val
7486 return
7587end
7688
77- function csc_reduce_kernel (f:: F , op:: OP , neutral, output:: CuDeviceArray , args... ) where {F, OP}
89+ function csc_reduce_kernel (f:: F , op:: OP , neutral, zeros_preserved :: Bool , output:: CuDeviceArray , args... ) where {F, OP}
7890 # every thread processes an entire column
7991 col = threadIdx (). x + (blockIdx (). x - 1 i32) * blockDim (). x
8092 col > size (output, 2 ) && return
@@ -92,6 +104,17 @@ function csc_reduce_kernel(f::F, op::OP, neutral, output::CuDeviceArray, args...
92104 end
93105 val = op (val, f (vals... ))
94106 end
107+ if ! zeros_preserved
108+ f_zero_val = f (zero (neutral))
109+ next_col_ind = col+ 1 i32
110+ nzs_this_col = ntuple (Val (length (args))) do i
111+ max_n_zeros = size (args[i], 1 )
112+ arg_col_ptr = args[i]. colPtr
113+ nz_this_col = max_n_zeros - (@inbounds (arg_col_ptr[next_col_ind]) - @inbounds (arg_col_ptr[col]))
114+ return nz_this_col * f_zero_val
115+ end
116+ val = op (val, nzs_this_col... )
117+ end
95118
96119 @inbounds output[col] = val
97120 return
0 commit comments