Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@
groupwise_affine_quantize_tensor_from_qparams,
)
from torchao.utils import (
check_cpu_version,
check_xpu_version,
_is_device,
get_current_accelerator_device,
is_fbcode,
)
Expand Down Expand Up @@ -152,9 +151,9 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.reshape_as(w)
)

if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
if (not (_is_device("cpu", w.device))) and (not (_is_device("xpu", w.device))):
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if check_xpu_version(w.device):
if _is_device("xpu", w.device):
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)

return w_int4x8
Expand Down Expand Up @@ -708,11 +707,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
if zero_point_domain == ZeroPointDomain.INT:
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
input_tmp = input
if (not (check_cpu_version(input.device))) and (
not (check_xpu_version(input.device))
if (not (_is_device("cpu", input.device))) and (
not (_is_device("xpu", input.device))
):
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
if check_xpu_version(input.device):
if _is_device("xpu", input.device):
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
Expand Down
4 changes: 2 additions & 2 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

from torchao.utils import check_cpu_version, torch_version_at_least
from torchao.utils import _is_device, torch_version_at_least

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
Expand Down Expand Up @@ -192,7 +192,7 @@ def int_scaled_matmul(
assert 1 == scales1.size(1)
assert scales1.is_contiguous()

if check_cpu_version(scales1.device):
if _is_device("cpu", scales1.device):
return _int_scaled_matmul_cpu(a, b, scales1)

scales1 = scales1.expand((M, N))
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from hqq.core.utils import * # noqa: F401, F403
from torch import Tensor, nn

from torchao.utils import _is_device, check_cpu_version
from torchao.utils import _is_device


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -166,7 +166,7 @@ def process_hqq_quants(self, W_q, meta):
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
)
if check_cpu_version(W_q.device):
if _is_device("cpu", W_q.device):
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
Expand Down Expand Up @@ -241,7 +241,7 @@ def pack_scales_and_zeros(self, scales, zeros):
def matmul(self, x):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
if check_cpu_version(x.device):
if _is_device("cpu", x.device):
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
Expand Down
15 changes: 6 additions & 9 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
dequantize_affine,
quantize_affine,
)
from torchao.utils import (
check_cpu_version,
check_xpu_version,
)
from torchao.utils import _is_device

from .granularity import (
Granularity,
Expand Down Expand Up @@ -462,11 +459,11 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max,
)
if w.shape[-1] > 1:
if (not (check_cpu_version(int_data.device))) and (
not (check_xpu_version(int_data.device))
if (not (_is_device("cpu", int_data.device))) and (
not (_is_device("xpu", int_data.device))
):
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if check_xpu_version(int_data.device):
if _is_device("xpu", int_data.device):
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
return int_data

Expand All @@ -483,7 +480,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
assert w_int4x8.dim() == 2
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not (
check_cpu_version(w_int4x8.device)
_is_device("cpu", w_int4x8.device)
):
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
Expand All @@ -493,7 +490,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
dtype=torch.int32,
device=w_int4x8.device,
)
if not (check_xpu_version(w_int4x8.device)):
if not (_is_device("xpu", w_int4x8.device)):
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
Expand Down
12 changes: 0 additions & 12 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,18 +1169,6 @@ def is_cuda_version_at_least(major: int, minor: int) -> bool:
return (cuda_major, cuda_minor) >= (major, minor)


def check_cpu_version(device, version="2.6.0"):
if isinstance(device, torch.device):
device = device.type
return device == "cpu" and torch_version_at_least(version)


def check_xpu_version(device, version="2.8.0"):
if isinstance(device, torch.device):
device = device.type
return device == "xpu" and torch_version_at_least(version)


def ceil_div(a, b):
return (a + b - 1) // b

Expand Down
Loading