Skip to content

Commit 7dd8be5

Browse files
Create recipe for flux2pro running on AMD (#4200)
Differential Revision: D98537991 Pull Request resolved: #4200
1 parent da257b5 commit 7dd8be5

3 files changed

Lines changed: 118 additions & 2 deletions

File tree

test/quantization/quantize_/workflows/float8/test_float8_tensor.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
is_sm_at_least_100,
3939
torch_version_at_least,
4040
)
41+
from unittest.mock import patch
4142

4243
# Needed since changing args to function causes recompiles
4344
torch._dynamo.config.cache_size_limit = 128
@@ -1489,5 +1490,118 @@ def test_create_tensor_out_of_inference_mode(self):
14891490

14901491
common_utils.instantiate_parametrized_tests(TestFloat8Tensor)
14911492

1493+
1494+
class TestMI350HardwareSupport(common_utils.TestCase):
1495+
"""Tests that MI350 (gfx950) is accepted by FP8 hardware checks.
1496+
1497+
Uses mocking so the tests run on any hardware without needing an actual
1498+
MI350 GPU.
1499+
"""
1500+
1501+
def _patch_mi350_only(self):
1502+
"""Context manager simulating an MI350-only environment."""
1503+
from unittest.mock import patch
1504+
1505+
return [
1506+
patch("torchao.float8.inference.is_MI350", return_value=True),
1507+
patch("torchao.float8.inference.is_MI300", return_value=False),
1508+
patch("torchao.float8.inference.is_sm_at_least_89", return_value=False),
1509+
patch("torch.cuda.is_available", return_value=True),
1510+
patch("torch.xpu.is_available", return_value=False),
1511+
]
1512+
1513+
def _patch_no_hw(self):
1514+
"""Context manager simulating unsupported hardware."""
1515+
from unittest.mock import patch
1516+
1517+
return [
1518+
patch("torchao.float8.inference.is_MI350", return_value=False),
1519+
patch("torchao.float8.inference.is_MI300", return_value=False),
1520+
patch("torchao.float8.inference.is_sm_at_least_89", return_value=False),
1521+
patch("torch.cuda.is_available", return_value=True),
1522+
patch("torch.xpu.is_available", return_value=False),
1523+
]
1524+
1525+
def _start(self, patches):
1526+
for p in patches:
1527+
p.start()
1528+
1529+
def _stop(self, patches):
1530+
for p in patches:
1531+
p.stop()
1532+
1533+
def test_check_hardware_support_mi350_per_tensor(self):
1534+
from torchao.float8.inference import _check_hardware_support
1535+
1536+
patches = self._patch_mi350_only()
1537+
self._start(patches)
1538+
try:
1539+
_check_hardware_support((PerTensor(), PerTensor()))
1540+
finally:
1541+
self._stop(patches)
1542+
1543+
def test_check_hardware_support_mi350_per_row(self):
1544+
from torchao.float8.inference import _check_hardware_support
1545+
1546+
patches = self._patch_mi350_only()
1547+
self._start(patches)
1548+
try:
1549+
_check_hardware_support((PerRow(), PerRow()))
1550+
finally:
1551+
self._stop(patches)
1552+
1553+
def test_check_hardware_support_rejects_unsupported_hw(self):
1554+
from torchao.float8.inference import _check_hardware_support
1555+
1556+
patches = self._patch_no_hw()
1557+
self._start(patches)
1558+
try:
1559+
with self.assertRaises(AssertionError):
1560+
_check_hardware_support((PerRow(), PerRow()))
1561+
finally:
1562+
self._stop(patches)
1563+
1564+
def test_quant_api_hardware_gate_mi350(self):
1565+
"""The assertion in _float8_dynamic_activation_float8_weight_transform
1566+
should pass on MI350."""
1567+
1568+
with (
1569+
patch("torchao.quantization.quant_api.is_MI350", return_value=True),
1570+
patch("torchao.quantization.quant_api.is_MI300", return_value=False),
1571+
patch(
1572+
"torchao.quantization.quant_api.is_sm_at_least_89",
1573+
return_value=False,
1574+
),
1575+
patch("torch.cuda.is_available", return_value=True),
1576+
):
1577+
from torchao.quantization.quant_api import (
1578+
is_MI300,
1579+
is_MI350,
1580+
is_sm_at_least_89,
1581+
)
1582+
1583+
self.assertTrue(is_sm_at_least_89() or is_MI300() or is_MI350())
1584+
1585+
def test_quant_api_hardware_gate_rejects_unsupported(self):
1586+
from unittest.mock import patch
1587+
1588+
with (
1589+
patch("torchao.quantization.quant_api.is_MI350", return_value=False),
1590+
patch("torchao.quantization.quant_api.is_MI300", return_value=False),
1591+
patch(
1592+
"torchao.quantization.quant_api.is_sm_at_least_89",
1593+
return_value=False,
1594+
),
1595+
patch("torch.cuda.is_available", return_value=True),
1596+
):
1597+
from torchao.quantization.quant_api import (
1598+
is_MI300,
1599+
is_MI350,
1600+
is_sm_at_least_89,
1601+
)
1602+
1603+
self.assertFalse(is_sm_at_least_89() or is_MI300() or is_MI350())
1604+
1605+
14921606
if __name__ == "__main__":
14931607
run_tests()

torchao/float8/inference.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torchao.float8.types import FP8Granularity
1717
from torchao.utils import (
1818
is_MI300,
19+
is_MI350,
1920
is_sm_at_least_89,
2021
)
2122

@@ -295,7 +296,7 @@ def _check_hardware_support(
295296

296297
if is_per_tensor or is_per_row:
297298
assert torch.xpu.is_available() or (
298-
torch.cuda.is_available() and is_sm_at_least_89() or is_MI300()
299+
torch.cuda.is_available() and is_sm_at_least_89() or is_MI300() or is_MI350()
299300
), (
300301
"Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+ or XPU."
301302
)

torchao/quantization/quant_api.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
)
9292
from torchao.utils import (
9393
is_MI300,
94+
is_MI350,
9495
is_sm_at_least_89,
9596
)
9697

@@ -1527,7 +1528,7 @@ def _float8_dynamic_activation_float8_weight_transform(
15271528
parameter_name: str = "weight",
15281529
):
15291530
if torch.cuda.is_available():
1530-
assert is_sm_at_least_89() or is_MI300(), (
1531+
assert is_sm_at_least_89() or is_MI300() or is_MI350(), (
15311532
"Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+"
15321533
)
15331534
if config.set_inductor_config:

0 commit comments

Comments
 (0)