Skip to content

Commit a4ae9cc

Browse files
Use relaxed memory ordering for Triton atomics on AMDGPU. (#3945)
1 parent 4e18d87 commit a4ae9cc

4 files changed

Lines changed: 26 additions & 4 deletions

File tree

torchao/prototype/common/triton/matmul.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,11 @@ def _kernel(
259259
if SPLIT_K == 1:
260260
tl.store(C, acc, mask=mask)
261261
else:
262-
tl.atomic_add(C, acc, mask=mask)
262+
# AMD GPUs need relaxed semantics for better performance
263+
if tl.constexpr(torch.version.hip is not None):
264+
tl.atomic_add(C, acc, mask=mask, sem="relaxed")
265+
else:
266+
tl.atomic_add(C, acc, mask=mask)
263267

264268

265269
class _matmul(torch.autograd.Function):

torchao/prototype/float8nocompile/kernels/fp8_dynamic_tensorwise.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,11 @@ def _amax_atomic(
473473
block_mask = block_offs < num_elements
474474
vals = tl.load(input_ptr + block_offs, mask=block_mask).to(input_dtype)
475475
block_amax = tl.max(tl.abs(vals))
476-
tl.atomic_max(amax_ptr, block_amax)
476+
# AMD GPUs need relaxed semantics for better performance
477+
if tl.constexpr(torch.version.hip is not None):
478+
tl.atomic_max(amax_ptr, block_amax, sem="relaxed")
479+
else:
480+
tl.atomic_max(amax_ptr, block_amax)
477481

478482

479483
@triton.jit

torchao/prototype/hqq/kernels.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
6+
import torch
67
import triton
78
import triton.language as tl
89
from triton import Config
@@ -389,7 +390,11 @@ def _mixed_mm_kernel(
389390
if SPLIT_K == 1:
390391
tl.store(C, acc, mask=mask)
391392
else:
392-
tl.atomic_add(C, acc, mask=mask)
393+
# AMD GPUs need relaxed semantics for better performance
394+
if tl.constexpr(torch.version.hip is not None):
395+
tl.atomic_add(C, acc, mask=mask, sem="relaxed")
396+
else:
397+
tl.atomic_add(C, acc, mask=mask)
393398

394399

395400
_mixed_mm = triton.heuristics(MIXED_MM_HEURISTICS)(_mixed_mm_kernel)

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,16 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
207207
+ k_offs[None, :] * stride_scales_dim1
208208
)
209209
scales_mask = k_offs[None, :] < K
210-
tl.atomic_min(scales_ptr + scales_offs, scales[None, :], mask=scales_mask)
210+
# AMD GPUs need relaxed semantics for better performance
211+
if tl.constexpr(torch.version.hip is not None):
212+
tl.atomic_min(
213+
scales_ptr + scales_offs,
214+
scales[None, :],
215+
mask=scales_mask,
216+
sem="relaxed",
217+
)
218+
else:
219+
tl.atomic_min(scales_ptr + scales_offs, scales[None, :], mask=scales_mask)
211220

212221
@triton.autotune(configs=atomic_kernel_configs_2D, key=["num_elements"])
213222
@triton.jit

0 commit comments

Comments
 (0)