Skip to content

Support 32x32 scaling for weights in MXFP8 weight quantization kernel#4254

Merged
danielvegamyhre merged 1 commit intomainfrom
add-32x32-support
Apr 26, 2026
Merged

Support 32x32 scaling for weights in MXFP8 weight quantization kernel#4254
danielvegamyhre merged 1 commit intomainfrom
add-32x32-support

Conversation

@alexsamardzic
Copy link
Copy Markdown
Collaborator

@alexsamardzic alexsamardzic commented Apr 9, 2026

Fixes #4185

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4254

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 1 New Failure

As of commit b7809e3 with merge base 25ca6b8 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 9, 2026
@alexsamardzic
Copy link
Copy Markdown
Collaborator Author

Fixes #4185.

To test:

pytest -q test/prototype/moe_training/test_kernels.py -k "test_cuda_mx_3d_cutedsl_numerics and 32x32"

To benchmark:

python benchmarks/prototype/moe_training/mxfp8/bench_quantize_3d.py

Benchmarking results:

input_shape       scaling_mode                  scale_block_k    cuda_2d_us    cutedsl_3d_us    to_mx_us    cuda_2d_gbps    cutedsl_3d_gbps    to_mx_gbps
----------------  --------------------------  ---------------  ------------  ---------------  ----------  --------------  -----------------  ------------
(1, 8192, 5120)   ScaleCalculationMode.FLOOR                1        36.864           29.728     119.728        3448.89             4276.77      1061.91
(1, 8192, 5120)   ScaleCalculationMode.FLOOR               32       nan               30.752      73.664         nan                4134.36      1725.94
(1, 8192, 5120)   ScaleCalculationMode.RCEIL                1        35.872           28.704     118.752        3544.26             4429.34      1070.63
(1, 8192, 5120)   ScaleCalculationMode.RCEIL               32       nan               28.704      72.704         nan                4429.34      1748.73
(1, 7168, 2048)   ScaleCalculationMode.FLOOR                1        38.304           24.64       83.872        1161.73             1805.96       530.558
(1, 7168, 2048)   ScaleCalculationMode.FLOOR               32       nan               20.096      70.72          nan                2214.32       629.227
(1, 7168, 2048)   ScaleCalculationMode.RCEIL                1        33.984           18.624      83.968        1309.41             2389.33       529.951
(1, 7168, 2048)   ScaleCalculationMode.RCEIL               32       nan               19.808      69.824         nan                2246.51       637.302
(8, 8192, 5120)   ScaleCalculationMode.FLOOR                1      1524.64           168.864    1689.5           667.121            6023.3        602.022
(8, 8192, 5120)   ScaleCalculationMode.FLOOR               32       nan              168        1816.64          nan                6054.28       559.89
(8, 8192, 5120)   ScaleCalculationMode.RCEIL                1      1505.44           163.68     1688.42          675.629            6214.07       602.41
(8, 8192, 5120)   ScaleCalculationMode.RCEIL               32       nan              162.976    1817.6           nan                6240.91       559.594
(8, 7168, 2048)   ScaleCalculationMode.FLOOR                1       541.536           66.56      602.08          657.374            5348.43       591.27
(8, 7168, 2048)   ScaleCalculationMode.FLOOR               32       nan               65.568     644             nan                5429.35       552.782
(8, 7168, 2048)   ScaleCalculationMode.RCEIL                1       535.488           65.376     602.112         664.798            5445.29       591.238
(8, 7168, 2048)   ScaleCalculationMode.RCEIL               32       nan               64.544     644.032         nan                5515.49       552.754
(32, 7168, 2048)  ScaleCalculationMode.FLOOR                1      2128.03           228.512    2358.21          669.147            6231.47       603.834
(32, 7168, 2048)  ScaleCalculationMode.FLOOR               32       nan              229.312    2544.69          nan                6209.73       559.584
(32, 7168, 2048)  ScaleCalculationMode.RCEIL                1      2101.38           220.144    2358.29          677.635            6468.34       603.814
(32, 7168, 2048)  ScaleCalculationMode.RCEIL               32       nan              220.192    2535.33          nan                6466.93       561.65
(32, 8192, 5120)  ScaleCalculationMode.FLOOR                1      6057.66           651.264    6706.26          671.624            6247.04       606.669
(32, 8192, 5120)  ScaleCalculationMode.FLOOR               32       nan              651.536    7228.51          nan                6244.44       562.837
(32, 8192, 5120)  ScaleCalculationMode.RCEIL                1      5978.61           623.824    6707.81          680.505            6521.83       606.528
(32, 8192, 5120)  ScaleCalculationMode.RCEIL               32       nan              624.304    7229.57          nan                6516.82       562.755

To test with grouped MM, apply this patch, and run (tolerances updated as 32x32 gives "worse" quantization quality than 32x1):

pytest -q test/prototype/moe_training/test_mxfp8_grouped_mm.py -k test_mxfp8_grouped_gemm_with_dq_fwd_bwd

self,
amax: cutlass.Float32,
):
amax = cute.arch.fmax(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using warp reduction for this sense - one nit but i think this is better done via a for loop, see example from quack:

@cute.jit
def warp_reduce(val: cute.Numeric,
                op: Callable,
                width: cutlass.Constexpr = 32) -> cute.Numeric:
    for i in range(int(math.log2(width))):
        # cute.arch.shuffle_sync_bfly will read from another thread's registers
        val = op(val, cute.arch.shuffle_sync_bfly(val, offset=1 << i))
    return val

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

def mxfp8_quantize_cuda_3d(
x: torch.Tensor,
block_size: int = 32,
scale_block_k: int = 1,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this API is a bit confusing - the function is just called "mxfp8_quantize_cuda_3d" and combined with the "scale_block_k" param (without corresponding scale_block_n param) seems to communicate that it is a kernel for scaling across K.

I think we should update this to be:

def mxfp8_quantize_3d_cutedsl(
   x: torch.Tensor,
   scale_block_n: Literal[1, 32],
   scale_block_k: Literal[1, 32],
) ...
   # .. (validate only 32x1 and 32x32 are currently supported)
   # .. (in the future would be nice to support 1x32 as well)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed. (Probably 1x32 will not be a straightforward extension like 32x32.)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree, based on my experience with the kernels for 1x32 vs 32x1 scaling on 2D tensors.

This can just be the user-facing API, that dispatches to the appropriate kernel based on the scale_k/scale_n values

)
else:
padded_scale_cols = cute.round_up(n_blocks, 4)
k_block_tiles = cute.ceil_div(K, 128)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i just noticed, shouldn't we always be padding the scale tensor rows/cols such that is evenly divisible into 128x4 cols (or 4x128, depending on scaling dim)? for both 32x1 scaling and 32x32 scaling?

this seems to assume that in the case of 32x1 scaling, we assume K is divisible by 128, but do we actually validate that, or will it just silently fail/crash?

i would prefer we don't do if/else here, and just unconditionally use padded rows/cols.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@danielvegamyhre danielvegamyhre removed this from the MXFP8 Training milestone Apr 9, 2026
@alexsamardzic alexsamardzic force-pushed the add-32x32-support branch 2 times, most recently from cd4c46f to 3cc1b4a Compare April 10, 2026 11:30
weight_e4m3, weight_scales_blocked = mxfp8_quantize_cuda_3d(
weight._data if hasattr(weight, "_data") else weight,
block_size,
scale_block_n=block_size,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make this more explicit by also specifying scale_block_k=1? so it is clear at a glance we are still using 32x1 scaling here, not 32x32 (yet).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

Comment on lines 793 to +794
block_size: int = 32,
scale_block_k: int = 1,
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Apr 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make everything consistent with the user-facing API where scale_block_n and scale_block_k are both explicit, rather than N being implicit and K being explicit? i.e. in the kernel as well

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, fixed - sorry about that. For the kernel, as discussed above, considerable changes would be probably needed for scale_block_n different than 32, but it still make sense to have it explicit.

@alexsamardzic alexsamardzic force-pushed the add-32x32-support branch 2 times, most recently from bb90db2 to fb9bd5f Compare April 13, 2026 20:18
K: int,
scale_block_k: int,
) -> Tuple[str, Tuple[int, int, int, int]]:
del scale_block_n
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one last nit: in this particular helper function, i think it is confusing to pass K and scale_block_n args then immediately delete them via del K and del scale_block_n. let's just only pass the args we need here.

the other places you added explicit scale_block_n/scale_block_k look good (main kernel entry point, user facing api/wrapper, etc).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed.

# (compute_warps, tile_n, tile_k, k_tiles_per_cta)
_CUTEDSL_CONFIGS = {
"bf16_default": (6, 32, 128, 4),
"bf16_1x32": (4, 32, 128, 4),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this one be 32x1 (not 1x32) since it is quantizing shape (E,N,K) with 32x1 scaling along N

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was a typo, fixed.

amax = self._warp_reduce_max(amax)
scale_biased, inv_scale = compute_scale_from_amax(amax, USE_RCEIL)
if cutlass.const_expr(SCALE_DIM_K_VALUE == 32):
if lane == cutlass.Int32(0):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on your store_scale_32x32 implementation here:

            if cutlass.const_expr(BLOCKED_SCALE_OUTPUT):
                n_row = n_block * cutlass.Int64(32) + cutlass.Int64(lane)
                scales_expert[n_row, k_block] = scale_u8

don't we need all 32 lanes to store the warp reduced scale factor, not just thread 0, in order to replicate it to all 32 rows of the SF tile?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed that was a bug. I actually encountered it when I tried end-to-end test, and fixed it, but instead of adding the fix to commit, it stayed in the patch attached to my first commit. I've pushed this change now, and here is the proper patch for end-to-end testing: diff-end-to-end-test.txt. Shall we add this patch, to be an option for given grouped MM test?

padded_scale_cols = cute.round_up(n_blocks, 4)
k_block_tiles = cute.ceil_div(K, 128)
n_block_tiles = padded_scale_cols // cutlass.Int64(4)
scale_rows = K if cutlass.const_expr(SCALE_DIM_K_VALUE == 1) else N
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'm a bit confused, why does the number of rows in the scale factor tensor depend on the granularity of our scale factor size (32x32 vs 32x1)? same question for scale_cols below. is this right? if so, could you please add a brief comment explaining

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the confusion comes from the fact that I tried to keep single scale factor for 32x32 block as an option for the quantization side of the things, to be ready in case we decide to try with grouped MM kernel that would handle this situation. This variant should be under blocked_scale_output = False condition. I've added some comments to clarify it, and I've also added an explicit test variant for this case. Is this explaining it, or you had something else on mind here?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this complexity of swapping K/N for rows/cols be eliminated if we only supported blocked layout for 32x32?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I think I answered a different question above. You were asking about the blocked_scale_output = True case, while I replied about the compact blocked_scale_output = False 32x32 representation. In the blocked path this is intentional: 32x1 and 32x32 map to different logical scale layouts before blocking, so scale_rows / scale_cols do depend on the granularity there.

Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see, so to summarize my understanding:

  • for 32x1 blocked layout the scale shape is (K, N//32) since the grouped gemm expects this shape
  • for 32x32 blocked layout, we conceptually have one scale per (N//32, K//32) block, then replicate it by having every lane write it to its associated row, giving us (N, K//32).

is that correct? also, have you tested the 32x32 blocked outputs with the torch._scaled_grouped_mm mxfp8 grouped gemm?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is correct.

As for the testing: What is tested for 32x32 at the moment is the quantization itself in test_cuda_mx_3d_cutedsl_numerics. If you want, I can add a grouped MM test that consumes the blocked 32x32 scales.

Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes please add a test that confirms A @ B with 1x32 for A and 32x32 for B works, that would be great

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

USE_RCEIL,
BLOCKED_SCALE_OUTPUT_VALUE,
)
self._quantize_store_tail(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comparing the if cutlass.const_expr(IS_FULL_K_TILES) "true" path vs "false" path, the code looks identical except:

  • in the 'false' path, it has an extra condition k_in_bounds (e.g. if k_rel < TILE_K and k_in_bounds: ...)
  • at the end we callself._quantize_store_full vs self._quantize_store_tail

Could we simplify this by:

  1. always just doing the k_in_bounds check (unnecessary for full blocks but i doubt that 1 boolean check will hurt perf?)
  2. executing the shared code (which would now be one unified path)
  3. doing the if/else at the end for full vs tail store?

if i missed something, please let me know

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmarks still look the same after the change?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, benchmarked with and without this change, no noticeable difference.

@alexsamardzic alexsamardzic force-pushed the add-32x32-support branch 2 times, most recently from 5b89ea9 to e465aec Compare April 21, 2026 21:13
scale_biased, inv_scale = compute_scale_from_amax(amax, USE_RCEIL)
if cutlass.const_expr(SCALE_DIM_K_VALUE == 32):
if cutlass.const_expr(BLOCKED_SCALE_OUTPUT):
self._store_scale_32x32(
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a comment to clarify the blocked path uses all 32 lanes to replicate the warp-reduced scale, versus the non-blocked (row major) path uses a single lane to write a single non replicated scale factor for the 32x32 block?

Just to make this calling logic more immediately understandable, so the reader doesn't have to dive deeper into the code to understand the higher level calling logic.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

BLOCKED_SCALE_OUTPUT,
)
else:
self._store_scale(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but can we call this _store_scale_32x1 now to be explicit, since this kernel now supports different scale granularities. it would be preferable for both 32x32 and 32x1 be explicit, rather than one be implicit and one be explicit.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

@pytest.mark.parametrize(
"scale_mode", (ScaleCalculationMode.FLOOR, ScaleCalculationMode.RCEIL)
)
def test_emulate_mxfp8_grouped_gemm_2d_3d(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@alexsamardzic it seems like this test was changed from testing the emulated path, to testing the non-emulated/real-sm100 path? instead we should keep the test for the emulated path, and make a new test for this real 1x32 @ 32x32 case (make sure to add skip if not on sm100)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. We have now test_emulate_mxfp8_grouped_gemm_2d_3d and test_mxfp8_grouped_gemm_2d_3d, that both test both 32x1 and 32x32 variants that we have at the moment.

@danielvegamyhre danielvegamyhre merged commit a9f24af into main Apr 26, 2026
22 of 23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support 32x32 scaling for weights in MXFP8 weight quantization kernel

2 participants