@@ -318,7 +318,7 @@ def triton_fp8_blockwise_act_quant_lhs_kernel(
318318 k_offs = pid_k * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
319319 x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
320320 x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
321- x = tl .load (x_ptr + x_offs , mask = x_mask )
321+ x = tl .load (x_ptr + x_offs , mask = x_mask , other = 0.0 )
322322
323323 # Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1)
324324 amax = tl .clamp (tl .max (tl .abs (x ), axis = 1 ), min = EPS , max = float ("inf" )).to (tl .float64 )
@@ -333,7 +333,8 @@ def triton_fp8_blockwise_act_quant_lhs_kernel(
333333
334334 # Write reciprocal scales
335335 scale_offs = m_offs [:, None ] * s_stride_dim_0 + pid_k * s_stride_dim_1
336- tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
336+ scale_mask = m_offs [:, None ] < M
337+ tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
337338
338339
339340@triton_op ("torchao::triton_fp8_blockwise_act_quant_lhs" , mutates_args = {})
@@ -412,7 +413,7 @@ def triton_fp8_blockwise_act_quant_rhs_kernel(
412413 k_offs = pid_k * NUM_GROUPS + tl .arange (0 , NUM_GROUPS )
413414 x_offs = m_offs [:, None ] * x_stride_dim_0 + k_offs [None , :] * x_stride_dim_1
414415 x_mask = (m_offs [:, None ] < M ) & (k_offs [None , :] < K )
415- x = tl .load (x_ptr + x_offs , mask = x_mask )
416+ x = tl .load (x_ptr + x_offs , mask = x_mask , other = 0.0 )
416417
417418 # Column-wise scales for RHS operand, shape (1, block_size)
418419 amax = tl .clamp (tl .max (tl .abs (x ), axis = 0 ), min = EPS , max = float ("inf" )).to (tl .float64 )
@@ -427,7 +428,8 @@ def triton_fp8_blockwise_act_quant_rhs_kernel(
427428
428429 # Write scales
429430 scale_offs = pid_m * s_stride_dim_0 + k_offs [None , :] * s_stride_dim_1
430- tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ))
431+ scale_mask = k_offs [None , :] < K
432+ tl .store (s_ptr + scale_offs , tl .div_rn (1.0 , scale ), mask = scale_mask )
431433
432434
433435@triton_op ("torchao::triton_fp8_blockwise_act_quant_rhs" , mutates_args = {})
0 commit comments