Skip to content

Commit b1bcda2

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Rename Sparse2x4CUTLASSFloat8Tensor to Float8Sparse2x4_2DData2DMetadataTensor
Summary: Rename the CUTLASS float8 sparse tensor class to describe the memory layout: - Class: Sparse2x4CUTLASSFloat8Tensor → Float8Sparse2x4_2DData2DMetadataTensor - Enum: SPARSE_CUTLASS → SPARSE_2D_DATA_2D_METADATA (old value kept for backward compatibility) The old identifiers to Sparse2x4CUTLASSFloat8Tensor will remain importable using backward compatible aliases. Differential Revision: D102374347
1 parent 4a5e3be commit b1bcda2

6 files changed

Lines changed: 268 additions & 254 deletions

File tree

test/quantization/quantize_/workflows/float8/test_sparse_2x4_cutlass_float8_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_fp8_cutlass_sparse(self, compile):
6868
model,
6969
Float8DynamicActivationFloat8WeightConfig(
7070
version=2,
71-
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
71+
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
7272
granularity=PerRow(),
7373
),
7474
)
@@ -89,7 +89,7 @@ def test_fp8_cutlass_sparse_lowering_op_clone(self):
8989
model,
9090
Float8DynamicActivationFloat8WeightConfig(
9191
version=2,
92-
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
92+
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
9393
granularity=PerRow(),
9494
),
9595
)
@@ -114,7 +114,7 @@ def test_fp8_cutlass_sparse_lowering_op_to(self):
114114
model,
115115
Float8DynamicActivationFloat8WeightConfig(
116116
version=2,
117-
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
117+
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
118118
granularity=PerRow(),
119119
),
120120
)

torchao/quantization/quant_api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
IntxUnpackedToInt8Tensor,
6868
QuantizeTensorToFloat8Kwargs,
6969
QuantizeTensorToInt8Kwargs,
70+
Float8Sparse2x4_2DData2DMetadataTensor,
7071
Sparse2x4CUTLASSFloat8Tensor,
7172
)
7273
from torchao.quantization.transform_module import (
@@ -1258,11 +1259,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
12581259
act_quant_kwargs=act_quant_kwargs,
12591260
)
12601261
return quantized_weight
1261-
elif packing_format == Float8PackingFormat.SPARSE_CUTLASS:
1262+
elif packing_format == Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA:
12621263
assert isinstance(weight_granularity, PerRow), (
12631264
"Sparse packing format only supports per-row quantization"
12641265
)
1265-
quantized_weight = Sparse2x4CUTLASSFloat8Tensor.from_hp(
1266+
quantized_weight = Float8Sparse2x4_2DData2DMetadataTensor.from_hp(
12661267
weight,
12671268
float8_dtype=weight_dtype,
12681269
granularity=weight_granularity,

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
Float8Tensor,
99
QuantizeTensorToFloat8Kwargs,
1010
)
11+
from .float8.sparse_2x4_2d_data_2d_metadata_float8_tensor import (
12+
Float8Sparse2x4_2DData2DMetadataTensor,
13+
)
1114
from .float8.sparse_2x4_cutlass_float8_tensor import (
1215
Sparse2x4CUTLASSFloat8Tensor,
1316
)
@@ -47,6 +50,7 @@
4750
"Int8Tensor",
4851
"QuantizeTensorToInt8Kwargs",
4952
"Float8Tensor",
53+
"Float8Sparse2x4_2DData2DMetadataTensor",
5054
"Sparse2x4CUTLASSFloat8Tensor",
5155
"Float8Sparse2x4_1DData1DMetadataTensor",
5256
"Float8PackingFormat",

torchao/quantization/quantize_/workflows/float8/float8_packing_format.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class Float8PackingFormat(str, Enum):
2929
"""
3030
Sparse packing format for 2:4 sparsity + FP8 quantization
3131
32-
SPARSE_CUTLASS will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
32+
SPARSE_2D_DATA_2D_METADATA will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
3333
This packing format will dispatch to `rowwise_scaled_linear_sparse_cutlass_f8f8`, which will fuse the per-row scaling into the sparse matmul.
3434
"""
35-
SPARSE_CUTLASS = "sparse_cutlass"
35+
SPARSE_2D_DATA_2D_METADATA = "sparse_2d_data_2d_metadata"
3636
"""
3737
Sparse packing format for 2:4 sparsity + FP8 quantization using hipSPARSELt (ROCm/AMD only).
3838
Lines changed: 254 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,254 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import List, Optional
9+
10+
import torch
11+
12+
from torchao.float8.inference import (
13+
FP8Granularity,
14+
)
15+
from torchao.ops import (
16+
rowwise_scaled_linear_sparse_cutlass_f8f8,
17+
to_sparse_semi_structured_cutlass_sm9x_f8,
18+
)
19+
from torchao.quantization.granularity import PerRow
20+
from torchao.quantization.quant_primitives import (
21+
_choose_scale_float8,
22+
_quantize_affine_float8,
23+
)
24+
from torchao.quantization.quantize_.common import (
25+
_choose_quant_func_and_quantize_tensor,
26+
)
27+
from torchao.quantization.utils import get_block_size
28+
from torchao.utils import (
29+
TorchAOBaseTensor,
30+
is_sm_at_least_90,
31+
)
32+
33+
__all__ = [
34+
"Float8Sparse2x4_2DData2DMetadataTensor",
35+
]
36+
37+
aten = torch.ops.aten
38+
39+
40+
from .float8_tensor import QuantizeTensorToFloat8Kwargs
41+
42+
43+
class Float8Sparse2x4_2DData2DMetadataTensor(TorchAOBaseTensor):
44+
"""
45+
Float8 Quantized + 2:4 sparse (weight) Tensor using CUTLASS kernels, with float8 dynamic quantization for activation.
46+
47+
Tensor Attributes:
48+
qdata: float8 raw data
49+
sparse_metadata: metadata for 2:4 sparse tensor
50+
scale: the scale for float8 Tensor
51+
52+
Non-Tensor Attributes:
53+
block_size (List[int]): the block size for float8 quantization, meaning the shape of the elements
54+
sharing the same set of quantization parameters (scale), have the same rank as qdata or
55+
is an empty list (representing per tensor quantization)
56+
act_quant_kwargs (QuantizeTensorToFloat8Kwargs): the kwargs for Sparse2x4Float8Tensor.from_hp
57+
packing_format (Float8PackingFormat): the preference for quantize, mm etc. kernel to use,
58+
by default, this will be chosen for user based on hardware, library availabilities etc.
59+
dtype: Original Tensor dtype
60+
"""
61+
62+
tensor_data_names = ["qdata", "sparse_metadata", "scale"]
63+
tensor_attribute_names = []
64+
optional_tensor_attribute_names = [
65+
"block_size",
66+
"act_quant_kwargs",
67+
"dtype",
68+
]
69+
70+
def __new__(
71+
cls,
72+
qdata: torch.Tensor,
73+
sparse_metadata: torch.Tensor,
74+
scale: torch.Tensor,
75+
block_size: Optional[List[int]] = None,
76+
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
77+
dtype: Optional[torch.dtype] = None,
78+
):
79+
shape = qdata.shape[0], 2 * qdata.shape[1]
80+
81+
kwargs = {}
82+
kwargs["device"] = qdata.device
83+
kwargs["dtype"] = dtype
84+
kwargs["requires_grad"] = False
85+
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
86+
87+
def __init__(
88+
self,
89+
qdata: torch.Tensor,
90+
sparse_metadata: torch.Tensor,
91+
scale: torch.Tensor,
92+
block_size: Optional[List[int]] = None,
93+
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
94+
dtype: Optional[torch.dtype] = None,
95+
):
96+
super().__init__()
97+
self.qdata = qdata
98+
self.sparse_metadata = sparse_metadata
99+
self.scale = scale
100+
self.block_size = block_size
101+
self.act_quant_kwargs = act_quant_kwargs
102+
103+
def __repr__(self):
104+
return (
105+
f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.sparse_metadata=}, {self.scale=}, "
106+
f"{self.block_size=}, "
107+
f"{self.shape=}, {self.device=}, {self.dtype=})"
108+
)
109+
110+
def _quantization_type(self):
111+
return f"{self.act_quant_kwargs=}, {self.block_size=}, {self.scale.shape=}"
112+
113+
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
114+
# No support in CUTLASS to convert back to dense from sparse
115+
# semi-structured format, so multiplying with identity matrix,
116+
# and using identity scale factors, for the conversion.
117+
cols = self.shape[1]
118+
input = torch.eye(cols, dtype=self.qdata.dtype, device=self.qdata.device)
119+
input_scale = torch.ones(
120+
(cols,), dtype=self.scale.dtype, device=self.qdata.device
121+
)
122+
123+
out_dtype = torch.bfloat16
124+
dense = (
125+
rowwise_scaled_linear_sparse_cutlass_f8f8(
126+
input,
127+
input_scale,
128+
self.qdata,
129+
self.sparse_metadata,
130+
self.scale,
131+
out_dtype=out_dtype,
132+
)
133+
.to(output_dtype)
134+
.t()
135+
.contiguous()
136+
)
137+
return dense
138+
139+
@classmethod
140+
def from_hp(
141+
cls,
142+
hp_tensor: torch.Tensor,
143+
float8_dtype: torch.dtype = torch.float8_e4m3fn,
144+
granularity: FP8Granularity = PerRow(),
145+
hp_value_lb: Optional[float] = None,
146+
hp_value_ub: Optional[float] = None,
147+
act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None,
148+
):
149+
block_size = get_block_size(hp_tensor.shape, granularity)
150+
block_size = list(block_size)
151+
scale = _choose_scale_float8(
152+
hp_tensor,
153+
float8_dtype=float8_dtype,
154+
block_size=block_size,
155+
hp_value_lb=hp_value_lb,
156+
hp_value_ub=hp_value_ub,
157+
)
158+
data = _quantize_affine_float8(hp_tensor, scale, float8_dtype)
159+
hp_dtype = hp_tensor.dtype
160+
161+
assert is_sm_at_least_90(), (
162+
"CUTLASS sparse kernel requires hardware >= SM 9.0 (>= H100)"
163+
)
164+
assert isinstance(granularity, PerRow), (
165+
"CUTLASS sparse kernel only supports per-row quantization"
166+
)
167+
# CUTLASS path only supports quantizing along the last dim
168+
assert granularity.dim in (-1, len(hp_tensor.shape) - 1), (
169+
"CUTLASS sparse kernel only supports quantizing along the last dimension"
170+
)
171+
assert float8_dtype == torch.float8_e4m3fn, (
172+
"CUTLASS sparse kernel only supports float8_e4m3fn dtype"
173+
)
174+
assert hp_value_lb is None, "CUTLASS sparse kernel does not support hp_value_lb"
175+
176+
# Use CUTLASS rowwise fp8 + 2:4 sparse mm kernel
177+
qdata, sparse_metadata = to_sparse_semi_structured_cutlass_sm9x_f8(data)
178+
179+
return Float8Sparse2x4_2DData2DMetadataTensor(
180+
qdata,
181+
sparse_metadata,
182+
scale,
183+
block_size=block_size,
184+
act_quant_kwargs=act_quant_kwargs,
185+
dtype=hp_dtype,
186+
)
187+
188+
189+
implements = Float8Sparse2x4_2DData2DMetadataTensor.implements
190+
implements_torch_function = Float8Sparse2x4_2DData2DMetadataTensor.implements_torch_function
191+
192+
193+
@implements(aten.linear.default)
194+
@implements_torch_function(torch.nn.functional.linear)
195+
def _(func, types, args, kwargs):
196+
input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None)
197+
weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None)
198+
bias = kwargs.get("bias", args[2] if len(args) > 2 else None)
199+
200+
assert input_tensor is not None, "input tensor must not be None"
201+
assert weight_tensor is not None, "weight tensor must not be None"
202+
203+
act_quant_kwargs = weight_tensor.act_quant_kwargs
204+
# quantize activation, if `act_quant_kwargs` is specified
205+
if act_quant_kwargs is not None:
206+
assert not isinstance(input_tensor, TorchAOBaseTensor), (
207+
"input tensor was already quantized"
208+
)
209+
input_tensor = _choose_quant_func_and_quantize_tensor(
210+
input_tensor, act_quant_kwargs
211+
)
212+
input = input_tensor.qdata
213+
input_scale = input_tensor.scale.squeeze(1)
214+
weight = weight_tensor.qdata
215+
weight_meta = weight_tensor.sparse_metadata
216+
weight_scale = weight_tensor.scale.squeeze(1)
217+
out_dtype = input_tensor.dtype
218+
219+
out = rowwise_scaled_linear_sparse_cutlass_f8f8(
220+
input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype
221+
)
222+
return out
223+
224+
225+
@implements(aten.to.dtype_layout)
226+
def _(func, types, args, kwargs):
227+
return (
228+
args[0]
229+
.dequantize()
230+
.to(
231+
*args[1:],
232+
dtype=kwargs.get("dtype", args[0].dtype),
233+
device=kwargs.get("device", args[0].device),
234+
)
235+
)
236+
237+
238+
# implement to.dtype for cases where dtype specified in args[1]
239+
@implements(aten.to.dtype)
240+
def _(func, types, args, kwargs):
241+
dtype = kwargs.get("dtype", args[1] if len(args) > 1 else None)
242+
assert dtype is not None, "dtype must not be None"
243+
244+
return (
245+
args[0]
246+
.dequantize()
247+
.to(
248+
dtype=dtype,
249+
)
250+
)
251+
252+
253+
# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
254+
torch.serialization.add_safe_globals([Float8Sparse2x4_2DData2DMetadataTensor])

0 commit comments

Comments
 (0)