🐛 Describe the bug
Today, torch.compile does not generate optimal kernels for casting a tensor from bfloat16 to nvfp4. Using a torchao microbenchmark, we see the following:
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (main)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_nvfp4
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.10.0a0+gitd884ac4
triton version: 3.5.1
mode: dim0_nvfp4
time_us 879.6639740467072
mem_bw_gbps 781.9643367178256
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (main)]$ python benchmarks/mx_formats/cast_bench.py --mode dim0_nvfp4_triton_swizzle
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA B200
torch version: 2.10.0a0+gitd884ac4
triton version: 3.5.1
mode: dim0_nvfp4_triton_swizzle
time_us 206.84799551963806
mem_bw_gbps 3325.465418564785
(pytorch) [vasiliy@devgpu023.atn1 ~/local/ao (main)]$ TORCH_LOGS_FORMAT=short TORCH_LOGS=output_code python benchmarks/mx_formats/cast_bench.py --mode dim0_nvfp4 2>&1 | pastry
P2024733149: https://www.internalfb.com/intern/paste/P2024733149/
The compile version hits ~0.8 GBPS (10% of peak mem bw on a B200). For reference, a very naive fused triton kernel (with tiny tiles and lots of room to optimize) hits 3.3 GBPS on the same benchmark.
Logs of what compile is currently generating: https://www.internalfb.com/phabricator/paste/view/P2024733149
This will be really useful as nvfp4 adoption picks up and people will want to apply torch.compile to fuse the quantization cast to the preceding op.
Versions
NVIDIA B200, PyTorch 2.10.0a0+gitd884ac4
cc @chauhang @penguinwu
🐛 Describe the bug
Today, torch.compile does not generate optimal kernels for casting a tensor from bfloat16 to nvfp4. Using a torchao microbenchmark, we see the following:
The compile version hits ~0.8 GBPS (10% of peak mem bw on a B200). For reference, a very naive fused triton kernel (with tiny tiles and lots of room to optimize) hits 3.3 GBPS on the same benchmark.
Logs of what compile is currently generating: https://www.internalfb.com/phabricator/paste/view/P2024733149
This will be really useful as nvfp4 adoption picks up and people will want to apply torch.compile to fuse the quantization cast to the preceding op.
Versions
NVIDIA B200, PyTorch 2.10.0a0+gitd884ac4
cc @chauhang @penguinwu