|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +from typing import List, Optional |
| 9 | + |
| 10 | +import torch |
| 11 | + |
| 12 | +from torchao.float8.inference import ( |
| 13 | + FP8Granularity, |
| 14 | +) |
| 15 | +from torchao.ops import ( |
| 16 | + rowwise_scaled_linear_sparse_cutlass_f8f8, |
| 17 | + to_sparse_semi_structured_cutlass_sm9x_f8, |
| 18 | +) |
| 19 | +from torchao.quantization.granularity import PerRow |
| 20 | +from torchao.quantization.quant_primitives import ( |
| 21 | + _choose_scale_float8, |
| 22 | + _quantize_affine_float8, |
| 23 | +) |
| 24 | +from torchao.quantization.quantize_.common import ( |
| 25 | + _choose_quant_func_and_quantize_tensor, |
| 26 | +) |
| 27 | +from torchao.quantization.utils import get_block_size |
| 28 | +from torchao.utils import ( |
| 29 | + TorchAOBaseTensor, |
| 30 | + is_sm_at_least_90, |
| 31 | +) |
| 32 | + |
| 33 | +__all__ = [ |
| 34 | + "Float8Sparse2x4_2DData2DMetadataTensor", |
| 35 | +] |
| 36 | + |
| 37 | +aten = torch.ops.aten |
| 38 | + |
| 39 | + |
| 40 | +from .float8_tensor import QuantizeTensorToFloat8Kwargs |
| 41 | + |
| 42 | + |
| 43 | +class Float8Sparse2x4_2DData2DMetadataTensor(TorchAOBaseTensor): |
| 44 | + """ |
| 45 | + Float8 Quantized + 2:4 sparse (weight) Tensor using CUTLASS kernels, with float8 dynamic quantization for activation. |
| 46 | +
|
| 47 | + Tensor Attributes: |
| 48 | + qdata: float8 raw data |
| 49 | + sparse_metadata: metadata for 2:4 sparse tensor |
| 50 | + scale: the scale for float8 Tensor |
| 51 | +
|
| 52 | + Non-Tensor Attributes: |
| 53 | + block_size (List[int]): the block size for float8 quantization, meaning the shape of the elements |
| 54 | + sharing the same set of quantization parameters (scale), have the same rank as qdata or |
| 55 | + is an empty list (representing per tensor quantization) |
| 56 | + act_quant_kwargs (QuantizeTensorToFloat8Kwargs): the kwargs for Sparse2x4Float8Tensor.from_hp |
| 57 | + packing_format (Float8PackingFormat): the preference for quantize, mm etc. kernel to use, |
| 58 | + by default, this will be chosen for user based on hardware, library availabilities etc. |
| 59 | + dtype: Original Tensor dtype |
| 60 | + """ |
| 61 | + |
| 62 | + tensor_data_names = ["qdata", "sparse_metadata", "scale"] |
| 63 | + tensor_attribute_names = [] |
| 64 | + optional_tensor_attribute_names = [ |
| 65 | + "block_size", |
| 66 | + "act_quant_kwargs", |
| 67 | + "dtype", |
| 68 | + ] |
| 69 | + |
| 70 | + def __new__( |
| 71 | + cls, |
| 72 | + qdata: torch.Tensor, |
| 73 | + sparse_metadata: torch.Tensor, |
| 74 | + scale: torch.Tensor, |
| 75 | + block_size: Optional[List[int]] = None, |
| 76 | + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, |
| 77 | + dtype: Optional[torch.dtype] = None, |
| 78 | + ): |
| 79 | + shape = qdata.shape[0], 2 * qdata.shape[1] |
| 80 | + |
| 81 | + kwargs = {} |
| 82 | + kwargs["device"] = qdata.device |
| 83 | + kwargs["dtype"] = dtype |
| 84 | + kwargs["requires_grad"] = False |
| 85 | + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] |
| 86 | + |
| 87 | + def __init__( |
| 88 | + self, |
| 89 | + qdata: torch.Tensor, |
| 90 | + sparse_metadata: torch.Tensor, |
| 91 | + scale: torch.Tensor, |
| 92 | + block_size: Optional[List[int]] = None, |
| 93 | + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, |
| 94 | + dtype: Optional[torch.dtype] = None, |
| 95 | + ): |
| 96 | + super().__init__() |
| 97 | + self.qdata = qdata |
| 98 | + self.sparse_metadata = sparse_metadata |
| 99 | + self.scale = scale |
| 100 | + self.block_size = block_size |
| 101 | + self.act_quant_kwargs = act_quant_kwargs |
| 102 | + |
| 103 | + def __repr__(self): |
| 104 | + return ( |
| 105 | + f"{self.__class__.__name__}({self.act_quant_kwargs=}, {self.qdata=}, {self.sparse_metadata=}, {self.scale=}, " |
| 106 | + f"{self.block_size=}, " |
| 107 | + f"{self.shape=}, {self.device=}, {self.dtype=})" |
| 108 | + ) |
| 109 | + |
| 110 | + def _quantization_type(self): |
| 111 | + return f"{self.act_quant_kwargs=}, {self.block_size=}, {self.scale.shape=}" |
| 112 | + |
| 113 | + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: |
| 114 | + # No support in CUTLASS to convert back to dense from sparse |
| 115 | + # semi-structured format, so multiplying with identity matrix, |
| 116 | + # and using identity scale factors, for the conversion. |
| 117 | + cols = self.shape[1] |
| 118 | + input = torch.eye(cols, dtype=self.qdata.dtype, device=self.qdata.device) |
| 119 | + input_scale = torch.ones( |
| 120 | + (cols,), dtype=self.scale.dtype, device=self.qdata.device |
| 121 | + ) |
| 122 | + |
| 123 | + out_dtype = torch.bfloat16 |
| 124 | + dense = ( |
| 125 | + rowwise_scaled_linear_sparse_cutlass_f8f8( |
| 126 | + input, |
| 127 | + input_scale, |
| 128 | + self.qdata, |
| 129 | + self.sparse_metadata, |
| 130 | + self.scale, |
| 131 | + out_dtype=out_dtype, |
| 132 | + ) |
| 133 | + .to(output_dtype) |
| 134 | + .t() |
| 135 | + .contiguous() |
| 136 | + ) |
| 137 | + return dense |
| 138 | + |
| 139 | + @classmethod |
| 140 | + def from_hp( |
| 141 | + cls, |
| 142 | + hp_tensor: torch.Tensor, |
| 143 | + float8_dtype: torch.dtype = torch.float8_e4m3fn, |
| 144 | + granularity: FP8Granularity = PerRow(), |
| 145 | + hp_value_lb: Optional[float] = None, |
| 146 | + hp_value_ub: Optional[float] = None, |
| 147 | + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, |
| 148 | + ): |
| 149 | + block_size = get_block_size(hp_tensor.shape, granularity) |
| 150 | + block_size = list(block_size) |
| 151 | + scale = _choose_scale_float8( |
| 152 | + hp_tensor, |
| 153 | + float8_dtype=float8_dtype, |
| 154 | + block_size=block_size, |
| 155 | + hp_value_lb=hp_value_lb, |
| 156 | + hp_value_ub=hp_value_ub, |
| 157 | + ) |
| 158 | + data = _quantize_affine_float8(hp_tensor, scale, float8_dtype) |
| 159 | + hp_dtype = hp_tensor.dtype |
| 160 | + |
| 161 | + assert is_sm_at_least_90(), ( |
| 162 | + "CUTLASS sparse kernel requires hardware >= SM 9.0 (>= H100)" |
| 163 | + ) |
| 164 | + assert isinstance(granularity, PerRow), ( |
| 165 | + "CUTLASS sparse kernel only supports per-row quantization" |
| 166 | + ) |
| 167 | + # CUTLASS path only supports quantizing along the last dim |
| 168 | + assert granularity.dim in (-1, len(hp_tensor.shape) - 1), ( |
| 169 | + "CUTLASS sparse kernel only supports quantizing along the last dimension" |
| 170 | + ) |
| 171 | + assert float8_dtype == torch.float8_e4m3fn, ( |
| 172 | + "CUTLASS sparse kernel only supports float8_e4m3fn dtype" |
| 173 | + ) |
| 174 | + assert hp_value_lb is None, "CUTLASS sparse kernel does not support hp_value_lb" |
| 175 | + |
| 176 | + # Use CUTLASS rowwise fp8 + 2:4 sparse mm kernel |
| 177 | + qdata, sparse_metadata = to_sparse_semi_structured_cutlass_sm9x_f8(data) |
| 178 | + |
| 179 | + return Float8Sparse2x4_2DData2DMetadataTensor( |
| 180 | + qdata, |
| 181 | + sparse_metadata, |
| 182 | + scale, |
| 183 | + block_size=block_size, |
| 184 | + act_quant_kwargs=act_quant_kwargs, |
| 185 | + dtype=hp_dtype, |
| 186 | + ) |
| 187 | + |
| 188 | + |
| 189 | +implements = Float8Sparse2x4_2DData2DMetadataTensor.implements |
| 190 | +implements_torch_function = Float8Sparse2x4_2DData2DMetadataTensor.implements_torch_function |
| 191 | + |
| 192 | + |
| 193 | +@implements(aten.linear.default) |
| 194 | +@implements_torch_function(torch.nn.functional.linear) |
| 195 | +def _(func, types, args, kwargs): |
| 196 | + input_tensor = kwargs.get("input", args[0] if len(args) > 0 else None) |
| 197 | + weight_tensor = kwargs.get("weight", args[1] if len(args) > 1 else None) |
| 198 | + bias = kwargs.get("bias", args[2] if len(args) > 2 else None) |
| 199 | + |
| 200 | + assert input_tensor is not None, "input tensor must not be None" |
| 201 | + assert weight_tensor is not None, "weight tensor must not be None" |
| 202 | + |
| 203 | + act_quant_kwargs = weight_tensor.act_quant_kwargs |
| 204 | + # quantize activation, if `act_quant_kwargs` is specified |
| 205 | + if act_quant_kwargs is not None: |
| 206 | + assert not isinstance(input_tensor, TorchAOBaseTensor), ( |
| 207 | + "input tensor was already quantized" |
| 208 | + ) |
| 209 | + input_tensor = _choose_quant_func_and_quantize_tensor( |
| 210 | + input_tensor, act_quant_kwargs |
| 211 | + ) |
| 212 | + input = input_tensor.qdata |
| 213 | + input_scale = input_tensor.scale.squeeze(1) |
| 214 | + weight = weight_tensor.qdata |
| 215 | + weight_meta = weight_tensor.sparse_metadata |
| 216 | + weight_scale = weight_tensor.scale.squeeze(1) |
| 217 | + out_dtype = input_tensor.dtype |
| 218 | + |
| 219 | + out = rowwise_scaled_linear_sparse_cutlass_f8f8( |
| 220 | + input, input_scale, weight, weight_meta, weight_scale, bias, out_dtype |
| 221 | + ) |
| 222 | + return out |
| 223 | + |
| 224 | + |
| 225 | +@implements(aten.to.dtype_layout) |
| 226 | +def _(func, types, args, kwargs): |
| 227 | + return ( |
| 228 | + args[0] |
| 229 | + .dequantize() |
| 230 | + .to( |
| 231 | + *args[1:], |
| 232 | + dtype=kwargs.get("dtype", args[0].dtype), |
| 233 | + device=kwargs.get("device", args[0].device), |
| 234 | + ) |
| 235 | + ) |
| 236 | + |
| 237 | + |
| 238 | +# implement to.dtype for cases where dtype specified in args[1] |
| 239 | +@implements(aten.to.dtype) |
| 240 | +def _(func, types, args, kwargs): |
| 241 | + dtype = kwargs.get("dtype", args[1] if len(args) > 1 else None) |
| 242 | + assert dtype is not None, "dtype must not be None" |
| 243 | + |
| 244 | + return ( |
| 245 | + args[0] |
| 246 | + .dequantize() |
| 247 | + .to( |
| 248 | + dtype=dtype, |
| 249 | + ) |
| 250 | + ) |
| 251 | + |
| 252 | + |
| 253 | +# Allow a model with Float8Tensor weights to be loaded with `weights_only=True` |
| 254 | +torch.serialization.add_safe_globals([Float8Sparse2x4_2DData2DMetadataTensor]) |
0 commit comments