Skip to content

Commit f0135d7

Browse files
bbeckcafacebook-github-bot
authored andcommitted
Rename Sparse2x4CUTLASSFloat8Tensor to Float8Sparse2x4_2DData2DMetadataTensor (pytorch#4343)
Summary: Rename the CUTLASS float8 sparse tensor class to describe the memory layout: - Class: Sparse2x4CUTLASSFloat8Tensor → Float8Sparse2x4_2DData2DMetadataTensor - Enum: SPARSE_CUTLASS → SPARSE_2D_DATA_2D_METADATA (old value kept for backward compatibility) The old identifiers to Sparse2x4CUTLASSFloat8Tensor will remain importable using backward compatible aliases. Reviewed By: RandySheriff Differential Revision: D102374347
1 parent e8a2ccc commit f0135d7

5 files changed

Lines changed: 19 additions & 16 deletions

File tree

test/quantization/quantize_/workflows/float8/test_sparse_2x4_cutlass_float8_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_fp8_cutlass_sparse(self, compile):
6868
model,
6969
Float8DynamicActivationFloat8WeightConfig(
7070
version=2,
71-
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
71+
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
7272
granularity=PerRow(),
7373
),
7474
)
@@ -89,7 +89,7 @@ def test_fp8_cutlass_sparse_lowering_op_clone(self):
8989
model,
9090
Float8DynamicActivationFloat8WeightConfig(
9191
version=2,
92-
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
92+
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
9393
granularity=PerRow(),
9494
),
9595
)
@@ -114,7 +114,7 @@ def test_fp8_cutlass_sparse_lowering_op_to(self):
114114
model,
115115
Float8DynamicActivationFloat8WeightConfig(
116116
version=2,
117-
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
117+
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
118118
granularity=PerRow(),
119119
),
120120
)

torchao/quantization/quant_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@
6767
IntxUnpackedToInt8Tensor,
6868
QuantizeTensorToFloat8Kwargs,
6969
QuantizeTensorToInt8Kwargs,
70-
Sparse2x4CUTLASSFloat8Tensor,
70+
Float8Sparse2x4_2DData2DMetadataTensor,
7171
)
7272
from torchao.quantization.transform_module import (
7373
_QUANTIZE_CONFIG_HANDLER,
@@ -1258,11 +1258,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
12581258
act_quant_kwargs=act_quant_kwargs,
12591259
)
12601260
return quantized_weight
1261-
elif packing_format == Float8PackingFormat.SPARSE_CUTLASS:
1261+
elif packing_format == Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA:
12621262
assert isinstance(weight_granularity, PerRow), (
12631263
"Sparse packing format only supports per-row quantization"
12641264
)
1265-
quantized_weight = Sparse2x4CUTLASSFloat8Tensor.from_hp(
1265+
quantized_weight = Float8Sparse2x4_2DData2DMetadataTensor.from_hp(
12661266
weight,
12671267
float8_dtype=weight_dtype,
12681268
granularity=weight_granularity,

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
Float8Tensor,
99
QuantizeTensorToFloat8Kwargs,
1010
)
11-
from .float8.sparse_2x4_cutlass_float8_tensor import (
12-
Sparse2x4CUTLASSFloat8Tensor,
11+
from .float8.sparse_2x4_2d_data_2d_metadata_float8_tensor import (
12+
Float8Sparse2x4_2DData2DMetadataTensor,
1313
)
1414
from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
1515
from .int4.int4_packing_format import Int4PackingFormat
@@ -39,6 +39,8 @@
3939
)
4040
from .nf4.nf4_tensor import NF4Tensor, to_nf4
4141

42+
Sparse2x4CUTLASSFloat8Tensor = Float8Sparse2x4_2DData2DMetadataTensor
43+
4244
__all__ = [
4345
"Int4Tensor",
4446
"Int4PreshuffledTensor",
@@ -47,6 +49,7 @@
4749
"Int8Tensor",
4850
"QuantizeTensorToInt8Kwargs",
4951
"Float8Tensor",
52+
"Float8Sparse2x4_2DData2DMetadataTensor",
5053
"Sparse2x4CUTLASSFloat8Tensor",
5154
"Float8Sparse2x4_1DData1DMetadataTensor",
5255
"Float8PackingFormat",

torchao/quantization/quantize_/workflows/float8/float8_packing_format.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ class Float8PackingFormat(str, Enum):
2929
"""
3030
Sparse packing format for 2:4 sparsity + FP8 quantization
3131
32-
SPARSE_CUTLASS will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
32+
SPARSE_2D_DATA_2D_METADATA will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
3333
This packing format will dispatch to `rowwise_scaled_linear_sparse_cutlass_f8f8`, which will fuse the per-row scaling into the sparse matmul.
3434
"""
35-
SPARSE_CUTLASS = "sparse_cutlass"
35+
SPARSE_2D_DATA_2D_METADATA = "sparse_2d_data_2d_metadata"
3636
"""
3737
Sparse packing format for 2:4 sparsity + FP8 quantization using hipSPARSELt (ROCm/AMD only).
3838

torchao/quantization/quantize_/workflows/float8/sparse_2x4_cutlass_float8_tensor.py renamed to torchao/quantization/quantize_/workflows/float8/sparse_2x4_2d_data_2d_metadata_float8_tensor.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
)
3232

3333
__all__ = [
34-
"Sparse2x4CUTLASSFloat8Tensor",
34+
"Float8Sparse2x4_2DData2DMetadataTensor",
3535
]
3636

3737
aten = torch.ops.aten
@@ -40,7 +40,7 @@
4040
from .float8_tensor import QuantizeTensorToFloat8Kwargs
4141

4242

43-
class Sparse2x4CUTLASSFloat8Tensor(TorchAOBaseTensor):
43+
class Float8Sparse2x4_2DData2DMetadataTensor(TorchAOBaseTensor):
4444
"""
4545
Float8 Quantized + 2:4 sparse (weight) Tensor using CUTLASS kernels, with float8 dynamic quantization for activation.
4646
@@ -176,7 +176,7 @@ def from_hp(
176176
# Use CUTLASS rowwise fp8 + 2:4 sparse mm kernel
177177
qdata, sparse_metadata = to_sparse_semi_structured_cutlass_sm9x_f8(data)
178178

179-
return Sparse2x4CUTLASSFloat8Tensor(
179+
return Float8Sparse2x4_2DData2DMetadataTensor(
180180
qdata,
181181
sparse_metadata,
182182
scale,
@@ -186,8 +186,8 @@ def from_hp(
186186
)
187187

188188

189-
implements = Sparse2x4CUTLASSFloat8Tensor.implements
190-
implements_torch_function = Sparse2x4CUTLASSFloat8Tensor.implements_torch_function
189+
implements = Float8Sparse2x4_2DData2DMetadataTensor.implements
190+
implements_torch_function = Float8Sparse2x4_2DData2DMetadataTensor.implements_torch_function
191191

192192

193193
@implements(aten.linear.default)
@@ -251,4 +251,4 @@ def _(func, types, args, kwargs):
251251

252252

253253
# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
254-
torch.serialization.add_safe_globals([Sparse2x4CUTLASSFloat8Tensor])
254+
torch.serialization.add_safe_globals([Float8Sparse2x4_2DData2DMetadataTensor])

0 commit comments

Comments
 (0)