|
38 | 38 | is_sm_at_least_100, |
39 | 39 | torch_version_at_least, |
40 | 40 | ) |
| 41 | +from unittest.mock import patch |
41 | 42 |
|
42 | 43 | # Needed since changing args to function causes recompiles |
43 | 44 | torch._dynamo.config.cache_size_limit = 128 |
@@ -1489,5 +1490,118 @@ def test_create_tensor_out_of_inference_mode(self): |
1489 | 1490 |
|
1490 | 1491 | common_utils.instantiate_parametrized_tests(TestFloat8Tensor) |
1491 | 1492 |
|
| 1493 | + |
| 1494 | +class TestMI350HardwareSupport(common_utils.TestCase): |
| 1495 | + """Tests that MI350 (gfx950) is accepted by FP8 hardware checks. |
| 1496 | +
|
| 1497 | + Uses mocking so the tests run on any hardware without needing an actual |
| 1498 | + MI350 GPU. |
| 1499 | + """ |
| 1500 | + |
| 1501 | + def _patch_mi350_only(self): |
| 1502 | + """Context manager simulating an MI350-only environment.""" |
| 1503 | + from unittest.mock import patch |
| 1504 | + |
| 1505 | + return [ |
| 1506 | + patch("torchao.float8.inference.is_MI350", return_value=True), |
| 1507 | + patch("torchao.float8.inference.is_MI300", return_value=False), |
| 1508 | + patch("torchao.float8.inference.is_sm_at_least_89", return_value=False), |
| 1509 | + patch("torch.cuda.is_available", return_value=True), |
| 1510 | + patch("torch.xpu.is_available", return_value=False), |
| 1511 | + ] |
| 1512 | + |
| 1513 | + def _patch_no_hw(self): |
| 1514 | + """Context manager simulating unsupported hardware.""" |
| 1515 | + from unittest.mock import patch |
| 1516 | + |
| 1517 | + return [ |
| 1518 | + patch("torchao.float8.inference.is_MI350", return_value=False), |
| 1519 | + patch("torchao.float8.inference.is_MI300", return_value=False), |
| 1520 | + patch("torchao.float8.inference.is_sm_at_least_89", return_value=False), |
| 1521 | + patch("torch.cuda.is_available", return_value=True), |
| 1522 | + patch("torch.xpu.is_available", return_value=False), |
| 1523 | + ] |
| 1524 | + |
| 1525 | + def _start(self, patches): |
| 1526 | + for p in patches: |
| 1527 | + p.start() |
| 1528 | + |
| 1529 | + def _stop(self, patches): |
| 1530 | + for p in patches: |
| 1531 | + p.stop() |
| 1532 | + |
| 1533 | + def test_check_hardware_support_mi350_per_tensor(self): |
| 1534 | + from torchao.float8.inference import _check_hardware_support |
| 1535 | + |
| 1536 | + patches = self._patch_mi350_only() |
| 1537 | + self._start(patches) |
| 1538 | + try: |
| 1539 | + _check_hardware_support((PerTensor(), PerTensor())) |
| 1540 | + finally: |
| 1541 | + self._stop(patches) |
| 1542 | + |
| 1543 | + def test_check_hardware_support_mi350_per_row(self): |
| 1544 | + from torchao.float8.inference import _check_hardware_support |
| 1545 | + |
| 1546 | + patches = self._patch_mi350_only() |
| 1547 | + self._start(patches) |
| 1548 | + try: |
| 1549 | + _check_hardware_support((PerRow(), PerRow())) |
| 1550 | + finally: |
| 1551 | + self._stop(patches) |
| 1552 | + |
| 1553 | + def test_check_hardware_support_rejects_unsupported_hw(self): |
| 1554 | + from torchao.float8.inference import _check_hardware_support |
| 1555 | + |
| 1556 | + patches = self._patch_no_hw() |
| 1557 | + self._start(patches) |
| 1558 | + try: |
| 1559 | + with self.assertRaises(AssertionError): |
| 1560 | + _check_hardware_support((PerRow(), PerRow())) |
| 1561 | + finally: |
| 1562 | + self._stop(patches) |
| 1563 | + |
| 1564 | + def test_quant_api_hardware_gate_mi350(self): |
| 1565 | + """The assertion in _float8_dynamic_activation_float8_weight_transform |
| 1566 | + should pass on MI350.""" |
| 1567 | + |
| 1568 | + with ( |
| 1569 | + patch("torchao.quantization.quant_api.is_MI350", return_value=True), |
| 1570 | + patch("torchao.quantization.quant_api.is_MI300", return_value=False), |
| 1571 | + patch( |
| 1572 | + "torchao.quantization.quant_api.is_sm_at_least_89", |
| 1573 | + return_value=False, |
| 1574 | + ), |
| 1575 | + patch("torch.cuda.is_available", return_value=True), |
| 1576 | + ): |
| 1577 | + from torchao.quantization.quant_api import ( |
| 1578 | + is_MI300, |
| 1579 | + is_MI350, |
| 1580 | + is_sm_at_least_89, |
| 1581 | + ) |
| 1582 | + |
| 1583 | + self.assertTrue(is_sm_at_least_89() or is_MI300() or is_MI350()) |
| 1584 | + |
| 1585 | + def test_quant_api_hardware_gate_rejects_unsupported(self): |
| 1586 | + from unittest.mock import patch |
| 1587 | + |
| 1588 | + with ( |
| 1589 | + patch("torchao.quantization.quant_api.is_MI350", return_value=False), |
| 1590 | + patch("torchao.quantization.quant_api.is_MI300", return_value=False), |
| 1591 | + patch( |
| 1592 | + "torchao.quantization.quant_api.is_sm_at_least_89", |
| 1593 | + return_value=False, |
| 1594 | + ), |
| 1595 | + patch("torch.cuda.is_available", return_value=True), |
| 1596 | + ): |
| 1597 | + from torchao.quantization.quant_api import ( |
| 1598 | + is_MI300, |
| 1599 | + is_MI350, |
| 1600 | + is_sm_at_least_89, |
| 1601 | + ) |
| 1602 | + |
| 1603 | + self.assertFalse(is_sm_at_least_89() or is_MI300() or is_MI350()) |
| 1604 | + |
| 1605 | + |
1492 | 1606 | if __name__ == "__main__": |
1493 | 1607 | run_tests() |
0 commit comments