Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions test/dtypes/test_affine_quantized.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
from torchao.quantization.quant_primitives import MappingType
from torchao.testing.utils import skip_if_rocm
from torchao.utils import (
check_cpu_version,
check_xpu_version,
get_current_accelerator_device,
is_fbcode,
is_sm_at_least_89,
Expand All @@ -49,11 +47,6 @@ def get_quantization_functions(
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC),
]
if do_int4:
if check_cpu_version(device):
pass
elif check_xpu_version(device):
pass

if is_sm_at_least_89():
base_functions.append(Float8WeightOnlyConfig())
Expand Down
12 changes: 4 additions & 8 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
groupwise_affine_quantize_tensor_from_qparams,
)
from torchao.utils import (
check_cpu_version,
check_xpu_version,
get_current_accelerator_device,
is_fbcode,
)
Expand Down Expand Up @@ -152,9 +150,9 @@ def _groupwise_affine_quantize_tensor_from_qparams(
.reshape_as(w)
)

if (not (check_cpu_version(w.device))) and (not (check_xpu_version(w.device))):
if w.device != "cpu" and w.device != "xpu":
w_int4x8 = (w_int4x8[::, ::2] << 4 | w_int4x8[::, 1::2]).to(torch.uint8)
if check_xpu_version(w.device):
if w.device == "xpu":
w_int4x8 = (w_int4x8[::, 1::2] << 4 | w_int4x8[::, ::2]).to(torch.uint8)

return w_int4x8
Expand Down Expand Up @@ -708,11 +706,9 @@ def test_groupwise_affine_dequantize_tensor_from_qparams(self):
if zero_point_domain == ZeroPointDomain.INT:
zeros = torch.randint(0, 15, (10, 2), dtype=torch.int32)
input_tmp = input
if (not (check_cpu_version(input.device))) and (
not (check_xpu_version(input.device))
):
if input.device != "cpu" and input.device != "xpu":
input_tmp = (input[::, ::2] << 4 | input[::, 1::2]).to(torch.uint8)
if check_xpu_version(input.device):
if input.device == "xpu":
input_tmp = (input[::, 1::2] << 4 | input[::, ::2]).to(torch.uint8)
w_bf16 = groupwise_affine_dequantize_tensor_from_qparams(
input_tmp, scales, zeros, n_bit, groupsize, zero_point_domain
Expand Down
4 changes: 1 addition & 3 deletions torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from torch._dynamo import is_compiling as dynamo_is_compiling
from torch._higher_order_ops.out_dtype import out_dtype

from torchao.utils import check_cpu_version

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

Expand Down Expand Up @@ -190,7 +188,7 @@ def int_scaled_matmul(
assert 1 == scales1.size(1)
assert scales1.is_contiguous()

if check_cpu_version(scales1.device):
if scales1.device == "cpu":
return _int_scaled_matmul_cpu(a, b, scales1)

scales1 = scales1.expand((M, N))
Expand Down
5 changes: 2 additions & 3 deletions torchao/prototype/hqq/hqq_tinygemm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from torch import Tensor, nn

from torchao.dtypes.utils import is_device
from torchao.utils import check_cpu_version


class HQQLinearTorchWeightOnlyInt4(torch.nn.Module):
Expand Down Expand Up @@ -167,7 +166,7 @@ def process_hqq_quants(self, W_q, meta):
W_q_torch, scales_torch, zeros_torch = self.hqq_quants_to_torch_quants(
W_q=W_q, scales=scales, zeros=zeros, shape=shape, nbits=self.nbits
)
if check_cpu_version(W_q.device):
if W_q.device == "cpu":
self.weight_int4pack = torch.ops.aten._convert_weight_to_int4pack_for_cpu(
W_q_torch, self.inner_k_tiles
)
Expand Down Expand Up @@ -242,7 +241,7 @@ def pack_scales_and_zeros(self, scales, zeros):
def matmul(self, x):
origin_x_size = x.size()
x = x.reshape(-1, origin_x_size[-1])
if check_cpu_version(x.device):
if x.device == "cpu":
c = torch.ops.aten._weight_int4pack_mm_for_cpu(
x, self.weight_int4pack, self.groupsize, self.scales_and_zeros
)
Expand Down
18 changes: 6 additions & 12 deletions torchao/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
dequantize_affine,
quantize_affine,
)
from torchao.utils import (
check_cpu_version,
check_xpu_version,
)

from .granularity import (
Granularity,
Expand Down Expand Up @@ -465,11 +461,9 @@ def groupwise_affine_quantize_tensor_from_qparams(
quant_max,
)
if w.shape[-1] > 1:
if (not (check_cpu_version(int_data.device))) and (
not (check_xpu_version(int_data.device))
):
if int_data.device != "cpu" and int_data.device != "xpu":
int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8)
if check_xpu_version(int_data.device):
if int_data.device == "xpu":
int_data = (int_data[::, 1::2] << 4 | int_data[::, ::2]).to(torch.uint8)
return int_data

Expand All @@ -485,9 +479,9 @@ def groupwise_affine_dequantize_tensor_from_qparams(
assert groupsize > 1
assert w_int4x8.dim() == 2
# need to handle single column case so check for dtype/size from groupwise_affine_quantize_tensor_from_qparams path
if (w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1) and not (
check_cpu_version(w_int4x8.device)
):
if (
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this condition seems complicated, might be good to simplify if possible

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I am not familiar with the code here. And I feel the condition cannot be simplified. However, we can probably add a comment here to clarify the condition.
CC the last author @yanbing-j

w_int4x8.dtype == torch.uint8 or w_int4x8.shape[-1] > 1
) and w_int4x8.device != "cpu":
data = w_int4x8.to(torch.int32)
high_bits = data >> 4
low_bits = data & 0x0F
Expand All @@ -496,7 +490,7 @@ def groupwise_affine_dequantize_tensor_from_qparams(
dtype=torch.int32,
device=w_int4x8.device,
)
if not (check_xpu_version(w_int4x8.device)):
if w_int4x8.device != "xpu":
w_int32[::, ::2] = high_bits
w_int32[::, 1::2] = low_bits
else:
Expand Down
12 changes: 0 additions & 12 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1239,18 +1239,6 @@ def is_cuda_version_at_least(major: int, minor: int) -> bool:
return (cuda_major, cuda_minor) >= (major, minor)


def check_cpu_version(device, version="2.6.0"):
if isinstance(device, torch.device):
device = device.type
return device == "cpu" and torch_version_at_least(version)


def check_xpu_version(device, version="2.8.0"):
if isinstance(device, torch.device):
device = device.type
return device == "xpu" and torch_version_at_least(version)


def ceil_div(a, b):
return (a + b - 1) // b

Expand Down
Loading