Skip to content

Add pre-quantized activation support to MXFP8 grouped GEMM (_to_mxfp8_then_scaled_grouped_mm)#3961

Merged
danielvegamyhre merged 3 commits intopytorch:mainfrom
MagellaX:feat/3379-prequantized-a-mxfp8-grouped-mm
Mar 13, 2026
Merged

Add pre-quantized activation support to MXFP8 grouped GEMM (_to_mxfp8_then_scaled_grouped_mm)#3961
danielvegamyhre merged 3 commits intopytorch:mainfrom
MagellaX:feat/3379-prequantized-a-mxfp8-grouped-mm

Conversation

@MagellaX
Copy link
Copy Markdown
Contributor

Summary

This PR adds pre-quantized activation support to MXFP8 grouped GEMM while preserving existing behavior for dynamic quantization.

_to_mxfp8_then_scaled_grouped_mm now supports:

  • A as an MXTensor (pre-quantized path), or
  • prequantized_A=(qdata, scale) as an optional input when HP A is still provided.

What changed

  • Replaced the alias with a wrapper function for _to_mxfp8_then_scaled_grouped_mm and added:
    • prequantized_A: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
  • Added input normalization/validation for pre-quantized A:
    • block_size == 32
    • qdata.dtype == torch.float8_e4m3fn
    • scale.dtype in {torch.float8_e8m0fnu, torch.uint8} (uint8 viewed as float8_e8m0fnu)
    • 2D row-wise scale shape (M, K//block_size)
    • MXTensor requires unswizzled scales
  • Extended _MXFP8GroupedMM.forward:
    • uses provided pre-quantized A_data/A_scale when available
    • otherwise keeps existing dynamic quantization behavior
  • Extended _MXFP8GroupedMM.backward:
    • handles A as MXTensor by dequantizing when needed
  • Added robust grouped-mm emulation fallback for environments where grouped/scaled grouped kernels are unavailable.

Tests added

  • test_mxfp8_grouped_gemm_prequantized_tuple_matches_dynamic
  • test_mxfp8_grouped_gemm_mxtensor_activation_forward

Validation

Ran:

  • python -m pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k "prequantized or mxtensor_activation" -q

Results:

  • H100: 3 passed, 50 deselected
  • B200: 3 passed, 50 deselected

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Feb 27, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3961

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (2 Unrelated Failures)

As of commit e995007 with merge base 0c13a56 (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 27, 2026
@danielvegamyhre danielvegamyhre self-requested a review February 27, 2026 17:02
@danielvegamyhre
Copy link
Copy Markdown
Contributor

thanks for the contribution @MagellaX , can you provide some context on the use case / motivation for this change?

i have considered changes like this to make composabillity with torch.compile easier. Is that the reason?

@MagellaX
Copy link
Copy Markdown
Contributor Author

thanks for the contribution @MagellaX , can you provide some context on the use case / motivation for this change?

i have considered changes like this to make composabillity with torch.compile easier. Is that the reason?

Yes, torch.compile composability is one benefit, but the primary motivation is avoiding redundant activation quantization in MoE pipelines.

In the EP/MoE flow, activations are often already in MXFP8 before grouped GEMM (or only quantized A is available). Previously, _to_mxfp8_then_scaled_grouped_mm always re-quantized A internally. This adds extra quant/dequant work, scale/layout conversion, and memory traffic.

This change adds support for passing pre-quantized A (MXTensor or (qdata, scale)), so grouped GEMM can consume it directly.
If pre-quantized A is not provided, behavior is unchanged (existing dynamic path still works).

@danielvegamyhre
Copy link
Copy Markdown
Contributor

danielvegamyhre commented Feb 27, 2026

Previously, _to_mxfp8_then_scaled_grouped_mm always re-quantized A internally. This adds extra quant/dequant work, scale/layout conversion, and memory traffic.

thanks for the details - to clarify though, this is not the case actually: in _to_mxfp8_then_scaled_grouped_mm the "A" tensor can already optionally be an MXTensor subclass (pre-quantized). we currently have mxfp8 expert parallel building blocks that pre-quantize before the all2all, and stay in mxfp8 through the comms -> token permutation -> mxfp8 grouped mm - so no redundant quantization. You can read more about it here: https://docs.pytorch.org/ao/main/eager_tutorials/mxfp8_expert_parallel_training.html#mxfp8-expert-parallel-apis

Passing prequantized inputs via separate qdata/scales tensors could be useful though, for users who are using _to_mxfp8_then_scaled_grouped_mm independently (i.e., not using torchtitan or who are not using our mxfp8 EP building blocks)

"out_dtype must be bfloat16 or float32"
)
if isinstance(input_act, MXTensor):
assert wgrad_with_hp, (
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i don't think this assertion should be removed - if the inputs arrive pre-quantized along dim0, we don't have the bf16 inputs to save for backward, which we need in order to properly quantize them along dim1 for dW = dO.t() @ X.

to support this we would need a fast fused "dequantize along dim0 -> requantize along dim1" kernel (i'd happily review a PR for this!)

@MagellaX
Copy link
Copy Markdown
Contributor Author

Thanks for the clarification. I revisited the docs, and yeah, it makes sense that the existing EP path already supports A as an MXTensor....
The narrower motivation for this PR is to support passing prequantized activations as separate (qdata, scale) tensors for users invoking _to_mxfp8_then_scaled_grouped_mm independently of the EP building blocks, while still keeping high-precision A available for backward.

and yeah, about the backward contract , I re-checked the backward path, and the current PR should not relax that assertion for MXTensor inputs. For the non-wgrad_with_hp recipe we still need the original high-precision activations to quantize along dim1 for dW, and a dim0-prequantized MXTensor alone is not sufficient for that contract.

I updated the PR to restore the assertion for MXTensor inputs, while keeping the new tuple-based prequantized_A=(qdata, scale) support for the case where high-precision A is still available for backward. I also added a regression test for the MXTensor -> requires wgrad_with_hp behavior.

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@MagellaX what do you think the benefits of having the api accept pre-quantized inputs via either a MX tensor or a (qdata,scales) tuple - rather than requiring the caller to wrap their qdata/scale in a MXTensor? I think it is more convenient / less friction, but it does make the api a bit messy to have multiple ways to pass the same thing..

@MagellaX
Copy link
Copy Markdown
Contributor Author

@MagellaX what do you think the benefits of having the api accept pre-quantized inputs via either a MX tensor or a (qdata,scales) tuple - rather than requiring the caller to wrap their qdata/scale in a MXTensor? I think it is more convenient / less friction, but it does make the api a bit messy to have multiple ways to pass the same thing..

yeah, I think the main benefit of the tuple path is just lower friction for callers that already have raw qdata/scales and are not really living in the MXTensor API already.

for the existing EP path, MXTensor is definitely the cleaner representation. the reason I added the tuple form is that wrapping raw qdata/scale into an MXTensor is not super lightweight/public today since the caller also needs to know the extra MX metadata/contract, so the tuple path felt like a practical interop escape hatch.

that said, I agree having 2 ways to represent the same thing does make the API a bit messier. if you’d prefer to keep the surface area tighter, I’m happy to simplify this and either:

  1. drop the tuple path and keep this MXTensor-only, or
  2. do it via a small helper / constructor for building an MXTensor from (qdata, scale) cleanly.

my intent here was convenience/interoperability, not to make MXTensor less central.

@danielvegamyhre danielvegamyhre added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Mar 2, 2026
@danielvegamyhre
Copy link
Copy Markdown
Contributor

danielvegamyhre commented Mar 2, 2026

After considering this further, I think rather than having two ways to pass pre-quantized inputs, which pollutes the API surface and is a bit confusing, I would prefer to require the caller wrap the qdata and scales in an MXTensor, and that be the blessed path. However, we should add a user-facing helper for this, as the MXTensor constructor itself has several fields, some of which may be unclear how to set for the average user. Do you want to take a stab at this?

Maybe something like

class MXTensor: 

  @staticmethod
  def from_qdata_and_scales(
       cls, qdata: Tensor, 
       scales: Tensor, 
       orig_dtype: torch.dtype) -> MXTensor:
       
       return cls(
          qdata,  
          scales, 
          block_size=32,  
          lem_dtype=qdata.dtype,
          orig_dtype=orig_dtype,  
          act_quant_kwargs=None,  
          is_swizzled_scales=False
     )

Alternatively it could be a standalone helper function

@danielvegamyhre
Copy link
Copy Markdown
Contributor

2. do it via a small helper / constructor for building an MXTensor from (qdata, scale) cleanly.

@MagellaX let's go with option 2, that sounds good to me

@MagellaX
Copy link
Copy Markdown
Contributor Author

MagellaX commented Mar 3, 2026

I dropped the separate tuple-based path and updated the PR to keep MXTensor as the only prequantized input representation.

I added a small MXTensor.from_qdata_and_scales(...) helper so callers who already have raw (qdata, scale) don’t need to know the full constructor contract, and updated the grouped-mm tests/docs to use that path.

I also kept the MXTensor -> requires wgrad_with_hp constraint in place

act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None,
is_swizzled_scales: bool = False,
):
if elem_dtype is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just delete this variable?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MagellaX can you address this comment please, the elem_dtype will always be qdata.dtype, we don't need this argument

assert scales.dtype in (torch.float8_e8m0fnu, torch.uint8), (
f"Expected scales.dtype to be torch.float8_e8m0fnu or torch.uint8, got {scales.dtype}"
)
if scales.dtype == torch.uint8:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete this and enforce the callsite is passing the right dtype

if scales.dtype == torch.uint8:
scales = scales.view(torch.float8_e8m0fnu)

if not is_swizzled_scales:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO move this to MXTensor constructor if it isnt already there, same with all of the other asserts

kernel_preference: Optional[KernelPreference] = None,
act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None,
is_swizzled_scales: bool = False,
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add return type MXTensor | DTensor


@staticmethod
@torch._dynamo.allow_in_graph
def from_qdata_and_scales(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the only reason for this to exist is to handle dtensor? if yes, can we put dtensor in the function name?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is a convenience function/builder for pre-quantized inputs, because the MXTensor constructor has several required args that many users may be unsure how to set, so this builder just lets them create the MXTensor from their qdata, scales, and original dtype, while setting defaults for the rest.

The goal is just to reduce friction for calling _to_mxfp8_then_scaled_grouped_mm with prequantized inputs, as they are often quantized prior to the all2all

)


def _build_prequantized_a_mxtensor(
Copy link
Copy Markdown
Contributor

@danielvegamyhre danielvegamyhre Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we remove this helper and just use to_mx + MXTensor.from_qdata_and_scales directly? this step of indirection makes it a bit confusing i think, and the 2 explicit calls is not much code and is much clearer imo

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea..makes sense

@MagellaX MagellaX requested a review from danielvegamyhre March 4, 2026 18:01
generate_jagged_offs,
)
from torchao.prototype.mx_formats.mx_tensor import to_mx
from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you rebase? this test file doesn't exist anymore, we have separated the float8 vs mxfp8 tests

@MagellaX MagellaX force-pushed the feat/3379-prequantized-a-mxfp8-grouped-mm branch from 2b2219d to a8cb745 Compare March 4, 2026 18:38
@danielvegamyhre
Copy link
Copy Markdown
Contributor

@MagellaX linter is failing, can you run ruff check --fix <dirs> and ruff format <dirs>

@danielvegamyhre danielvegamyhre added the module: training quantize_ api training flow label Mar 5, 2026
@MagellaX MagellaX requested a review from danielvegamyhre March 5, 2026 06:44
@MagellaX
Copy link
Copy Markdown
Contributor Author

MagellaX commented Mar 8, 2026

i think u can merge it @danielvegamyhre

@danielvegamyhre
Copy link
Copy Markdown
Contributor

@MagellaX sorry for the delay, got side tracked by some urgent bugs that needed attention. would you mind rebasing to resolve the merge conflict, then we can land? thanks!

@MagellaX MagellaX force-pushed the feat/3379-prequantized-a-mxfp8-grouped-mm branch from 439fdfa to 2d70b30 Compare March 12, 2026 07:35
@MagellaX
Copy link
Copy Markdown
Contributor Author

I think it LGTM

@danielvegamyhre danielvegamyhre added this to the MXFP8 Training milestone Mar 12, 2026
@danielvegamyhre danielvegamyhre merged commit e654d74 into pytorch:main Mar 13, 2026
17 of 19 checks passed
@danielvegamyhre
Copy link
Copy Markdown
Contributor

landed, thanks @MagellaX!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow moe mx topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants