Commit d26bbae
[moe training] fix fp8 grouped mm compile issue (#4233)
[moe training] register Float8TrainingOpConfig as pytree constant for torch.compile
Register Float8TrainingOpConfig as a pytree constant (matching
MXFP8TrainingOpConfig) so torch.compile can properly handle the config
stored in tensor subclass metadata via __tensor_flatten__.
- Add @register_as_pytree_constant decorator to Float8TrainingOpConfig
- Add __eq__ and __hash__ methods required for pytree constant registration
Related: #4048
Made-with: Cursor
Co-authored-by: DCCS-5881 <rissinha@chi-mi325x-pod2-103.ord.vultr.cpe.ice.amd.com>1 parent 6e7a6e9 commit d26bbae
1 file changed
+20
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
44 | 44 | | |
45 | 45 | | |
46 | 46 | | |
| 47 | + | |
47 | 48 | | |
48 | 49 | | |
49 | 50 | | |
| |||
74 | 75 | | |
75 | 76 | | |
76 | 77 | | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
77 | 97 | | |
78 | 98 | | |
79 | 99 | | |
| |||
0 commit comments