Skip to content

Commit f7465ff

Browse files
kshyattamontoison
authored andcommitted
Set neutral element to zero for sparse reduce
1 parent 6180d2c commit f7465ff

2 files changed

Lines changed: 51 additions & 7 deletions

File tree

lib/cusparse/reduce.jl

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@ function Base.mapreduce(f, op, A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC};
1010
(ET === Union{} || ET === Any) &&
1111
error("mapreduce cannot figure the output element type, please pass an explicit init value")
1212

13-
init = GPUArrays.neutral_element(op, ET)
13+
init = zero(ET)
1414
else
1515
ET = typeof(init)
1616
end
1717

18+
f_preserves_zeros = ( f(zero(ET)) == zero(ET) )
1819
# we only handle reducing along one of the two dimensions,
1920
# or a complete reduction (requiring an additional pass)
2021
in(dims, [Colon(), 1, 2]) || error("only dims=:, dims=1 or dims=2 is supported")
@@ -29,29 +30,29 @@ function Base.mapreduce(f, op, A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC};
2930
if A isa CuSparseMatrixCSR
3031
output = CuArray{ET}(undef, m)
3132

32-
kernel = @cuda launch=false csr_reduce_kernel(f, op, init, output, A)
33+
kernel = @cuda launch=false csr_reduce_kernel(f, op, init, f_preserves_zeros, output, A)
3334
config = launch_configuration(kernel.fun)
3435
threads = min(m, config.threads)
3536
blocks = cld(m, threads)
3637
elseif A isa CuSparseMatrixCSC
3738
output = CuArray{ET}(undef, (1, n))
3839

39-
kernel = @cuda launch=false csc_reduce_kernel(f, op, init, output, A)
40+
kernel = @cuda launch=false csc_reduce_kernel(f, op, init, f_preserves_zeros, output, A)
4041
config = launch_configuration(kernel.fun)
4142
threads = min(n, config.threads)
4243
blocks = cld(n, threads)
4344
end
44-
kernel(f, op, init, output, A; threads, blocks)
45+
kernel(f, op, init, f_preserves_zeros, output, A; threads, blocks)
4546

4647
if dims == Colon()
47-
mapreduce(f, op, output; init)
48+
mapreduce(identity, op, output; init)
4849
else
4950
output
5051
end
5152
end
5253

5354
## COV_EXCL_START
54-
function csr_reduce_kernel(f::F, op::OP, neutral, output::CuDeviceArray, args...) where {F, OP}
55+
function csr_reduce_kernel(f::F, op::OP, neutral, zeros_preserved::Bool, output::CuDeviceArray, args...) where {F, OP}
5556
# every thread processes an entire row
5657
row = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
5758
row > size(output, 1) && return
@@ -69,12 +70,23 @@ function csr_reduce_kernel(f::F, op::OP, neutral, output::CuDeviceArray, args...
6970
end
7071
val = op(val, f(vals...))
7172
end
73+
if !zeros_preserved
74+
f_zero_val = f(zero(neutral))
75+
next_row_ind = row+1i32
76+
nzs_this_row = ntuple(Val(length(args))) do i
77+
max_n_zeros = size(args[i], 2)
78+
arg_row_ptr = args[i].rowPtr
79+
nz_this_row = max_n_zeros - (@inbounds(arg_row_ptr[next_row_ind]) - @inbounds(arg_row_ptr[row]))
80+
return nz_this_row * f_zero_val
81+
end
82+
val = op(val, nzs_this_row...)
83+
end
7284

7385
@inbounds output[row] = val
7486
return
7587
end
7688

77-
function csc_reduce_kernel(f::F, op::OP, neutral, output::CuDeviceArray, args...) where {F, OP}
89+
function csc_reduce_kernel(f::F, op::OP, neutral, zeros_preserved::Bool, output::CuDeviceArray, args...) where {F, OP}
7890
# every thread processes an entire column
7991
col = threadIdx().x + (blockIdx().x - 1i32) * blockDim().x
8092
col > size(output, 2) && return
@@ -92,6 +104,17 @@ function csc_reduce_kernel(f::F, op::OP, neutral, output::CuDeviceArray, args...
92104
end
93105
val = op(val, f(vals...))
94106
end
107+
if !zeros_preserved
108+
f_zero_val = f(zero(neutral))
109+
next_col_ind = col+1i32
110+
nzs_this_col = ntuple(Val(length(args))) do i
111+
max_n_zeros = size(args[i], 1)
112+
arg_col_ptr = args[i].colPtr
113+
nz_this_col = max_n_zeros - (@inbounds(arg_col_ptr[next_col_ind]) - @inbounds(arg_col_ptr[col]))
114+
return nz_this_col * f_zero_val
115+
end
116+
val = op(val, nzs_this_col...)
117+
end
95118

96119
@inbounds output[col] = val
97120
return

test/libraries/cusparse/reduce.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,41 @@ for elty in [Int32, Int64, Float32, Float64]
1212
y = sum(x)
1313
dy = sum(dx)
1414
@test y dy
15+
16+
y = mapreduce(exp, +, x)
17+
dy = mapreduce(exp, +, dx)
18+
@test y dy
1519

1620
# dim=1
1721
y = sum(x, dims=1)
1822
dy = sum(dx, dims=1)
1923
@test y Array(dy)
24+
25+
y = mapreduce(exp, +, x, dims=1)
26+
dy = mapreduce(exp, +, dx, dims=1)
27+
@test y Array(dy)
2028

2129
# dim=2
2230
y = sum(x, dims=2)
2331
dy = sum(dx, dims=2)
2432
@test y Array(dy)
33+
34+
y = mapreduce(exp, +, x, dims=2)
35+
dy = mapreduce(exp, +, dx, dims=2)
36+
@test y Array(dy)
2537
if elty in (Float32, Float64)
2638
dy = mapreduce(abs, +, dx; init=zero(elty))
2739
y = mapreduce(abs, +, x; init=zero(elty))
2840
@test y dy
2941
end
42+
43+
# test with a matrix with fully empty rows
44+
x = zeros(elty, m, n)
45+
x[2, :] .= -one(elty)
46+
x[2, end] = -elty(16)
47+
dx = typ(sparse(x))
48+
y = mapreduce(abs, max, x)
49+
dy = mapreduce(abs, max, dx)
50+
@test y dy
3051
end
3152
end

0 commit comments

Comments
 (0)