Skip to content

Commit 9d6ba70

Browse files
maleadtclaude
andauthored
Use atomic reductions to eliminate two-pass overhead (#141)
When a reduction op has a hardware atomic counterpart for the output dtype, use a single-kernel atomic accumulation instead of the two-pass temp-buffer approach. Each block reduces its chunk and atomically accumulates into the pre-initialized output, eliminating the temp allocation and second kernel launch. Supported atomic ops (type-dispatched): - atomic_add: Float32, Int32, Int64 - atomic_max/min: integers only - atomic_or/and: integers only Falls back to two-pass for unsupported (op, dtype) combinations (e.g., prod, or max on Float32). Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3962d2f commit 9d6ba70

1 file changed

Lines changed: 49 additions & 14 deletions

File tree

src/mapreduce.jl

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
using GPUArrays: neutral_element
22

3+
# Map reduction ops to their atomic counterparts for a given element type.
4+
# Returns nothing when no hardware atomic is available for the (op, dtype) combo.
5+
_atomic_op(::typeof(+), ::Type{<:Union{Float32, Int32, Int64}}) = atomic_add
6+
_atomic_op(::typeof(Base.add_sum), ::Type{<:Union{Float32, Int32, Int64}}) = atomic_add
7+
_atomic_op(::typeof(max), ::Type{<:Integer}) = atomic_max
8+
_atomic_op(::typeof(min), ::Type{<:Integer}) = atomic_min
9+
_atomic_op(::typeof(|), ::Type{<:Integer}) = atomic_or
10+
_atomic_op(::typeof(&), ::Type{<:Integer}) = atomic_and
11+
_atomic_op(_, ::Type) = nothing
12+
313
@generated function mapreduce_kernel(
414
dest::TileArray{TD, N}, src::TileArray{TS, N},
515
f, op, tile_size, reduce_dims, overflow_grids, init_val, pad_mode,
6-
reduce_stride
16+
reduce_stride, atomic_op
717
) where {TD, TS, N}
8-
f_func = f.instance
9-
op_func = op.instance
1018
quote
1119
bids = _unflatten_bids(Val{$N}(), overflow_grids)
1220

@@ -25,15 +33,28 @@ using GPUArrays: neutral_element
2533
d -> (idx_d = idx_d + reduce_stride[d]),
2634
begin
2735
tile = load(src, (@ntuple $N d -> idx_d), tile_size; padding_mode=pad_mode)
28-
acc = $op_func.(acc, $f_func.(tile))
36+
acc = op.(acc, f.(tile))
2937
end)
3038

3139
# Collapse each reduced dimension within the accumulated tile
3240
@nexprs $N d -> if d in reduce_dims
33-
acc = reduce($op_func, acc; dims=d, init=init_val)
41+
acc = reduce(op, acc; dims=d, init=init_val)
42+
end
43+
44+
if atomic_op !== nothing
45+
scalar = reshape(acc, ())
46+
@nexprs $N d -> out_d = d in reduce_dims ? Int32(1) : bids[d]
47+
linear_idx = Int32(1)
48+
stride = Int32(1)
49+
@nexprs $N d -> begin
50+
linear_idx = linear_idx + (out_d - Int32(1)) * stride
51+
stride = stride * size(dest, d)
52+
end
53+
atomic_op(dest, linear_idx, scalar)
54+
else
55+
store(dest, bids, acc)
3456
end
3557

36-
store(dest, bids, acc)
3758
return
3859
end
3960
end
@@ -72,29 +93,43 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray, reduce_dims::
7293

7394
_dim_size(d) = d == par_dim ? par_blocks : d in reduce_dims ? 1 : size(A, d)
7495

96+
# Atomics only work when each block produces one scalar output element,
97+
# i.e., all non-reduce dims have tile size 1
98+
scalar_output = all(d -> d in reduce_dims || ts[d] == 1, 1:N)
99+
atomic_op = scalar_output ? _atomic_op(op, eltype(R)) : nothing
100+
75101
if par_blocks > 1
76-
# Two-pass: parallelize along par_dim, then reduce partials
77-
tmp = similar(A, eltype(R), ntuple(_dim_size, N))
78102
grid = ntuple(N) do d
79103
d == par_dim ? par_blocks : d in reduce_dims ? 1 : cld(size(A, d), ts[d])
80104
end
81105
reduce_stride = ntuple(d -> d == par_dim ? Int32(par_blocks) : Int32(1), N)
82-
_launch_mapreduce!(grid, TileArray(tmp), src_ta, f, op, ts, reduce_dims,
83-
init, pad_mode, reduce_stride)
84-
_mapreducedim!(identity, op, R, tmp, (par_dim,); init)
106+
107+
if atomic_op !== nothing
108+
fill!(R, init)
109+
_launch_mapreduce!(grid, TileArray(R), src_ta, f, op, ts, reduce_dims,
110+
init, pad_mode, reduce_stride, atomic_op)
111+
else
112+
# Two-pass: parallelize along par_dim, then reduce partials
113+
tmp = similar(A, eltype(R), ntuple(_dim_size, N))
114+
_launch_mapreduce!(grid, TileArray(tmp), src_ta, f, op, ts, reduce_dims,
115+
init, pad_mode, reduce_stride, nothing)
116+
_mapreducedim!(identity, op, R, tmp, (par_dim,); init)
117+
end
85118
else
86119
grid = ntuple(d -> d in reduce_dims ? 1 : cld(size(A, d), ts[d]), N)
87120
reduce_stride = ntuple(d -> Int32(1), N)
88121
_launch_mapreduce!(grid, TileArray(R), src_ta, f, op, ts, reduce_dims,
89-
init, pad_mode, reduce_stride)
122+
init, pad_mode, reduce_stride, nothing)
90123
end
91124
end
92125

93-
function _launch_mapreduce!(grid, dest_ta, src_ta, f, op, ts, reduce_dims, init, pad_mode, reduce_stride)
126+
function _launch_mapreduce!(grid, dest_ta, src_ta, f, op, ts, reduce_dims,
127+
init, pad_mode, reduce_stride, atomic_op)
94128
launch_grid, overflow = _flatten_grid(grid)
95129
launch(mapreduce_kernel, launch_grid, dest_ta, src_ta,
96130
f, op, Constant(ts), Constant(reduce_dims), Constant(overflow),
97-
Constant(init), Constant(pad_mode), Constant(reduce_stride))
131+
Constant(init), Constant(pad_mode), Constant(reduce_stride),
132+
atomic_op === nothing ? Constant(nothing) : atomic_op)
98133
end
99134

100135
function _mapreduce(f, op, A::AbstractArray; dims, init)

0 commit comments

Comments
 (0)