Skip to content

Commit 8c24bb9

Browse files
committed
Update on "Add support for flashinfer quantize kernel option for nvfp4"
Summary: Added the flashinfer option for better performance on some of the workflow we are interested in, also added numerical equivalence test between different nvfp4_quantize_kernel_choice options Test Plan: pytest test/prototype/mx_formats/test_nvfp4_tensor.py -k test_kernel_preference_numerical_equivalence We'll test speedup a bit later Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
2 parents 311704c + 04782bf commit 8c24bb9

188 files changed

Lines changed: 14810 additions & 4256 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/1xH100_tests.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,5 +57,6 @@ jobs:
5757
python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py
5858
./test/float8/test_everything_single_gpu.sh
5959
pytest test/prototype/mx_formats/ --verbose -s
60-
pytest test/prototype/moe_training/test_scaled_grouped_mm.py --verbose -s
60+
pytest test/prototype/moe_training/test_fp8_grouped_mm.py --verbose -s
61+
pytest test/prototype/moe_training/test_mxfp8_grouped_mm.py --verbose -s
6162
pytest test/prototype/moe_training/test_training.py --verbose -s

.github/workflows/4xH100_tests.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,4 @@ jobs:
4747
uv pip install -r dev-requirements.txt
4848
pip install . --no-build-isolation
4949
./test/float8/test_everything_multi_gpu.sh
50-
./test/prototype/mx_formats/test_mx_dtensor.sh
5150
./test/prototype/mx_formats/test_mxfp8_allgather.sh
52-
./test/prototype/moe_training/test_distributed.sh

.github/workflows/claude-code.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
name: Claude Code
2+
3+
on:
4+
issue_comment:
5+
types: [created]
6+
issues:
7+
types: [opened]
8+
9+
jobs:
10+
claude-code:
11+
uses: pytorch/test-infra/.github/workflows/_claude-code.yml@main
12+
permissions:
13+
contents: read
14+
pull-requests: write
15+
issues: write
16+
id-token: write
17+
secrets: inherit

CLAUDE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# TorchAO Claude Instructions
2+
3+
Fill me in

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ pip install torchao
110110

111111
Please see the [torchao compability table](https://github.com/pytorch/ao/issues/2919) for version requirements for dependencies.
112112

113+
### Optional Dependencies
114+
115+
[MSLK](https://github.com/pytorch/MSLK) is an optional runtime dependency that provides accelerated kernels for some of the workflows in torchao. Stable MSLK should be used with stable torchao, and nightly MSLK with nightly torchao.
116+
```bash
117+
# Stable
118+
pip install mslk-cuda==1.0.0
119+
120+
# Nightly
121+
pip install --pre mslk --index-url https://download.pytorch.org/whl/nightly/cu128
122+
```
123+
113124
## 🔎 Inference
114125

115126
TorchAO delivers substantial performance gains with minimal code changes:

benchmarks/float8/bench_matmul.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def run(
4242
assert recipe in (
4343
"tensorwise",
4444
"rowwise",
45-
"mxfp8_cublas",
4645
"mxfp4_cutlass",
4746
"nvfp4",
4847
), "unsupported"

benchmarks/float8/float8_inference_roofline.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
)
5959
from torchao.quantization.quantize_.common import KernelPreference
6060
from torchao.testing.training.roofline_utils import (
61+
get_inference_bf16_activation_mem_sympy,
6162
get_inference_float8_mem_sympy,
6263
get_inference_gemm_time_sympy,
6364
)
@@ -111,7 +112,7 @@ def get_gemm_times(
111112

112113
bf16_time_s = get_gpu_kernel_gemm_time_s(torch.mm, x_bf16, w_bf16)
113114

114-
if recipe_name in ("mxfp4_cutlass", "nvfp4"):
115+
if recipe_name in ("mxfp4_cutlass", "nvfp4", "nvfp4_static"):
115116
d1, d2, d3 = torch.float4_e2m1fn_x2, torch.float4_e2m1fn_x2, torch.bfloat16
116117
A = torch.randint(0, 255, (M, K // 2), device=device, dtype=torch.uint8).view(
117118
d1
@@ -150,7 +151,7 @@ def get_gemm_times(
150151
scale_b = torch.ones(N, K // 32, device=device, dtype=torch.float8_e8m0fnu)
151152
scale_a = to_blocked(scale_a)
152153
scale_b = to_blocked(scale_b)
153-
elif recipe_name == "nvfp4":
154+
elif recipe_name in ("nvfp4", "nvfp4_static"):
154155
scale_a = torch.ones(M, K // 16, device=device, dtype=torch.float8_e4m3fn)
155156
scale_b = torch.ones(N, K // 16, device=device, dtype=torch.float8_e4m3fn)
156157
scale_a = to_blocked(scale_a)
@@ -176,7 +177,7 @@ def do_matmul(A, B):
176177
swizzle_b=SwizzleType.SWIZZLE_32_4_4,
177178
output_dtype=d3,
178179
)
179-
if recipe_name == "nvfp4":
180+
if recipe_name in ("nvfp4", "nvfp4_static"):
180181
return torch._scaled_mm(
181182
A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=False
182183
)
@@ -468,8 +469,8 @@ def _stack_layers_conv(
468469

469470

470471
def run(
471-
outfile: str,
472472
recipe_name: str,
473+
outfile: str | None = None,
473474
do_benchmarks: bool = True,
474475
shape_gen_name: str = "pow2",
475476
M: Optional[int] = None,
@@ -485,6 +486,7 @@ def run(
485486
kernel_size: Optional[int] = None,
486487
stride: int = 1,
487488
padding: int = 0,
489+
skip_printing_detailed_metrics: bool = False,
488490
):
489491
"""
490492
Args:
@@ -500,6 +502,8 @@ def run(
500502
* `kernel_size`: kernel_size for conv3d / conv2d
501503
* `stride`: stride for conv ops (default: 1)
502504
* `padding`: padding for conv ops (default: 0)
505+
* `skip_printing_detailed_metrics`: if True, prints e2e roofline
506+
and observed speedups only, skipping all other intermediate metrics
503507
"""
504508
_SUPPORTED_OPS = ["linear", "conv2d", "conv3d"]
505509
assert op_name in _SUPPORTED_OPS, (
@@ -561,6 +565,11 @@ def run(
561565
# TODO(future): also enable fusion modeling here
562566
)
563567
bf16_gemm_time_sympy = get_inference_gemm_time_sympy(M, K, N, torch.bfloat16, None)
568+
if enable_fusion_modeling and op_name == "linear":
569+
bf16_ovhd_time_sympy = get_inference_bf16_activation_mem_sympy(M, K, N)
570+
else:
571+
# multiply by M to ensure we get a sympy symbol
572+
bf16_ovhd_time_sympy = M * 0
564573

565574
if recipe_name and recipe_name.startswith(("nvfp4", "mxfp4")):
566575
fp8_gemm_time_sympy = get_inference_gemm_time_sympy(
@@ -572,6 +581,7 @@ def run(
572581
M, K, N, torch.float8_e4m3fn, gemm_recipe_name
573582
)
574583
print("bf16_gemm_time_sympy", bf16_gemm_time_sympy)
584+
print("bf16_ovhd_time_sympy", bf16_ovhd_time_sympy)
575585
print("fp8_gemm_time_sympy", fp8_gemm_time_sympy)
576586
print("fp8_ovhd_time_sympy", fp8_ovhd_time_sympy)
577587
print()
@@ -587,6 +597,8 @@ def run(
587597
# roofline - gemm time (fwd + bwd, 3 gemms; for conv: using equivalent implicit gemm dims)
588598
"r_bf16_gemm_s",
589599
"r_fp8_gemm_s",
600+
# roofline - bf16 overhead time (read-write prev activation, only if fusion modeling is on)
601+
"r_bf16_ovhd_s",
590602
# roofline - fp8 overhead time (by counting reads/writes in the ideal case)
591603
"r_fp8_ovhd_s",
592604
# roofline - fp8 gemm + fp8 overhead time (does not include LN or sigmoid)
@@ -628,11 +640,16 @@ def run(
628640
)
629641

630642
# note: cast from sympy.core.numbers.Float to float to make pandas formatting work
643+
r_bf16_ovhd_time_s = float(
644+
bf16_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
645+
)
631646
r_fp8_ovhd_time_s = float(
632647
fp8_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
633648
)
634649
r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
635-
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
650+
r_speedup = (r_bf16_gemm_time_s + r_bf16_ovhd_time_s) / (
651+
r_fp8_gemm_time_s + r_fp8_ovhd_time_s
652+
)
636653

637654
# if enabled, also measured observed gemm time
638655
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
@@ -679,11 +696,16 @@ def run(
679696
r_fp8_gemm_time_s = float(
680697
fp8_gemm_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N)
681698
)
699+
r_bf16_ovhd_time_s = float(
700+
bf16_ovhd_time_sympy.subs(M, M_val).subs(K, K_val).subs(N, N_val)
701+
)
682702
r_fp8_ovhd_time_s = float(
683703
fp8_ovhd_time_sympy.subs(M, gemm_M).subs(K, gemm_K).subs(N, gemm_N)
684704
)
685705
r_fp8_gemm_and_ovhd_s = r_fp8_gemm_time_s + r_fp8_ovhd_time_s
686-
r_speedup = r_bf16_gemm_time_s / (r_fp8_gemm_time_s + r_fp8_ovhd_time_s)
706+
r_speedup = (r_bf16_gemm_time_s + r_bf16_ovhd_time_s) / (
707+
r_fp8_gemm_time_s + r_fp8_ovhd_time_s
708+
)
687709

688710
# measure actual conv kernel times (without quant overhead)
689711
b_bf16_gemm_time_s, b_fp8_gemm_time_s = 0, 0
@@ -773,12 +795,29 @@ def run(
773795
)
774796
elif recipe_name == "nvfp4":
775797
config = NVFP4DynamicActivationNVFP4WeightConfig(
776-
use_dynamic_per_tensor_scale=False,
798+
use_dynamic_per_tensor_scale=True,
799+
)
800+
elif recipe_name == "nvfp4_static":
801+
config_calib = NVFP4DynamicActivationNVFP4WeightConfig(
802+
step="prepare",
803+
)
804+
config = NVFP4DynamicActivationNVFP4WeightConfig(
805+
step="convert",
777806
)
778807
else:
779808
assert False, "unsupported"
780809

781810
m_fp8_dyn = copy.deepcopy(m_orig)
811+
812+
if recipe_name == "nvfp4_static":
813+
# calibrate with sample data
814+
# this benchmark is performance-only, so a toy datum is fine
815+
quantize_(m_fp8_dyn, config_calib)
816+
toy_datum = torch.randn(
817+
M_val, K_val, dtype=torch.bfloat16, device="cuda"
818+
)
819+
m_fp8_dyn(toy_datum)
820+
782821
if op_name == "linear":
783822
quantize_(m_fp8_dyn, config)
784823
elif op_name == "conv2d":
@@ -813,7 +852,8 @@ def run(
813852
# roofline - gemm
814853
r_bf16_gemm_time_s,
815854
r_fp8_gemm_time_s,
816-
# roofline - fp8 overhead
855+
# roofline - overhead
856+
r_bf16_ovhd_time_s,
817857
r_fp8_ovhd_time_s,
818858
# roofline - gemm + overhead, and speedup
819859
r_fp8_gemm_and_ovhd_s,
@@ -833,8 +873,20 @@ def run(
833873

834874
pd.set_option("display.precision", 2)
835875
df = pd.DataFrame(results, columns=headers)
876+
877+
if outfile is not None:
878+
df.to_csv(outfile)
879+
880+
if op_name == "linear":
881+
# drop conv-only columns to simplify linear results
882+
df = df.drop(columns=["D", "H", "W", "kernel_size"])
883+
884+
if skip_printing_detailed_metrics:
885+
df = df[
886+
["fwd_M", "fwd_K", "fwd_N", "r_fp8_gemm_and_ovhd_spdp", "b_fp8_e2e_spdp"]
887+
]
888+
836889
print(df)
837-
df.to_csv(outfile)
838890
print("done")
839891

840892

benchmarks/float8/float8_roofline.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,10 @@
6161
Float8LinearConfig,
6262
convert_to_float8_training,
6363
)
64-
from torchao.prototype.mx_formats import MXLinearConfig
64+
from torchao.prototype.moe_training.config import (
65+
MXFP8TrainingOpConfig,
66+
MXFP8TrainingRecipe,
67+
)
6568
from torchao.quantization import quantize_
6669
from torchao.testing.training.roofline_utils import (
6770
get_float8_mem_sympy,
@@ -253,10 +256,7 @@ def run(
253256
print(f"enable_fusion_modeling: {enable_fusion_modeling}")
254257

255258
assert mx_recipe_name in (
256-
# real mxfp8_cublas recipe
257-
"mxfp8_cublas",
258-
# real mxfp8_cublas_rceil recipe
259-
"mxfp8_cublas_rceil",
259+
None,
260260
# modeling of what mxfp8 with 32x32 block size and without gemm
261261
# operand layout restrictions would look like
262262
"mxfp8_32x32_flexible_gemm_layout",
@@ -429,7 +429,15 @@ def run(
429429
)
430430
else:
431431
assert mx_recipe_name is not None
432-
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
432+
try:
433+
config = MXFP8TrainingOpConfig.from_recipe(
434+
MXFP8TrainingRecipe(mx_recipe_name)
435+
)
436+
except ValueError:
437+
raise ValueError(
438+
f"Unsupported mx_recipe_name: {mx_recipe_name}. "
439+
f"Supported values: {[r.value for r in MXFP8TrainingRecipe]}"
440+
)
433441
m_fp8_dyn = copy.deepcopy(m_orig)
434442
quantize_(m_fp8_dyn, config=config)
435443
m_fp8_dyn = torch.compile(m_fp8_dyn)

benchmarks/float8/profile_lowp_training.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@
4545
from torchao.float8.float8_linear_utils import (
4646
convert_to_float8_training,
4747
)
48-
from torchao.prototype.mx_formats.config import MXLinearConfig
48+
from torchao.prototype.moe_training.config import (
49+
MXFP8TrainingOpConfig,
50+
MXFP8TrainingRecipe,
51+
)
4952
from torchao.prototype.mx_formats.mx_tensor import MXTensor
5053
from torchao.prototype.mx_formats.utils import to_blocked
5154
from torchao.quantization import quantize_
@@ -320,7 +323,15 @@ def main(
320323
elif float8_recipe_name is not None:
321324
config = Float8LinearConfig.from_recipe_name(float8_recipe_name)
322325
elif mx_recipe_name is not None:
323-
config = MXLinearConfig.from_recipe_name(mx_recipe_name)
326+
try:
327+
config = MXFP8TrainingOpConfig.from_recipe(
328+
MXFP8TrainingRecipe(mx_recipe_name)
329+
)
330+
except ValueError:
331+
raise ValueError(
332+
f"Unsupported mx_recipe_name: {mx_recipe_name}. "
333+
f"Supported values: {[r.value for r in MXFP8TrainingRecipe]}"
334+
)
324335

325336
print(f"Compile is set to | {compile}")
326337
print(f"model_type is set to | {model_type}")

benchmarks/mx_formats/cast_bench.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ def to_nvfp4_reference(x_hp):
8383

8484

8585
def to_nvfp4_reference_triton_swizzle(x_hp):
86+
per_tensor_scale = torch.tensor(1.0, dtype=torch.float32, device=x_hp.device)
8687
nvfp4_tensor = NVFP4Tensor.to_nvfp4(
87-
x_hp, use_triton_kernel=True, is_swizzled_scales=True
88+
x_hp,
89+
per_tensor_scale=per_tensor_scale,
90+
use_triton_kernel=True,
91+
is_swizzled_scales=True,
8892
)
8993
return nvfp4_tensor.qdata, nvfp4_tensor.scale
9094

@@ -118,6 +122,7 @@ def run(
118122
"dim1_mxfp8_floor",
119123
"dim1_mxfp8_rceil",
120124
"dim1_mxfp8_triton_floor",
125+
"dim1_mxfp8_triton_rceil",
121126
"dim1_mxfp8_cuda_floor",
122127
"dim1_mxfp8_cuda_rceil",
123128
)
@@ -350,12 +355,41 @@ def run(
350355
bps = (bytes_r + bytes_w) / (time_us / 1e6)
351356

352357
elif mode == "dim1_mxfp8_triton_floor":
353-
y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
358+
y_d1, s_d1 = triton_to_mxfp8_dim1(
359+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
360+
)
354361

355362
for _ in range(2):
356-
__ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE)
363+
__ = triton_to_mxfp8_dim1(
364+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
365+
)
357366
time_us = benchmark_cuda_function_in_microseconds(
358-
lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE),
367+
lambda x, b: triton_to_mxfp8_dim1(
368+
x, inner_block_size=BLOCK_SIZE, scaling_mode="floor"
369+
),
370+
x,
371+
BLOCK_SIZE,
372+
)
373+
374+
assert y_d1.dtype == torch.float8_e4m3fn
375+
assert s_d1.dtype == torch.float8_e8m0fnu
376+
bytes_r = x.numel() * bytes_per_el_bf16
377+
bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8
378+
bps = (bytes_r + bytes_w) / (time_us / 1e6)
379+
380+
elif mode == "dim1_mxfp8_triton_rceil":
381+
y_d1, s_d1 = triton_to_mxfp8_dim1(
382+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
383+
)
384+
385+
for _ in range(2):
386+
__ = triton_to_mxfp8_dim1(
387+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
388+
)
389+
time_us = benchmark_cuda_function_in_microseconds(
390+
lambda x, b: triton_to_mxfp8_dim1(
391+
x, inner_block_size=BLOCK_SIZE, scaling_mode="rceil"
392+
),
359393
x,
360394
BLOCK_SIZE,
361395
)

0 commit comments

Comments
 (0)