1010
1111import numpy as np
1212import torch
13- from torch ._inductor .runtime .triton_helpers import libdevice
1413from torch .distributed .tensor import Replicate , Shard
1514from torch .distributed .tensor .experimental import register_sharding
1615from torch .utils ._triton import has_triton
3029 torch_version_at_least ,
3130)
3231
32+ _is_xpu = is_XPU ()
33+
3334logger = logging .getLogger (__name__ )
3435
3536
@@ -58,8 +59,6 @@ def get_bits(x: torch.Tensor) -> str:
5859ZERO_BITS_F32 = 0x0
5960ZERO_POINT_FIVE_BITS_F32 = 0x3F000000
6061
61- _is_xpu = is_XPU ()
62-
6362
6463def f32_to_f4_unpacked (x ):
6564 """
@@ -173,6 +172,9 @@ def pack_uint4(uint8_data: torch.Tensor) -> torch.Tensor:
173172 import triton
174173 import triton .language as tl
175174 from torch .library import triton_op , wrap_triton
175+ from triton .language .extra import libdevice
176+
177+ IS_XPU = tl .constexpr (_is_xpu )
176178
177179 def triton_to_mxfp8_dim1_reference (
178180 x_hp : torch .Tensor ,
@@ -231,7 +233,6 @@ def triton_mxfp8_dequant_dim0(
231233 e8m0_scales .size (1 ),
232234 out_dtype = out_dtype_tl ,
233235 SCALE_BLOCK_SIZE = scale_block_size ,
234- is_xpu = _is_xpu ,
235236 )
236237 return out_buffer .reshape (orig_shape )
237238
@@ -275,7 +276,6 @@ def _dequant_mxfp8_kernel(
275276 SCALE_BLOCK_SIZE : tl .constexpr ,
276277 ROW_TILE_SIZE : tl .constexpr ,
277278 COL_TILE_SIZE : tl .constexpr ,
278- is_xpu : tl .constexpr ,
279279 ):
280280 pid_row = tl .program_id (0 )
281281 pid_col = tl .program_id (1 )
@@ -307,7 +307,7 @@ def _dequant_mxfp8_kernel(
307307 e8m0_scale_block_r = e8m0_scale_block .reshape (
308308 ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE , 1
309309 )
310- fp32_scale = _e8m0_to_fp32 (e8m0_scale_block_r , is_xpu )
310+ fp32_scale = _e8m0_to_fp32 (e8m0_scale_block_r )
311311 data_hp = e4m3_data_block_r .to (tl .float32 ) * fp32_scale
312312
313313 # Write to output buffer
@@ -316,11 +316,11 @@ def _dequant_mxfp8_kernel(
316316 tl .store (out_buffer + block_offs , out_buffer_block , mask = mask )
317317
318318 @triton .jit
319- def _e8m0_to_fp32 (scale_e8m0 , is_xpu : tl . constexpr ):
319+ def _e8m0_to_fp32 (scale_e8m0 ):
320320 e8m0_nan_val = 255
321321 e8m0_exponent_bias = 127
322322 s_offset = scale_e8m0 .to (tl .int16 ) - e8m0_exponent_bias
323- if is_xpu :
323+ if IS_XPU :
324324 s_fp = libdevice .exp2 (s_offset .to (tl .float32 ))
325325 else :
326326 s_fp = tl .exp2 (s_offset .to (tl .float32 ))
@@ -479,12 +479,7 @@ def triton_mxfp8_dequant_dim0(
479479)
480480
481481if _triton_kernels_available :
482- import triton
483- import triton .language as tl
484- from torch .library import triton_op , wrap_triton
485-
486482 IS_ROCM = tl .constexpr (is_ROCM ())
487- IS_XPU = tl .constexpr (_is_xpu )
488483
489484 @triton .jit
490485 def _calculate_reciprocal_scale (scale_e8m0_biased ):
0 commit comments