Add pre-quantized activation support to MXFP8 grouped GEMM (_to_mxfp8_then_scaled_grouped_mm)#3961
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3961
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (2 Unrelated Failures)As of commit e995007 with merge base 0c13a56 ( BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
thanks for the contribution @MagellaX , can you provide some context on the use case / motivation for this change? i have considered changes like this to make composabillity with torch.compile easier. Is that the reason? |
Yes, torch.compile composability is one benefit, but the primary motivation is avoiding redundant activation quantization in MoE pipelines. In the EP/MoE flow, activations are often already in MXFP8 before grouped GEMM (or only quantized A is available). Previously, _to_mxfp8_then_scaled_grouped_mm always re-quantized A internally. This adds extra quant/dequant work, scale/layout conversion, and memory traffic. This change adds support for passing pre-quantized A (MXTensor or (qdata, scale)), so grouped GEMM can consume it directly. |
thanks for the details - to clarify though, this is not the case actually: in Passing prequantized inputs via separate qdata/scales tensors could be useful though, for users who are using |
| "out_dtype must be bfloat16 or float32" | ||
| ) | ||
| if isinstance(input_act, MXTensor): | ||
| assert wgrad_with_hp, ( |
There was a problem hiding this comment.
i don't think this assertion should be removed - if the inputs arrive pre-quantized along dim0, we don't have the bf16 inputs to save for backward, which we need in order to properly quantize them along dim1 for dW = dO.t() @ X.
to support this we would need a fast fused "dequantize along dim0 -> requantize along dim1" kernel (i'd happily review a PR for this!)
|
Thanks for the clarification. I revisited the docs, and yeah, it makes sense that the existing EP path already supports A as an MXTensor.... and yeah, about the backward contract , I re-checked the backward path, and the current PR should not relax that assertion for MXTensor inputs. For the non-wgrad_with_hp recipe we still need the original high-precision activations to quantize along dim1 for dW, and a dim0-prequantized MXTensor alone is not sufficient for that contract. I updated the PR to restore the assertion for MXTensor inputs, while keeping the new tuple-based prequantized_A=(qdata, scale) support for the case where high-precision A is still available for backward. I also added a regression test for the MXTensor -> requires wgrad_with_hp behavior. |
|
@MagellaX what do you think the benefits of having the api accept pre-quantized inputs via either a MX tensor or a (qdata,scales) tuple - rather than requiring the caller to wrap their qdata/scale in a MXTensor? I think it is more convenient / less friction, but it does make the api a bit messy to have multiple ways to pass the same thing.. |
yeah, I think the main benefit of the tuple path is just lower friction for callers that already have raw for the existing EP path, that said, I agree having 2 ways to represent the same thing does make the API a bit messier. if you’d prefer to keep the surface area tighter, I’m happy to simplify this and either:
my intent here was convenience/interoperability, not to make |
|
After considering this further, I think rather than having two ways to pass pre-quantized inputs, which pollutes the API surface and is a bit confusing, I would prefer to require the caller wrap the qdata and scales in an MXTensor, and that be the blessed path. However, we should add a user-facing helper for this, as the MXTensor constructor itself has several fields, some of which may be unclear how to set for the average user. Do you want to take a stab at this? Maybe something like class MXTensor:
@staticmethod
def from_qdata_and_scales(
cls, qdata: Tensor,
scales: Tensor,
orig_dtype: torch.dtype) -> MXTensor:
return cls(
qdata,
scales,
block_size=32,
lem_dtype=qdata.dtype,
orig_dtype=orig_dtype,
act_quant_kwargs=None,
is_swizzled_scales=False
)Alternatively it could be a standalone helper function |
@MagellaX let's go with option 2, that sounds good to me |
|
I dropped the separate tuple-based path and updated the PR to keep MXTensor as the only prequantized input representation. I added a small MXTensor.from_qdata_and_scales(...) helper so callers who already have raw (qdata, scale) don’t need to know the full constructor contract, and updated the grouped-mm tests/docs to use that path. I also kept the MXTensor -> requires wgrad_with_hp constraint in place |
| act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, | ||
| is_swizzled_scales: bool = False, | ||
| ): | ||
| if elem_dtype is None: |
There was a problem hiding this comment.
@MagellaX can you address this comment please, the elem_dtype will always be qdata.dtype, we don't need this argument
| assert scales.dtype in (torch.float8_e8m0fnu, torch.uint8), ( | ||
| f"Expected scales.dtype to be torch.float8_e8m0fnu or torch.uint8, got {scales.dtype}" | ||
| ) | ||
| if scales.dtype == torch.uint8: |
There was a problem hiding this comment.
delete this and enforce the callsite is passing the right dtype
| if scales.dtype == torch.uint8: | ||
| scales = scales.view(torch.float8_e8m0fnu) | ||
|
|
||
| if not is_swizzled_scales: |
There was a problem hiding this comment.
IMO move this to MXTensor constructor if it isnt already there, same with all of the other asserts
| kernel_preference: Optional[KernelPreference] = None, | ||
| act_quant_kwargs: Optional[QuantizeTensorToMXKwargs] = None, | ||
| is_swizzled_scales: bool = False, | ||
| ): |
There was a problem hiding this comment.
add return type MXTensor | DTensor
|
|
||
| @staticmethod | ||
| @torch._dynamo.allow_in_graph | ||
| def from_qdata_and_scales( |
There was a problem hiding this comment.
is the only reason for this to exist is to handle dtensor? if yes, can we put dtensor in the function name?
There was a problem hiding this comment.
it is a convenience function/builder for pre-quantized inputs, because the MXTensor constructor has several required args that many users may be unsure how to set, so this builder just lets them create the MXTensor from their qdata, scales, and original dtype, while setting defaults for the rest.
The goal is just to reduce friction for calling _to_mxfp8_then_scaled_grouped_mm with prequantized inputs, as they are often quantized prior to the all2all
| ) | ||
|
|
||
|
|
||
| def _build_prequantized_a_mxtensor( |
There was a problem hiding this comment.
nit: can we remove this helper and just use to_mx + MXTensor.from_qdata_and_scales directly? this step of indirection makes it a bit confusing i think, and the 2 explicit calls is not much code and is much clearer imo
| generate_jagged_offs, | ||
| ) | ||
| from torchao.prototype.mx_formats.mx_tensor import to_mx | ||
| from torchao.prototype.mx_formats.mx_tensor import MXTensor, to_mx |
There was a problem hiding this comment.
can you rebase? this test file doesn't exist anymore, we have separated the float8 vs mxfp8 tests
2b2219d to
a8cb745
Compare
|
@MagellaX linter is failing, can you run |
|
i think u can merge it @danielvegamyhre |
|
@MagellaX sorry for the delay, got side tracked by some urgent bugs that needed attention. would you mind rebasing to resolve the merge conflict, then we can land? thanks! |
# Conflicts: # torchao/prototype/moe_training/mxfp8_grouped_mm.py
439fdfa to
2d70b30
Compare
|
I think it LGTM |
|
landed, thanks @MagellaX! |
Summary
This PR adds pre-quantized activation support to MXFP8 grouped GEMM while preserving existing behavior for dynamic quantization.
_to_mxfp8_then_scaled_grouped_mmnow supports:Aas anMXTensor(pre-quantized path), orprequantized_A=(qdata, scale)as an optional input when HPAis still provided.What changed
_to_mxfp8_then_scaled_grouped_mmand added:prequantized_A: Optional[Tuple[torch.Tensor, torch.Tensor]] = NoneA:block_size == 32qdata.dtype == torch.float8_e4m3fnscale.dtype in {torch.float8_e8m0fnu, torch.uint8}(uint8viewed asfloat8_e8m0fnu)(M, K//block_size)MXTensorrequires unswizzled scales_MXFP8GroupedMM.forward:A_data/A_scalewhen available_MXFP8GroupedMM.backward:AasMXTensorby dequantizing when neededTests added
test_mxfp8_grouped_gemm_prequantized_tuple_matches_dynamictest_mxfp8_grouped_gemm_mxtensor_activation_forwardValidation
Ran:
python -m pytest test/prototype/moe_training/test_scaled_grouped_mm.py -k "prequantized or mxtensor_activation" -qResults:
3 passed, 50 deselected3 passed, 50 deselected