|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import copy |
| 7 | +import logging |
| 8 | +import unittest |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch import nn |
| 12 | +from torch.testing._internal import common_utils |
| 13 | + |
| 14 | +try: |
| 15 | + from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8_SPARSE |
| 16 | +except ImportError: |
| 17 | + PLATFORM_SUPPORTS_FP8_SPARSE = False |
| 18 | +from torchao.quantization import ( |
| 19 | + Float8DynamicActivationFloat8WeightConfig, |
| 20 | +) |
| 21 | +from torchao.quantization.granularity import PerTensor |
| 22 | +from torchao.quantization.quant_api import ( |
| 23 | + quantize_, |
| 24 | +) |
| 25 | +from torchao.quantization.quantize_.workflows import ( |
| 26 | + Float8PackingFormat, |
| 27 | +) |
| 28 | +from torchao.quantization.utils import compute_error |
| 29 | +from torchao.sparsity import apply_fake_sparsity |
| 30 | +from torchao.utils import ( |
| 31 | + is_ROCM, |
| 32 | + torch_version_at_least, |
| 33 | +) |
| 34 | + |
| 35 | +logging.basicConfig( |
| 36 | + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO |
| 37 | +) |
| 38 | + |
| 39 | + |
| 40 | +@unittest.skipIf( |
| 41 | + not torch_version_at_least("2.10.0"), |
| 42 | + "Need torch >= 2.10.0", |
| 43 | +) |
| 44 | +class TestFloat8Sparse2x4_1DData1DMetadataTensor(common_utils.TestCase): |
| 45 | + def setUp(self): |
| 46 | + if not is_ROCM(): |
| 47 | + self.skipTest("hipSPARSELt path requires ROCm") |
| 48 | + if not PLATFORM_SUPPORTS_FP8_SPARSE: |
| 49 | + self.skipTest("Need platform with FP8 sparse support (hipSPARSELt)") |
| 50 | + |
| 51 | + @common_utils.parametrize("compile", [True, False]) |
| 52 | + def test_fp8_hipsparselt_sparse(self, compile): |
| 53 | + with torch.inference_mode(): |
| 54 | + input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda") |
| 55 | + model = ( |
| 56 | + nn.Sequential( |
| 57 | + nn.Linear(256, 1024), |
| 58 | + nn.Linear(1024, 256), |
| 59 | + ) |
| 60 | + .bfloat16() |
| 61 | + .cuda() |
| 62 | + .eval() |
| 63 | + ) |
| 64 | + |
| 65 | + apply_fake_sparsity(model) |
| 66 | + baseline_result = model(input) |
| 67 | + model_copy = copy.deepcopy(model) |
| 68 | + |
| 69 | + # Quantized (dense) |
| 70 | + quantize_( |
| 71 | + model_copy, |
| 72 | + Float8DynamicActivationFloat8WeightConfig( |
| 73 | + granularity=PerTensor(), |
| 74 | + ), |
| 75 | + ) |
| 76 | + dense_result = model_copy(input) |
| 77 | + dense_sqnr = compute_error(baseline_result, dense_result) |
| 78 | + |
| 79 | + # Sparse + quantized |
| 80 | + quantize_( |
| 81 | + model, |
| 82 | + Float8DynamicActivationFloat8WeightConfig( |
| 83 | + version=2, |
| 84 | + packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA, |
| 85 | + granularity=PerTensor(), |
| 86 | + ), |
| 87 | + ) |
| 88 | + if compile: |
| 89 | + model = torch.compile(model) |
| 90 | + sparse_result = model(input) |
| 91 | + sparse_sqnr = compute_error(baseline_result, sparse_result) |
| 92 | + |
| 93 | + self.assertEqual(dense_sqnr, sparse_sqnr) |
| 94 | + |
| 95 | + def test_fp8_hipsparselt_sparse_lowering_op_clone(self): |
| 96 | + """Validates clone dispatch correctly copies both sparse data and scale metadata.""" |
| 97 | + with torch.inference_mode(): |
| 98 | + model = nn.Linear(256, 1024).half().cuda().eval() |
| 99 | + apply_fake_sparsity(model) |
| 100 | + quantize_( |
| 101 | + model, |
| 102 | + Float8DynamicActivationFloat8WeightConfig( |
| 103 | + version=2, |
| 104 | + packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA, |
| 105 | + granularity=PerTensor(), |
| 106 | + ), |
| 107 | + ) |
| 108 | + |
| 109 | + original = model.weight.dequantize() |
| 110 | + cloned = model.weight.clone().dequantize() |
| 111 | + |
| 112 | + for o, c in zip(original, cloned): |
| 113 | + self.assertEqual(o, c) |
| 114 | + |
| 115 | + def test_fp8_hipsparselt_sparse_lowering_op_to(self): |
| 116 | + """Validates both to.dtype_layout and to.dtype dispatch paths correctly dequantize the sparse tensor.""" |
| 117 | + with torch.inference_mode(): |
| 118 | + model = nn.Linear(256, 1024).half().cuda().eval() |
| 119 | + apply_fake_sparsity(model) |
| 120 | + model_copy = copy.deepcopy(model) |
| 121 | + expected = model_copy.weight.to(dtype=torch.float) |
| 122 | + |
| 123 | + quantize_( |
| 124 | + model, |
| 125 | + Float8DynamicActivationFloat8WeightConfig( |
| 126 | + version=2, |
| 127 | + packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA, |
| 128 | + granularity=PerTensor(), |
| 129 | + ), |
| 130 | + ) |
| 131 | + |
| 132 | + original_by_to_dtype_layout = torch.ops.aten.to.dtype_layout( |
| 133 | + model.weight, |
| 134 | + dtype=torch.float, |
| 135 | + layout=torch.strided, |
| 136 | + ) |
| 137 | + torch.testing.assert_close( |
| 138 | + expected, original_by_to_dtype_layout, atol=1e-1, rtol=1e-1 |
| 139 | + ) |
| 140 | + |
| 141 | + original_by_to_dtype = torch.ops.aten.to.dtype( |
| 142 | + model.weight, |
| 143 | + torch.float, |
| 144 | + ) |
| 145 | + torch.testing.assert_close( |
| 146 | + expected, original_by_to_dtype, atol=1e-1, rtol=1e-1 |
| 147 | + ) |
| 148 | + |
| 149 | + |
| 150 | +common_utils.instantiate_parametrized_tests(TestFloat8Sparse2x4_1DData1DMetadataTensor) |
| 151 | + |
| 152 | +if __name__ == "__main__": |
| 153 | + unittest.main() |
0 commit comments