Skip to content

Add "32x1 transposed" variant to MXFP8 3D quantization kernel#4356

Open
alexsamardzic wants to merge 2 commits intogh/alexsamardzic/1/basefrom
gh/alexsamardzic/1/head
Open

Add "32x1 transposed" variant to MXFP8 3D quantization kernel#4356
alexsamardzic wants to merge 2 commits intogh/alexsamardzic/1/basefrom
gh/alexsamardzic/1/head

Conversation

@alexsamardzic
Copy link
Copy Markdown
Collaborator

@alexsamardzic alexsamardzic commented Apr 30, 2026

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 30, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4356

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 04b49da with merge base 9052ece (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 30, 2026
@alexsamardzic alexsamardzic added module: training quantize_ api training flow moe mx labels Apr 30, 2026
[ghstack-poisoned]
alexsamardzic added a commit that referenced this pull request Apr 30, 2026
@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

Benchmarking results:

input_shape       scaling_mode                variant      cuda_2d_us    cutedsl_3d_us    to_mx_us    cuda_2d_gbps    cutedsl_3d_gbps    to_mx_gbps
----------------  --------------------------  ---------  ------------  ---------------  ----------  --------------  -----------------  ------------
(1, 8192, 5120)   ScaleCalculationMode.FLOOR  32x1_t          nan               34.816      26.624         nan                3651.76      4775.39
(1, 8192, 5120)   ScaleCalculationMode.FLOOR  32x1_n           30.72            28.736     118.784        4138.67             4424.41      1070.35
(1, 8192, 5120)   ScaleCalculationMode.FLOOR  32x32_n         nan               30.72       73.696         nan                4138.67      1725.19
(1, 8192, 5120)   ScaleCalculationMode.RCEIL  32x1_t          nan               34.816      36.864         nan                3651.76      3448.89
(1, 8192, 5120)   ScaleCalculationMode.RCEIL  32x1_n           28.672           28.672     120.864        4434.29             4434.29      1051.92
(1, 8192, 5120)   ScaleCalculationMode.RCEIL  32x32_n         nan               28.704      73.728         nan                4429.34      1724.44
(1, 7168, 2048)   ScaleCalculationMode.FLOOR  32x1_t          nan               16.384      59.392         nan                2716          749.241
(1, 7168, 2048)   ScaleCalculationMode.FLOOR  32x1_n           16.384           14.336      83.968        2716                3104          529.951
(1, 7168, 2048)   ScaleCalculationMode.FLOOR  32x32_n         nan               14.336      69.632         nan                3104          639.059
(1, 7168, 2048)   ScaleCalculationMode.RCEIL  32x1_t          nan               16.384      63.488         nan                2716          700.903
(1, 7168, 2048)   ScaleCalculationMode.RCEIL  32x1_n           14.368           14.336      86.016        3097.09             3104          517.333
(1, 7168, 2048)   ScaleCalculationMode.RCEIL  32x32_n         nan               14.368      71.648         nan                3097.09       621.077
(8, 8192, 5120)   ScaleCalculationMode.FLOOR  32x1_t          nan              186.368    1208.48          nan                5457.58       841.651
(8, 8192, 5120)   ScaleCalculationMode.FLOOR  32x1_n         1524.74           169.984    1689.63          667.079            5983.61       601.976
(8, 8192, 5120)   ScaleCalculationMode.FLOOR  32x32_n         nan              174.08     1816.58          nan                5842.82       559.91
(8, 8192, 5120)   ScaleCalculationMode.RCEIL  32x1_t          nan              184.288    1289.22          nan                5519.18       788.944
(8, 8192, 5120)   ScaleCalculationMode.RCEIL  32x1_n         1506.27           163.872    1764.35          675.256            6206.79       576.483
(8, 8192, 5120)   ScaleCalculationMode.RCEIL  32x32_n         nan              166.976    1672.19          nan                6091.41       608.255
(8, 7168, 2048)   ScaleCalculationMode.FLOOR  32x1_t          nan               70.656     433.2           nan                5038.38       821.772
(8, 7168, 2048)   ScaleCalculationMode.FLOOR  32x1_n          540.704           65.536     602.112         658.385            5432          591.238
(8, 7168, 2048)   ScaleCalculationMode.FLOOR  32x32_n         nan               67.584     643.072         nan                5267.39       553.58
(8, 7168, 2048)   ScaleCalculationMode.RCEIL  32x1_t          nan               69.664     458.752         nan                5110.12       776
(8, 7168, 2048)   ScaleCalculationMode.RCEIL  32x1_n          534.56            65.536     626.72          665.952            5432          568.023
(8, 7168, 2048)   ScaleCalculationMode.RCEIL  32x32_n         nan               65.536     593.92          nan                5432          599.393
(32, 7168, 2048)  ScaleCalculationMode.FLOOR  32x1_t          nan              252        1686.56          nan                5650.66       844.302
(32, 7168, 2048)  ScaleCalculationMode.FLOOR  32x1_n         2128.9            235.52     2368.54          668.875            6046.05       601.199
(32, 7168, 2048)  ScaleCalculationMode.FLOOR  32x32_n         nan              245.792    2544.61          nan                5793.38       559.601
(32, 7168, 2048)  ScaleCalculationMode.RCEIL  32x1_t          nan              247.84     1807.36          nan                5745.51       787.871
(32, 7168, 2048)  ScaleCalculationMode.RCEIL  32x1_n         2102.24           223.392    2471.97          677.357            6374.29       576.046
(32, 7168, 2048)  ScaleCalculationMode.RCEIL  32x32_n         nan              231.456    2354.19          nan                6152.21       604.864
(32, 8192, 5120)  ScaleCalculationMode.FLOOR  32x1_t          nan              740.912    4806.78          nan                5491.17       846.403
(32, 8192, 5120)  ScaleCalculationMode.FLOOR  32x1_n         6059.01           666.624    6730.75          671.475            6103.1        604.461
(32, 8192, 5120)  ScaleCalculationMode.FLOOR  32x32_n         nan              708.576    7233.54          nan                5741.76       562.446
(32, 8192, 5120)  ScaleCalculationMode.RCEIL  32x1_t          nan              718.848    5125.71          nan                5659.72       793.738
(32, 8192, 5120)  ScaleCalculationMode.RCEIL  32x1_n         5980.18           633.792    7013.41          680.327            6419.26       580.1
(32, 8192, 5120)  ScaleCalculationMode.RCEIL  32x32_n         nan              649.216    6655.12          nan                6266.75       611.33

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@alexsamardzic can you benchmark this against the 2 stage approach we do here in _compute_fwd_sm100():

weight_e4m3, weight_scales = triton_to_mxfp8_dim0(
weight_t.transpose(-2, -1), block_size, scale_calculation_mode.value.lower()
)
weight_scales_blocked = triton_mx_block_rearrange_per_group_3d(weight_scales)

@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

can you benchmark this against the 2 stage approach we do here in _compute_fwd_sm100():

weight_e4m3, weight_scales = triton_to_mxfp8_dim0(
weight_t.transpose(-2, -1), block_size, scale_calculation_mode.value.lower()
)
weight_scales_blocked = triton_mx_block_rearrange_per_group_3d(weight_scales)

Here is an adapted benchmarking script to compare between the two: bench_quantize_3d_vs_triton.py.

And here are the results:

input_shape       scaling_mode                  cutedsl_3d_us    triton_two_stage_us  speedup
----------------  --------------------------  ---------------  ---------------------  ---------
(1, 8192, 5120)   ScaleCalculationMode.FLOOR           34.816                 34.848  1.00x
(1, 8192, 5120)   ScaleCalculationMode.RCEIL           32.832                 34.848  1.06x
(1, 7168, 2048)   ScaleCalculationMode.FLOOR           16.224                 20.736  1.28x
(1, 7168, 2048)   ScaleCalculationMode.RCEIL           14.336                 20.352  1.42x
(8, 8192, 5120)   ScaleCalculationMode.FLOOR          186.4                  196.608  1.05x
(8, 8192, 5120)   ScaleCalculationMode.RCEIL          184.384                198.656  1.08x
(8, 7168, 2048)   ScaleCalculationMode.FLOOR           71.68                  72.72   1.01x
(8, 7168, 2048)   ScaleCalculationMode.RCEIL           69.632                 73.728  1.06x
(32, 7168, 2048)  ScaleCalculationMode.FLOOR          261.12                 285.632  1.09x
(32, 7168, 2048)  ScaleCalculationMode.RCEIL          256                    287.712  1.12x
(32, 8192, 5120)  ScaleCalculationMode.FLOOR          759.808               1029.09   1.35x
(32, 8192, 5120)  ScaleCalculationMode.RCEIL          737.312                990.304  1.34x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants