Skip to content

Commit 1c3b180

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Add Sparse2x4HIPSPARSELTFloat8Tensor (#4277)
Summary: X-link: pytorch/pytorch#180312 What: Adding a new tensor subclass for FP8 2:4 sparsity via hipSPARSELt (ROCm only). Packs compressed values + metadata into a single tensor with `_cslt_compress` and dispatches through `_cslt_sparse_mm` with `A_scale * B_scale` as `alpha`. Why: This hipSPARSELt path differs enough in packing and kernel routing from CUTLASS to warrant a dedicated path. Reference: https://rocm.blogs.amd.com/artificial-intelligence/introduce_hipsparselt/README.html Differential Revision: D100640267
1 parent 6529fca commit 1c3b180

5 files changed

Lines changed: 450 additions & 0 deletions

File tree

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import copy
7+
import logging
8+
import unittest
9+
10+
import torch
11+
from torch import nn
12+
from torch.testing._internal import common_utils
13+
14+
try:
15+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8_SPARSE
16+
except ImportError:
17+
PLATFORM_SUPPORTS_FP8_SPARSE = False
18+
from torchao.quantization import (
19+
Float8DynamicActivationFloat8WeightConfig,
20+
)
21+
from torchao.quantization.granularity import PerTensor
22+
from torchao.quantization.quant_api import (
23+
quantize_,
24+
)
25+
from torchao.quantization.quantize_.workflows import (
26+
Float8PackingFormat,
27+
)
28+
from torchao.quantization.utils import compute_error
29+
from torchao.sparsity import apply_fake_sparsity
30+
from torchao.utils import (
31+
torch_version_at_least,
32+
)
33+
34+
logging.basicConfig(
35+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
36+
)
37+
38+
39+
@unittest.skipIf(
40+
not torch_version_at_least("2.10.0"),
41+
"Need torch >= 2.10.0",
42+
)
43+
class TestFloat8Sparse2x4_1DData1DMetadataTensor(common_utils.TestCase):
44+
def setUp(self):
45+
if not torch.cuda.is_available():
46+
self.skipTest("Need CUDA available")
47+
if not torch.version.hip:
48+
self.skipTest("hipSPARSELt path is ROCm-only")
49+
if not PLATFORM_SUPPORTS_FP8_SPARSE:
50+
self.skipTest("Need platform with FP8 sparse support (hipSPARSELt)")
51+
52+
@common_utils.parametrize("compile", [True, False])
53+
def test_fp8_hipsparselt_sparse(self, compile):
54+
with torch.inference_mode():
55+
input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda")
56+
model = (
57+
nn.Sequential(
58+
nn.Linear(256, 1024),
59+
nn.Linear(1024, 256),
60+
)
61+
.bfloat16()
62+
.cuda()
63+
.eval()
64+
)
65+
66+
apply_fake_sparsity(model)
67+
baseline_result = model(input)
68+
model_copy = copy.deepcopy(model)
69+
70+
# Quantized (dense)
71+
quantize_(
72+
model_copy,
73+
Float8DynamicActivationFloat8WeightConfig(
74+
granularity=PerTensor(),
75+
),
76+
)
77+
dense_result = model_copy(input)
78+
dense_sqnr = compute_error(baseline_result, dense_result)
79+
80+
# Sparse + quantized
81+
quantize_(
82+
model,
83+
Float8DynamicActivationFloat8WeightConfig(
84+
version=2,
85+
packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA,
86+
granularity=PerTensor(),
87+
),
88+
)
89+
if compile:
90+
model = torch.compile(model)
91+
sparse_result = model(input)
92+
sparse_sqnr = compute_error(baseline_result, sparse_result)
93+
94+
self.assertEqual(dense_sqnr, sparse_sqnr)
95+
96+
def test_fp8_hipsparselt_sparse_lowering_op_clone(self):
97+
"""Validates clone dispatch correctly copies both sparse data and scale metadata."""
98+
with torch.inference_mode():
99+
model = nn.Linear(256, 1024).half().cuda().eval()
100+
apply_fake_sparsity(model)
101+
quantize_(
102+
model,
103+
Float8DynamicActivationFloat8WeightConfig(
104+
version=2,
105+
packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA,
106+
granularity=PerTensor(),
107+
),
108+
)
109+
110+
original = model.weight.dequantize()
111+
cloned = model.weight.clone().dequantize()
112+
113+
for o, c in zip(original, cloned):
114+
self.assertEqual(o, c)
115+
116+
def test_fp8_hipsparselt_sparse_lowering_op_to(self):
117+
"""Validates both to.dtype_layout and to.dtype dispatch paths correctly dequantize the sparse tensor."""
118+
with torch.inference_mode():
119+
model = nn.Linear(256, 1024).half().cuda().eval()
120+
apply_fake_sparsity(model)
121+
model_copy = copy.deepcopy(model)
122+
expected = model_copy.weight.to(dtype=torch.float)
123+
124+
quantize_(
125+
model,
126+
Float8DynamicActivationFloat8WeightConfig(
127+
version=2,
128+
packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA,
129+
granularity=PerTensor(),
130+
),
131+
)
132+
133+
original_by_to_dtype_layout = torch.ops.aten.to.dtype_layout(
134+
model.weight,
135+
dtype=torch.float,
136+
layout=torch.strided,
137+
)
138+
torch.testing.assert_close(
139+
expected, original_by_to_dtype_layout, atol=1e-1, rtol=1e-1
140+
)
141+
142+
original_by_to_dtype = torch.ops.aten.to.dtype(
143+
model.weight,
144+
torch.float,
145+
)
146+
torch.testing.assert_close(
147+
expected, original_by_to_dtype, atol=1e-1, rtol=1e-1
148+
)
149+
150+
151+
common_utils.instantiate_parametrized_tests(TestFloat8Sparse2x4_1DData1DMetadataTensor)
152+
153+
if __name__ == "__main__":
154+
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from torchao.quantization.quantize_.workflows import (
5454
Float8PackingFormat,
55+
Float8Sparse2x4_1DData1DMetadataTensor,
5556
Float8Tensor,
5657
Int4ChooseQParamsAlgorithm,
5758
Int4PackingFormat,
@@ -1268,6 +1269,17 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
12681269
act_quant_kwargs=act_quant_kwargs,
12691270
)
12701271
return quantized_weight
1272+
elif packing_format == Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA:
1273+
assert isinstance(weight_granularity, PerTensor), (
1274+
"Sparse 1D data 1D metadata packing format only supports per-tensor quantization"
1275+
)
1276+
quantized_weight = Float8Sparse2x4_1DData1DMetadataTensor.from_hp(
1277+
weight,
1278+
float8_dtype=weight_dtype,
1279+
granularity=weight_granularity,
1280+
act_quant_kwargs=act_quant_kwargs,
1281+
)
1282+
return quantized_weight
12711283

12721284

12731285
@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from .float8.float8_packing_format import (
22
Float8PackingFormat,
33
)
4+
from .float8.float8_sparse_2x4_1d_data_1d_metadata_tensor import (
5+
Float8Sparse2x4_1DData1DMetadataTensor,
6+
)
47
from .float8.float8_tensor import (
58
Float8Tensor,
69
QuantizeTensorToFloat8Kwargs,
@@ -45,6 +48,7 @@
4548
"QuantizeTensorToInt8Kwargs",
4649
"Float8Tensor",
4750
"Sparse2x4CUTLASSFloat8Tensor",
51+
"Float8Sparse2x4_1DData1DMetadataTensor",
4852
"Float8PackingFormat",
4953
"QuantizeTensorToFloat8Kwargs",
5054
"Int8Tensor",

torchao/quantization/quantize_/workflows/float8/float8_packing_format.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ class Float8PackingFormat(str, Enum):
3333
This packing format will dispatch to `rowwise_scaled_linear_sparse_cutlass_f8f8`, which will fuse the per-row scaling into the sparse matmul.
3434
"""
3535
SPARSE_CUTLASS = "sparse_cutlass"
36+
"""
37+
Sparse packing format for 2:4 sparsity + FP8 quantization using hipSPARSELt (ROCm/AMD only).
38+
39+
SPARSE_1D_DATA_1D_METADATA will pack the quantized_data into a single tensor containing both the quantized data and metadata
40+
as a 1D tensor of r*c/2 + r*c/8 bytes with the following layout: [compressed_data | metadata]
41+
42+
- compressed_data: r*c/2 bytes
43+
The 2 non-zero FP8 values per group of 4 elements, stored row-major:
44+
row0_group0_val0, row0_group0_val1, row0_group1_val0, row0_group1_val1, ..., row1_group0_val0, ...
45+
- metadata: r*c/8 bytes
46+
4 bits per group of 4 elements encoding the positions of the 2 kept values
47+
(2 bits per kept element index), groups packed contiguously row-major:
48+
row0_group0_meta, row0_group1_meta, ..., row1_group0_meta, ...
49+
50+
This packing format will dispatch to torch._cslt_sparse_mm for matmul, with per-tensor scaling passed as alpha.
51+
"""
52+
SPARSE_1D_DATA_1D_METADATA = "sparse_1d_data_1d_metadata"
3653

3754

3855
torch.serialization.add_safe_globals([Float8PackingFormat])

0 commit comments

Comments
 (0)