Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
)
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
check_cpu_version,
check_xpu_version,
get_current_accelerator_device,
is_fbcode,
is_sm_at_least_89,
Expand All @@ -36,11 +34,6 @@ def get_quantization_functions(
do_sparse: bool, do_int4: bool, device: str = "cuda", int4_zp_int: bool = False
):
base_functions = []
if do_int4:
if check_cpu_version(device):
pass
elif check_xpu_version(device):
pass

if is_sm_at_least_89():
base_functions.append(Float8WeightOnlyConfig())
Expand Down
14 changes: 6 additions & 8 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@
groupwise_affine_quantize_tensor_from_qparams,
)
from torchao.utils import (
check_cpu_version,
check_xpu_version,
get_current_accelerator_device,
is_fbcode,
is_on_device,
not_on_device,
)

_SEED = 1234
Expand Down Expand Up @@ -152,9 +152,9 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.reshape_as(w)
)

if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
if not_on_device(w, ["cpu", "xpu"]):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if check_xpu_version(w.device):
if is_on_device(w, "xpu"):
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)

return w_int4x8
Expand Down Expand Up @@ -708,11 +708,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
if zero_point_domain == ZeroPointDomain.INT:
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
input_tmp = input
if (not (check_cpu_version(input.device))) and (
not (check_xpu_version(input.device))
):
if not_on_device(input, ["cpu", "xpu"]):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
if check_xpu_version(input.device):
if is_on_device(input, "xpu"):
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
Expand Down
6 changes: 1 addition & 5 deletions torchao/dtypes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import Optional, Tuple

import torch

Expand Down Expand Up @@ -77,10 +77,6 @@ def __post_init__(self):
)


def is_device(target_device_str: str, device: Union[str, torch.device]):
return torch.device(device).type == target_device_str


def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]:
"""Returns the unflattened shape of the input tensor.
Args:
Expand Down
4 changes: 2 additions & 2 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

from torchao.utils import check_cpu_version, torch_version_at_least
from torchao.utils import is_on_device, torch_version_at_least

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -192,7 +192,7 @@ def int_scaled_matmul(
assert 1 == scales1.size(1)
assert scales1.is_contiguous()

if check_cpu_version(scales1.device):
if is_on_device(scales1, "cpu"):
return _int_scaled_matmul_cpu(a, b, scales1)

scales1 = scales1.expand((M, N))
Expand Down
9 changes: 4 additions & 5 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
from hqq.core.utils import * # noqa: F401, F403
from torch import Tensor, nn

from torchao.dtypes.utils import is_device
from torchao.utils import check_cpu_version
from torchao.utils import is_on_device


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -167,7 +166,7 @@ def process_hqq_quants(self, W_q, meta):
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
)
if check_cpu_version(W_q.device):
if is_on_device(W_q, "cpu"):
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
Expand Down Expand Up @@ -209,7 +208,7 @@ def hqq_quants_to_torch_quants(
.reshape(shape)
.contiguous()
)
if not is_device(W_q.device.type, "cpu"):
if not is_on_device(W_q, "cpu"):
W_q = (W_q[::, ::2] << 4 | W_q[::, 1::2]).to(torch.uint8)

# group_dequantize_tensor_from_qparams
Expand Down Expand Up @@ -242,7 +241,7 @@ def pack_scales_and_zeros(self, scales, zeros):
def matmul(self, x):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
if check_cpu_version(x.device):
if is_on_device(x, "cpu"):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
Expand Down
5 changes: 2 additions & 3 deletions torchao/prototype/tensor_conversion/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import torch
import torch.nn as nn

# TODO: move the function to torchao.utils
from torchao.dtypes.utils import is_device
from torchao.quantization import (
Int4PreshuffledTensor,
Int4Tensor,
Expand All @@ -17,6 +15,7 @@
from torchao.utils import (
TorchAOBaseTensor,
_is_mslk_available,
is_on_device,
is_sm_at_least_90,
)

Expand Down Expand Up @@ -189,7 +188,7 @@ def convert_to_packed_tensor_based_on_current_hardware(tensor: TorchAOBaseTensor
"""
if (
isinstance(tensor, Int4Tensor)
and is_device("cuda", tensor.device)
and is_on_device(tensor, "cuda")
and _is_mslk_available()
and is_sm_at_least_90()
):
Expand Down
9 changes: 4 additions & 5 deletions torchao/quantization/linear_quant_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
import torch.nn as nn
import torch.nn.functional as F

from torchao.dtypes.utils import is_device
from torchao.utils import find_multiple
from torchao.utils import find_multiple, is_on_device

from .quant_primitives import (
MappingType,
Expand Down Expand Up @@ -57,7 +56,7 @@ def linear_forward_int4(
):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
if is_device(x.device.type, "cpu"):
if is_on_device(x, "cpu"):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x.to(precision),
weight_int4pack,
Expand Down Expand Up @@ -117,7 +116,7 @@ def __init__(
assert in_features % (inner_k_tiles * 16) == 0, (
"require in_features % (innerKTiles * 16) == 0"
)
if is_device(device.type, "cpu"):
if device.type == "cpu":
self.register_buffer(
"weight",
torch.zeros(
Expand Down Expand Up @@ -296,7 +295,7 @@ def _create_quantized_state_dict(
self.precision, # dtype for scales_and_zeros
)
# TODO: just get the device from mod.weight.device?
if is_device(w_int4x8.device.type, "cpu"):
if is_on_device(w_int4x8, "cpu"):
weight_int4pack = (
torch.ops.aten._convert_weight_to_int4pack_for_cpu(
w_int4x8.to(self.device), self.inner_k_tiles
Expand Down
4 changes: 2 additions & 2 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
import torch.nn.functional as F

from torchao.dtypes.utils import is_device
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.linear_quant_modules import (
Int8DynActInt4WeightLinear,
Expand All @@ -25,6 +24,7 @@
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.utils import is_on_device

from .fake_quantize_config import (
FakeQuantizeConfigBase,
Expand Down Expand Up @@ -478,7 +478,7 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
n_bit,
config.group_size,
)
if is_device(q_weight.device.type, "cpu"):
if is_on_device(q_weight, "cpu"):
q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
q_weight.to(child.weight.device),
child.inner_k_tiles,
Expand Down
17 changes: 6 additions & 11 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
dequantize_affine,
quantize_affine,
)
from torchao.utils import (
check_cpu_version,
check_xpu_version,
)
from torchao.utils import is_on_device, not_on_device

from .granularity import (
Granularity,
Expand Down Expand Up @@ -465,11 +462,9 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max,
)
if w.shape[-1] > 1:
if (not (check_cpu_version(int_data.device))) and (
not (check_xpu_version(int_data.device))
):
if not_on_device(int_data, ["cpu", "xpu"]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I don't think we need special util just for having a list of args, I think it's better to just use is_device

int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if check_xpu_version(int_data.device):
if is_on_device(int_data, "xpu"):
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
return int_data

Expand All @@ -485,8 +480,8 @@ def groupwise_affine_dequantize_tensor_from_qparams(
assert groupsize > 1
assert w_int4x8.dim() == 2
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not (
check_cpu_version(w_int4x8.device)
if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not_on_device(
w_int4x8, "cpu"
):
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
Expand All @@ -496,7 +491,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
dtype=torch.int32,
device=w_int4x8.device,
)
if not (check_xpu_version(w_int4x8.device)):
if not_on_device(w_int4x8, "xpu"):
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
Expand Down
26 changes: 13 additions & 13 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from functools import reduce
from importlib.metadata import version
from math import gcd
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import torch
import torch.nn.utils.parametrize as parametrize
Expand Down Expand Up @@ -1239,18 +1239,6 @@ def is_cuda_version_at_least(major: int, minor: int) -> bool:
return (cuda_major, cuda_minor) >= (major, minor)


def check_cpu_version(device, version="2.6.0"):
if isinstance(device, torch.device):
device = device.type
return device == "cpu" and torch_version_at_least(version)


def check_xpu_version(device, version="2.8.0"):
if isinstance(device, torch.device):
device = device.type
return device == "xpu" and torch_version_at_least(version)


def ceil_div(a, b):
return (a + b - 1) // b

Expand Down Expand Up @@ -1296,6 +1284,18 @@ def _is_flashinfer_available():
) or is_fbcode()


def is_on_device(tensor: torch.Tensor, device_str: str) -> bool:
return tensor.device.type == device_str


def not_on_device(
tensor: torch.Tensor, devices: Union[str, list[str], tuple[str, ...]]
) -> bool:
if isinstance(devices, str):
devices = [devices]
return not any(is_on_device(tensor, device) for device in devices)


class DummyModule(torch.nn.Module):
"""This is used because the TorchAO quantization functions tend to operate on modules so to apply the transform to a tensor, we can load a
DummyModule with the target tensor and then apply the transformation to the module and then extract the transformed tensor.
Expand Down
Loading