Skip to content

Commit a6be48f

Browse files
authored
fix undefined values for tail elements in act quant kernels (#4186)
1 parent 9051a2f commit a6be48f

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

  • torchao/prototype/blockwise_fp8_training

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def triton_fp8_blockwise_act_quant_lhs_kernel(
318318
k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
319319
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
320320
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
321-
x = tl.load(x_ptr + x_offs, mask=x_mask)
321+
x = tl.load(x_ptr + x_offs, mask=x_mask, other=0.0)
322322

323323
# Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1)
324324
amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64)
@@ -333,7 +333,8 @@ def triton_fp8_blockwise_act_quant_lhs_kernel(
333333

334334
# Write reciprocal scales
335335
scale_offs = m_offs[:, None] * s_stride_dim_0 + pid_k * s_stride_dim_1
336-
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
336+
scale_mask = m_offs[:, None] < M
337+
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
337338

338339

339340
@triton_op("torchao::triton_fp8_blockwise_act_quant_lhs", mutates_args={})
@@ -412,7 +413,7 @@ def triton_fp8_blockwise_act_quant_rhs_kernel(
412413
k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
413414
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
414415
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
415-
x = tl.load(x_ptr + x_offs, mask=x_mask)
416+
x = tl.load(x_ptr + x_offs, mask=x_mask, other=0.0)
416417

417418
# Column-wise scales for RHS operand, shape (1, block_size)
418419
amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64)
@@ -427,7 +428,8 @@ def triton_fp8_blockwise_act_quant_rhs_kernel(
427428

428429
# Write scales
429430
scale_offs = pid_m * s_stride_dim_0 + k_offs[None, :] * s_stride_dim_1
430-
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
431+
scale_mask = k_offs[None, :] < K
432+
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
431433

432434

433435
@triton_op("torchao::triton_fp8_blockwise_act_quant_rhs", mutates_args={})

0 commit comments

Comments
 (0)