Skip to content

Commit e8a2ccc

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Add Sparse2x4HIPSPARSELTFloat8Tensor (pytorch#4277)
Summary: X-link: pytorch/pytorch#180312 What: Adding a new tensor subclass for FP8 2:4 sparsity via hipSPARSELt (ROCm only). Packs compressed values + metadata into a single tensor with `_cslt_compress` and dispatches through `_cslt_sparse_mm` with `A_scale * B_scale` as `alpha`. Why: This hipSPARSELt path differs enough in packing and kernel routing from CUTLASS to warrant a dedicated path. Reference: https://rocm.blogs.amd.com/artificial-intelligence/introduce_hipsparselt/README.html Reviewed By: RandySheriff Differential Revision: D100640267
1 parent 6529fca commit e8a2ccc

5 files changed

Lines changed: 457 additions & 0 deletions

File tree

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import copy
7+
import logging
8+
import unittest
9+
10+
import torch
11+
from torch import nn
12+
from torch.testing._internal import common_utils
13+
14+
try:
15+
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8_SPARSE
16+
except ImportError:
17+
PLATFORM_SUPPORTS_FP8_SPARSE = False
18+
from torchao.quantization import (
19+
Float8DynamicActivationFloat8WeightConfig,
20+
)
21+
from torchao.quantization.granularity import PerTensor
22+
from torchao.quantization.quant_api import (
23+
quantize_,
24+
)
25+
from torchao.quantization.quantize_.workflows import (
26+
Float8PackingFormat,
27+
)
28+
from torchao.quantization.utils import compute_error
29+
from torchao.sparsity import apply_fake_sparsity
30+
from torchao.utils import (
31+
is_ROCM,
32+
torch_version_at_least,
33+
)
34+
35+
logging.basicConfig(
36+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
37+
)
38+
39+
40+
@unittest.skipIf(
41+
not torch_version_at_least("2.10.0"),
42+
"Need torch >= 2.10.0",
43+
)
44+
class TestFloat8Sparse2x4_1DData1DMetadataTensor(common_utils.TestCase):
45+
def setUp(self):
46+
if not is_ROCM():
47+
self.skipTest("hipSPARSELt path requires ROCm")
48+
if not PLATFORM_SUPPORTS_FP8_SPARSE:
49+
self.skipTest("Need platform with FP8 sparse support (hipSPARSELt)")
50+
51+
@common_utils.parametrize("compile", [True, False])
52+
def test_fp8_hipsparselt_sparse(self, compile):
53+
with torch.inference_mode():
54+
input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda")
55+
model = (
56+
nn.Sequential(
57+
nn.Linear(256, 1024),
58+
nn.Linear(1024, 256),
59+
)
60+
.bfloat16()
61+
.cuda()
62+
.eval()
63+
)
64+
65+
apply_fake_sparsity(model)
66+
baseline_result = model(input)
67+
model_copy = copy.deepcopy(model)
68+
69+
# Quantized (dense)
70+
quantize_(
71+
model_copy,
72+
Float8DynamicActivationFloat8WeightConfig(
73+
granularity=PerTensor(),
74+
),
75+
)
76+
dense_result = model_copy(input)
77+
dense_sqnr = compute_error(baseline_result, dense_result)
78+
79+
# Sparse + quantized
80+
quantize_(
81+
model,
82+
Float8DynamicActivationFloat8WeightConfig(
83+
version=2,
84+
packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA,
85+
granularity=PerTensor(),
86+
),
87+
)
88+
if compile:
89+
model = torch.compile(model)
90+
sparse_result = model(input)
91+
sparse_sqnr = compute_error(baseline_result, sparse_result)
92+
93+
self.assertEqual(dense_sqnr, sparse_sqnr)
94+
95+
def test_fp8_hipsparselt_sparse_lowering_op_clone(self):
96+
"""Validates clone dispatch correctly copies both sparse data and scale metadata."""
97+
with torch.inference_mode():
98+
model = nn.Linear(256, 1024).half().cuda().eval()
99+
apply_fake_sparsity(model)
100+
quantize_(
101+
model,
102+
Float8DynamicActivationFloat8WeightConfig(
103+
version=2,
104+
packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA,
105+
granularity=PerTensor(),
106+
),
107+
)
108+
109+
original = model.weight.dequantize()
110+
cloned = model.weight.clone().dequantize()
111+
112+
for o, c in zip(original, cloned):
113+
self.assertEqual(o, c)
114+
115+
def test_fp8_hipsparselt_sparse_lowering_op_to(self):
116+
"""Validates both to.dtype_layout and to.dtype dispatch paths correctly dequantize the sparse tensor."""
117+
with torch.inference_mode():
118+
model = nn.Linear(256, 1024).half().cuda().eval()
119+
apply_fake_sparsity(model)
120+
model_copy = copy.deepcopy(model)
121+
expected = model_copy.weight.to(dtype=torch.float)
122+
123+
quantize_(
124+
model,
125+
Float8DynamicActivationFloat8WeightConfig(
126+
version=2,
127+
packing_format=Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA,
128+
granularity=PerTensor(),
129+
),
130+
)
131+
132+
original_by_to_dtype_layout = torch.ops.aten.to.dtype_layout(
133+
model.weight,
134+
dtype=torch.float,
135+
layout=torch.strided,
136+
)
137+
torch.testing.assert_close(
138+
expected, original_by_to_dtype_layout, atol=1e-1, rtol=1e-1
139+
)
140+
141+
original_by_to_dtype = torch.ops.aten.to.dtype(
142+
model.weight,
143+
torch.float,
144+
)
145+
torch.testing.assert_close(
146+
expected, original_by_to_dtype, atol=1e-1, rtol=1e-1
147+
)
148+
149+
150+
common_utils.instantiate_parametrized_tests(TestFloat8Sparse2x4_1DData1DMetadataTensor)
151+
152+
if __name__ == "__main__":
153+
unittest.main()

torchao/quantization/quant_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
)
5353
from torchao.quantization.quantize_.workflows import (
5454
Float8PackingFormat,
55+
Float8Sparse2x4_1DData1DMetadataTensor,
5556
Float8Tensor,
5657
Int4ChooseQParamsAlgorithm,
5758
Int4PackingFormat,
@@ -1268,6 +1269,17 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
12681269
act_quant_kwargs=act_quant_kwargs,
12691270
)
12701271
return quantized_weight
1272+
elif packing_format == Float8PackingFormat.SPARSE_1D_DATA_1D_METADATA:
1273+
assert isinstance(weight_granularity, PerTensor), (
1274+
"Sparse 1D data 1D metadata packing format only supports per-tensor quantization"
1275+
)
1276+
quantized_weight = Float8Sparse2x4_1DData1DMetadataTensor.from_hp(
1277+
weight,
1278+
float8_dtype=weight_dtype,
1279+
granularity=weight_granularity,
1280+
act_quant_kwargs=act_quant_kwargs,
1281+
)
1282+
return quantized_weight
12711283

12721284

12731285
@register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig)

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from .float8.float8_packing_format import (
22
Float8PackingFormat,
33
)
4+
from .float8.float8_sparse_2x4_1d_data_1d_metadata_tensor import (
5+
Float8Sparse2x4_1DData1DMetadataTensor,
6+
)
47
from .float8.float8_tensor import (
58
Float8Tensor,
69
QuantizeTensorToFloat8Kwargs,
@@ -45,6 +48,7 @@
4548
"QuantizeTensorToInt8Kwargs",
4649
"Float8Tensor",
4750
"Sparse2x4CUTLASSFloat8Tensor",
51+
"Float8Sparse2x4_1DData1DMetadataTensor",
4852
"Float8PackingFormat",
4953
"QuantizeTensorToFloat8Kwargs",
5054
"Int8Tensor",

torchao/quantization/quantize_/workflows/float8/float8_packing_format.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ class Float8PackingFormat(str, Enum):
3333
This packing format will dispatch to `rowwise_scaled_linear_sparse_cutlass_f8f8`, which will fuse the per-row scaling into the sparse matmul.
3434
"""
3535
SPARSE_CUTLASS = "sparse_cutlass"
36+
"""
37+
Sparse packing format for 2:4 sparsity + FP8 quantization using hipSPARSELt (ROCm/AMD only).
38+
39+
SPARSE_1D_DATA_1D_METADATA will pack the quantized_data into a single tensor containing both the quantized data and metadata
40+
as a 1D tensor of r*c/2 + r*c/8 bytes with the following layout: [compressed_data | metadata]
41+
42+
- compressed_data: r*c/2 bytes
43+
The 2 non-zero FP8 values per group of 4 elements, stored row-major:
44+
row0_group0_val0, row0_group0_val1, row0_group1_val0, row0_group1_val1, ..., row1_group0_val0, ...
45+
- metadata: r*c/8 bytes
46+
4 bits per group of 4 elements encoding the positions of the 2 kept values
47+
(2 bits per kept element index), groups packed contiguously row-major:
48+
row0_group0_meta, row0_group1_meta, ..., row1_group0_meta, ...
49+
50+
This packing format will dispatch to torch._cslt_sparse_mm for matmul, with per-tensor scaling passed as alpha.
51+
"""
52+
SPARSE_1D_DATA_1D_METADATA = "sparse_1d_data_1d_metadata"
3653

3754

3855
torch.serialization.add_safe_globals([Float8PackingFormat])

0 commit comments

Comments
 (0)