diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 868b59e54f..963253eae5 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -35,8 +35,7 @@ groupwise_affine_quantize_tensor_from_qparams, ) from torchao.utils import ( - check_cpu_version, - check_xpu_version, + _is_device, get_current_accelerator_device, is_fbcode, ) @@ -152,9 +151,9 @@ def _groupwise_affine_quantize_tensor_from_qparams( .reshape_as(w) ) - if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))): + if (not (_is_device("cpu", w.device))) and (not (_is_device("xpu", w.device))): w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8) - if check_xpu_version(w.device): + if _is_device("xpu", w.device): w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8) return w_int4x8 @@ -708,11 +707,11 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self): if zero_point_domain == ZeroPointDomain.INT: zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32) input_tmp = input - if (not (check_cpu_version(input.device))) and ( - not (check_xpu_version(input.device)) + if (not (_is_device("cpu", input.device))) and ( + not (_is_device("xpu", input.device)) ): input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8) - if check_xpu_version(input.device): + if _is_device("xpu", input.device): input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8) w_bf16 = groupwise_affine_dequantize_tensor_from_qparams( input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index f159aa1fc7..20efb54178 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -10,7 +10,7 @@ from torch._dynamo import is_compiling as dynamo_is_compiling from torch._higher_order_ops.out_dtype import out_dtype -from torchao.utils import check_cpu_version, torch_version_at_least +from torchao.utils import _is_device, torch_version_at_least logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) @@ -192,7 +192,7 @@ def int_scaled_matmul( assert 1 == scales1.size(1) assert scales1.is_contiguous() - if check_cpu_version(scales1.device): + if _is_device("cpu", scales1.device): return _int_scaled_matmul_cpu(a, b, scales1) scales1 = scales1.expand((M, N)) diff --git a/torchao/prototype/hqq/hqq_tinygemm_linear.py b/torchao/prototype/hqq/hqq_tinygemm_linear.py index 045ed5f6e4..ab2ce43321 100644 --- a/torchao/prototype/hqq/hqq_tinygemm_linear.py +++ b/torchao/prototype/hqq/hqq_tinygemm_linear.py @@ -16,7 +16,7 @@ from hqq.core.utils import * # noqa: F401, F403 from torch import Tensor, nn -from torchao.utils import _is_device, check_cpu_version +from torchao.utils import _is_device class HQQLinearTorchWeightOnlyInt4(torch.nn.Module): @@ -166,7 +166,7 @@ def process_hqq_quants(self, W_q, meta): W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants( W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits ) - if check_cpu_version(W_q.device): + if _is_device("cpu", W_q.device): self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu( W_q_torch, self.inner_k_tiles ) @@ -241,7 +241,7 @@ def pack_scales_and_zeros(self, scales, zeros): def matmul(self, x): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) - if check_cpu_version(x.device): + if _is_device("cpu", x.device): c = torch.ops.aten._weight_int4pack_mm_for_cpu( x, self.weight_int4pack, self.groupsize, self.scales_and_zeros ) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3cde912f67..79ca6441e9 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -25,10 +25,7 @@ dequantize_affine, quantize_affine, ) -from torchao.utils import ( - check_cpu_version, - check_xpu_version, -) +from torchao.utils import _is_device from .granularity import ( Granularity, @@ -462,11 +459,11 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_max, ) if w.shape[-1] > 1: - if (not (check_cpu_version(int_data.device))) and ( - not (check_xpu_version(int_data.device)) + if (not (_is_device("cpu", int_data.device))) and ( + not (_is_device("xpu", int_data.device)) ): int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - if check_xpu_version(int_data.device): + if _is_device("xpu", int_data.device): int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8) return int_data @@ -483,7 +480,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( assert w_int4x8.dim() == 2 # need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not ( - check_cpu_version(w_int4x8.device) + _is_device("cpu", w_int4x8.device) ): data = w_int4x8.to(torch.int32) high_bits = data >> 4 @@ -493,7 +490,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( dtype=torch.int32, device=w_int4x8.device, ) - if not (check_xpu_version(w_int4x8.device)): + if not (_is_device("xpu", w_int4x8.device)): w_int32[::, ::2] = high_bits w_int32[::, 1::2] = low_bits else: diff --git a/torchao/utils.py b/torchao/utils.py index 219dd0e4f1..ca0dda262c 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -1169,18 +1169,6 @@ def is_cuda_version_at_least(major: int, minor: int) -> bool: return (cuda_major, cuda_minor) >= (major, minor) -def check_cpu_version(device, version="2.6.0"): - if isinstance(device, torch.device): - device = device.type - return device == "cpu" and torch_version_at_least(version) - - -def check_xpu_version(device, version="2.8.0"): - if isinstance(device, torch.device): - device = device.type - return device == "xpu" and torch_version_at_least(version) - - def ceil_div(a, b): return (a + b - 1) // b