11using GPUArrays: neutral_element
22
3+ # Map reduction ops to their atomic counterparts for a given element type.
4+ # Returns nothing when no hardware atomic is available for the (op, dtype) combo.
5+ _atomic_op (:: typeof (+ ), :: Type{<:Union{Float32, Int32, Int64}} ) = atomic_add
6+ _atomic_op (:: typeof (Base. add_sum), :: Type{<:Union{Float32, Int32, Int64}} ) = atomic_add
7+ _atomic_op (:: typeof (max), :: Type{<:Integer} ) = atomic_max
8+ _atomic_op (:: typeof (min), :: Type{<:Integer} ) = atomic_min
9+ _atomic_op (:: typeof (| ), :: Type{<:Integer} ) = atomic_or
10+ _atomic_op (:: typeof (& ), :: Type{<:Integer} ) = atomic_and
11+ _atomic_op (_, :: Type ) = nothing
12+
313@generated function mapreduce_kernel (
414 dest:: TileArray{TD, N} , src:: TileArray{TS, N} ,
515 f, op, tile_size, reduce_dims, overflow_grids, init_val, pad_mode,
6- reduce_stride
16+ reduce_stride, atomic_op
717) where {TD, TS, N}
8- f_func = f. instance
9- op_func = op. instance
1018 quote
1119 bids = _unflatten_bids (Val {$N} (), overflow_grids)
1220
@@ -25,15 +33,28 @@ using GPUArrays: neutral_element
2533 d -> (idx_d = idx_d + reduce_stride[d]),
2634 begin
2735 tile = load (src, (@ntuple $ N d -> idx_d), tile_size; padding_mode= pad_mode)
28- acc = $ op_func .(acc, $ f_func .(tile))
36+ acc = op .(acc, f .(tile))
2937 end )
3038
3139 # Collapse each reduced dimension within the accumulated tile
3240 @nexprs $ N d -> if d in reduce_dims
33- acc = reduce ($ op_func, acc; dims= d, init= init_val)
41+ acc = reduce (op, acc; dims= d, init= init_val)
42+ end
43+
44+ if atomic_op != = nothing
45+ scalar = reshape (acc, ())
46+ @nexprs $ N d -> out_d = d in reduce_dims ? Int32 (1 ) : bids[d]
47+ linear_idx = Int32 (1 )
48+ stride = Int32 (1 )
49+ @nexprs $ N d -> begin
50+ linear_idx = linear_idx + (out_d - Int32 (1 )) * stride
51+ stride = stride * size (dest, d)
52+ end
53+ atomic_op (dest, linear_idx, scalar)
54+ else
55+ store (dest, bids, acc)
3456 end
3557
36- store (dest, bids, acc)
3758 return
3859 end
3960end
@@ -72,29 +93,43 @@ function _mapreducedim!(f, op, R::AbstractArray, A::AbstractArray, reduce_dims::
7293
7394 _dim_size (d) = d == par_dim ? par_blocks : d in reduce_dims ? 1 : size (A, d)
7495
96+ # Atomics only work when each block produces one scalar output element,
97+ # i.e., all non-reduce dims have tile size 1
98+ scalar_output = all (d -> d in reduce_dims || ts[d] == 1 , 1 : N)
99+ atomic_op = scalar_output ? _atomic_op (op, eltype (R)) : nothing
100+
75101 if par_blocks > 1
76- # Two-pass: parallelize along par_dim, then reduce partials
77- tmp = similar (A, eltype (R), ntuple (_dim_size, N))
78102 grid = ntuple (N) do d
79103 d == par_dim ? par_blocks : d in reduce_dims ? 1 : cld (size (A, d), ts[d])
80104 end
81105 reduce_stride = ntuple (d -> d == par_dim ? Int32 (par_blocks) : Int32 (1 ), N)
82- _launch_mapreduce! (grid, TileArray (tmp), src_ta, f, op, ts, reduce_dims,
83- init, pad_mode, reduce_stride)
84- _mapreducedim! (identity, op, R, tmp, (par_dim,); init)
106+
107+ if atomic_op != = nothing
108+ fill! (R, init)
109+ _launch_mapreduce! (grid, TileArray (R), src_ta, f, op, ts, reduce_dims,
110+ init, pad_mode, reduce_stride, atomic_op)
111+ else
112+ # Two-pass: parallelize along par_dim, then reduce partials
113+ tmp = similar (A, eltype (R), ntuple (_dim_size, N))
114+ _launch_mapreduce! (grid, TileArray (tmp), src_ta, f, op, ts, reduce_dims,
115+ init, pad_mode, reduce_stride, nothing )
116+ _mapreducedim! (identity, op, R, tmp, (par_dim,); init)
117+ end
85118 else
86119 grid = ntuple (d -> d in reduce_dims ? 1 : cld (size (A, d), ts[d]), N)
87120 reduce_stride = ntuple (d -> Int32 (1 ), N)
88121 _launch_mapreduce! (grid, TileArray (R), src_ta, f, op, ts, reduce_dims,
89- init, pad_mode, reduce_stride)
122+ init, pad_mode, reduce_stride, nothing )
90123 end
91124end
92125
93- function _launch_mapreduce! (grid, dest_ta, src_ta, f, op, ts, reduce_dims, init, pad_mode, reduce_stride)
126+ function _launch_mapreduce! (grid, dest_ta, src_ta, f, op, ts, reduce_dims,
127+ init, pad_mode, reduce_stride, atomic_op)
94128 launch_grid, overflow = _flatten_grid (grid)
95129 launch (mapreduce_kernel, launch_grid, dest_ta, src_ta,
96130 f, op, Constant (ts), Constant (reduce_dims), Constant (overflow),
97- Constant (init), Constant (pad_mode), Constant (reduce_stride))
131+ Constant (init), Constant (pad_mode), Constant (reduce_stride),
132+ atomic_op === nothing ? Constant (nothing ) : atomic_op)
98133end
99134
100135function _mapreduce (f, op, A:: AbstractArray ; dims, init)
0 commit comments