Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you rename to test_float8_sparse_2x4_2d_data_2d_metadata_tensor.py

Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
not torch_version_at_least("2.10.0"),
"Need torch >= 2.10.0 for availability of ABI kernels",
)
class TestSparse2x4Float8Tensor(common_utils.TestCase):
class TestFloat8Sparse2x4Tensor(common_utils.TestCase):
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
Expand Down Expand Up @@ -68,7 +68,7 @@ def test_fp8_cutlass_sparse(self, compile):
model,
Float8DynamicActivationFloat8WeightConfig(
version=2,
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
granularity=PerRow(),
),
)
Expand All @@ -89,7 +89,7 @@ def test_fp8_cutlass_sparse_lowering_op_clone(self):
model,
Float8DynamicActivationFloat8WeightConfig(
version=2,
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
granularity=PerRow(),
),
)
Expand All @@ -114,7 +114,7 @@ def test_fp8_cutlass_sparse_lowering_op_to(self):
model,
Float8DynamicActivationFloat8WeightConfig(
version=2,
packing_format=Float8PackingFormat.SPARSE_CUTLASS,
packing_format=Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA,
granularity=PerRow(),
),
)
Expand All @@ -137,7 +137,7 @@ def test_fp8_cutlass_sparse_lowering_op_to(self):
)


common_utils.instantiate_parametrized_tests(TestSparse2x4Float8Tensor)
common_utils.instantiate_parametrized_tests(TestFloat8Sparse2x4Tensor)

if __name__ == "__main__":
unittest.main()
6 changes: 3 additions & 3 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from torchao.quantization.quantize_.workflows import (
Float8PackingFormat,
Float8Sparse2x4_1DData1DMetadataTensor,
Float8Sparse2x4_2DData2DMetadataTensor,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4PackingFormat,
Expand All @@ -67,7 +68,6 @@
IntxUnpackedToInt8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
Sparse2x4CUTLASSFloat8Tensor,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1258,11 +1258,11 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
act_quant_kwargs=act_quant_kwargs,
)
return quantized_weight
elif packing_format == Float8PackingFormat.SPARSE_CUTLASS:
elif packing_format == Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA:
assert isinstance(weight_granularity, PerRow), (
"Sparse packing format only supports per-row quantization"
)
quantized_weight = Sparse2x4CUTLASSFloat8Tensor.from_hp(
quantized_weight = Float8Sparse2x4_2DData2DMetadataTensor.from_hp(
weight,
float8_dtype=weight_dtype,
granularity=weight_granularity,
Expand Down
33 changes: 17 additions & 16 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from .float8.float8_sparse_2x4_1d_data_1d_metadata_tensor import (
Float8Sparse2x4_1DData1DMetadataTensor,
)
from .float8.float8_sparse_2x4_2d_data_2d_metadata_tensor import (
Float8Sparse2x4_2DData2DMetadataTensor,
)
from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
from .float8.sparse_2x4_cutlass_float8_tensor import (
Sparse2x4CUTLASSFloat8Tensor,
)
from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
from .int4.int4_packing_format import Int4PackingFormat
from .int4.int4_plain_int32_tensor import (
Expand Down Expand Up @@ -39,26 +39,27 @@
)
from .nf4.nf4_tensor import NF4Tensor, to_nf4

Sparse2x4CUTLASSFloat8Tensor = Float8Sparse2x4_2DData2DMetadataTensor

__all__ = [
"Int4Tensor",
"Int4PreshuffledTensor",
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Int8Tensor",
"QuantizeTensorToInt8Kwargs",
"Float8Tensor",
"Sparse2x4CUTLASSFloat8Tensor",
"Float8Sparse2x4_1DData1DMetadataTensor",
"Float8PackingFormat",
"QuantizeTensorToFloat8Kwargs",
"Int8Tensor",
"QuantizeTensorToInt8Kwargs",
"Float8Sparse2x4_1DData1DMetadataTensor",
"Float8Sparse2x4_2DData2DMetadataTensor",
"Float8Tensor",
"Int4ChooseQParamsAlgorithm",
"Int4PackingFormat",
"Int4PlainInt32Tensor",
"Int4PreshuffledTensor",
"Int4Tensor",
"Int4TilePackedTo4dTensor",
"Int8Tensor",
"IntxChooseQParamsAlgorithm",
"IntxOpaqueTensor",
"IntxPackingFormat",
"IntxUnpackedToInt8Tensor",
"IntxOpaqueTensor",
"NF4Tensor",
"QuantizeTensorToFloat8Kwargs",
"QuantizeTensorToInt8Kwargs",
"Sparse2x4CUTLASSFloat8Tensor",
"to_nf4",
Comment on lines -43 to 64
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional. Sorted and removed duplicates.

]
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ class Float8PackingFormat(str, Enum):
"""
Sparse packing format for 2:4 sparsity + FP8 quantization

SPARSE_CUTLASS will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
SPARSE_2D_DATA_2D_METADATA will pack the quantized_data into two tensors, qdata and sparse_metadata, for the specified values and metadata respectively.
This packing format will dispatch to `rowwise_scaled_linear_sparse_cutlass_f8f8`, which will fuse the per-row scaling into the sparse matmul.
"""
SPARSE_CUTLASS = "sparse_cutlass"
SPARSE_2D_DATA_2D_METADATA = "sparse_2d_data_2d_metadata"
"""
Sparse packing format for 2:4 sparsity + FP8 quantization using hipSPARSELt (ROCm/AMD only).

Expand All @@ -52,4 +52,6 @@ class Float8PackingFormat(str, Enum):
SPARSE_1D_DATA_1D_METADATA = "sparse_1d_data_1d_metadata"
Comment thread
bbeckca marked this conversation as resolved.


Float8PackingFormat.SPARSE_CUTLASS = Float8PackingFormat.SPARSE_2D_DATA_2D_METADATA

torch.serialization.add_safe_globals([Float8PackingFormat])
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Apr 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

name should start with float8 I think

float8_sparse_2x4_2d_data_2d_metadata_tensor.py

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. Update references to reflect prefix "float8".

Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)

__all__ = [
"Sparse2x4CUTLASSFloat8Tensor",
"Float8Sparse2x4_2DData2DMetadataTensor",
]

aten = torch.ops.aten
Expand All @@ -40,7 +40,7 @@
from .float8_tensor import QuantizeTensorToFloat8Kwargs


class Sparse2x4CUTLASSFloat8Tensor(TorchAOBaseTensor):
class Float8Sparse2x4_2DData2DMetadataTensor(TorchAOBaseTensor):
"""
Float8 Quantized + 2:4 sparse (weight) Tensor using CUTLASS kernels, with float8 dynamic quantization for activation.

Expand Down Expand Up @@ -176,7 +176,7 @@ def from_hp(
# Use CUTLASS rowwise fp8 + 2:4 sparse mm kernel
qdata, sparse_metadata = to_sparse_semi_structured_cutlass_sm9x_f8(data)

return Sparse2x4CUTLASSFloat8Tensor(
return Float8Sparse2x4_2DData2DMetadataTensor(
qdata,
sparse_metadata,
scale,
Expand All @@ -186,8 +186,10 @@ def from_hp(
)


implements = Sparse2x4CUTLASSFloat8Tensor.implements
implements_torch_function = Sparse2x4CUTLASSFloat8Tensor.implements_torch_function
implements = Float8Sparse2x4_2DData2DMetadataTensor.implements
implements_torch_function = (
Float8Sparse2x4_2DData2DMetadataTensor.implements_torch_function
)


@implements(aten.linear.default)
Expand Down Expand Up @@ -251,4 +253,4 @@ def _(func, types, args, kwargs):


# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([Sparse2x4CUTLASSFloat8Tensor])
torch.serialization.add_safe_globals([Float8Sparse2x4_2DData2DMetadataTensor])
Loading