Skip to content

Commit 9130a38

Browse files
committed
minor updates
1 parent 0181182 commit 9130a38

2 files changed

Lines changed: 9 additions & 14 deletions

File tree

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,7 @@ def _triton_fp8_rowwise_2d_fused_scale_and_cast_kernel(
811811
if is_xpu:
812812
scale = tl.math.div_rn(fp8_dtype_max, row_amax)
813813
else:
814-
scale = fp8_dtype_max / row_amax.to(tl.float64).to(tl.float32)
814+
scale = (fp8_dtype_max / row_amax.to(tl.float64)).to(tl.float32)
815815

816816
# Optionally round to power of 2 for hardware-friendly scaling.
817817
# Power-of-2 scales can be applied as exponent additions rather than

torchao/prototype/mx_formats/kernels.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
import torch
13-
from torch._inductor.runtime.triton_helpers import libdevice
1413
from torch.distributed.tensor import Replicate, Shard
1514
from torch.distributed.tensor.experimental import register_sharding
1615
from torch.utils._triton import has_triton
@@ -30,6 +29,8 @@
3029
torch_version_at_least,
3130
)
3231

32+
_is_xpu = is_XPU()
33+
3334
logger = logging.getLogger(__name__)
3435

3536

@@ -58,8 +59,6 @@ def get_bits(x: torch.Tensor) -> str:
5859
ZERO_BITS_F32 = 0x0
5960
ZERO_POINT_FIVE_BITS_F32 = 0x3F000000
6061

61-
_is_xpu = is_XPU()
62-
6362

6463
def f32_to_f4_unpacked(x):
6564
"""
@@ -173,6 +172,9 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
173172
import triton
174173
import triton.language as tl
175174
from torch.library import triton_op, wrap_triton
175+
from triton.language.extra import libdevice
176+
177+
IS_XPU = tl.constexpr(_is_xpu)
176178

177179
def triton_to_mxfp8_dim1_reference(
178180
x_hp: torch.Tensor,
@@ -231,7 +233,6 @@ def triton_mxfp8_dequant_dim0(
231233
e8m0_scales.size(1),
232234
out_dtype=out_dtype_tl,
233235
SCALE_BLOCK_SIZE=scale_block_size,
234-
is_xpu=_is_xpu,
235236
)
236237
return out_buffer.reshape(orig_shape)
237238

@@ -275,7 +276,6 @@ def _dequant_mxfp8_kernel(
275276
SCALE_BLOCK_SIZE: tl.constexpr,
276277
ROW_TILE_SIZE: tl.constexpr,
277278
COL_TILE_SIZE: tl.constexpr,
278-
is_xpu: tl.constexpr,
279279
):
280280
pid_row = tl.program_id(0)
281281
pid_col = tl.program_id(1)
@@ -307,7 +307,7 @@ def _dequant_mxfp8_kernel(
307307
e8m0_scale_block_r = e8m0_scale_block.reshape(
308308
ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, 1
309309
)
310-
fp32_scale = _e8m0_to_fp32(e8m0_scale_block_r, is_xpu)
310+
fp32_scale = _e8m0_to_fp32(e8m0_scale_block_r)
311311
data_hp = e4m3_data_block_r.to(tl.float32) * fp32_scale
312312

313313
# Write to output buffer
@@ -316,11 +316,11 @@ def _dequant_mxfp8_kernel(
316316
tl.store(out_buffer + block_offs, out_buffer_block, mask=mask)
317317

318318
@triton.jit
319-
def _e8m0_to_fp32(scale_e8m0, is_xpu: tl.constexpr):
319+
def _e8m0_to_fp32(scale_e8m0):
320320
e8m0_nan_val = 255
321321
e8m0_exponent_bias = 127
322322
s_offset = scale_e8m0.to(tl.int16) - e8m0_exponent_bias
323-
if is_xpu:
323+
if IS_XPU:
324324
s_fp = libdevice.exp2(s_offset.to(tl.float32))
325325
else:
326326
s_fp = tl.exp2(s_offset.to(tl.float32))
@@ -479,12 +479,7 @@ def triton_mxfp8_dequant_dim0(
479479
)
480480

481481
if _triton_kernels_available:
482-
import triton
483-
import triton.language as tl
484-
from torch.library import triton_op, wrap_triton
485-
486482
IS_ROCM = tl.constexpr(is_ROCM())
487-
IS_XPU = tl.constexpr(_is_xpu)
488483

489484
@triton.jit
490485
def _calculate_reciprocal_scale(scale_e8m0_biased):

0 commit comments

Comments
 (0)