Skip to content

Commit d26bbae

Browse files
rishisinhanjDCCS-5881
andauthored
[moe training] fix fp8 grouped mm compile issue (#4233)
[moe training] register Float8TrainingOpConfig as pytree constant for torch.compile Register Float8TrainingOpConfig as a pytree constant (matching MXFP8TrainingOpConfig) so torch.compile can properly handle the config stored in tensor subclass metadata via __tensor_flatten__. - Add @register_as_pytree_constant decorator to Float8TrainingOpConfig - Add __eq__ and __hash__ methods required for pytree constant registration Related: #4048 Made-with: Cursor Co-authored-by: DCCS-5881 <rissinha@chi-mi325x-pod2-103.ord.vultr.cpe.ice.amd.com>
1 parent 6e7a6e9 commit d26bbae

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

torchao/prototype/moe_training/config.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ class TrainingOpBaseConfig(AOBaseConfig):
4444
pass
4545

4646

47+
@register_as_pytree_constant
4748
@dataclass
4849
class Float8TrainingOpConfig(TrainingOpBaseConfig):
4950
"""
@@ -74,6 +75,25 @@ def from_recipe(
7475
else:
7576
raise ValueError(f"Unsupported FP8 recipe: {recipe}")
7677

78+
def __eq__(self, other):
79+
if isinstance(other, Float8TrainingOpConfig):
80+
return (
81+
self.float8_dtype == other.float8_dtype
82+
and self.out_dtype == other.out_dtype
83+
and self.pad_token_groups_for_grouped_mm
84+
== other.pad_token_groups_for_grouped_mm
85+
)
86+
return NotImplemented
87+
88+
def __hash__(self):
89+
return hash(
90+
(
91+
self.float8_dtype,
92+
self.out_dtype,
93+
self.pad_token_groups_for_grouped_mm,
94+
)
95+
)
96+
7797

7898
# register as pytree constant so we can use dynamo nonstrict trace in torchao.prototype.moe_training.ep
7999
@register_as_pytree_constant

0 commit comments

Comments
 (0)