|
10 | 10 | import torch |
11 | 11 |
|
12 | 12 | from torchao.prototype.dtypes.uintx.uintx_layout import to_uintx |
13 | | -from torchao.quantization.quant_api import UIntXWeightOnlyConfig, quantize_ |
| 13 | +from torchao.quantization.quant_api import quantize_ # noqa: F401 |
14 | 14 | from torchao.quantization.quant_primitives import ( |
15 | 15 | MappingType, |
16 | 16 | choose_qparams_affine, |
@@ -60,38 +60,6 @@ def forward(self, x): |
60 | 60 | return self.net(x) |
61 | 61 |
|
62 | 62 |
|
63 | | -@pytest.mark.parametrize("dtype", dtypes) |
64 | | -@pytest.mark.parametrize("group_size", group_sizes) |
65 | | -@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") |
66 | | -def test_uintx_quant_on_cpu_then_move_to_cuda(dtype, group_size): |
67 | | - scale = 512 |
68 | | - fp16_mod_on_cpu = Linear16(scale, "cpu") |
69 | | - device = get_current_accelerator_device() |
70 | | - quantize_(fp16_mod_on_cpu, UIntXWeightOnlyConfig(dtype, group_size=group_size)) |
71 | | - test_input_on_cpu = torch.randn(scale * 2, dtype=torch.float16, device="cpu") |
72 | | - output_on_cpu = fp16_mod_on_cpu(test_input_on_cpu) |
73 | | - fp16_mod_on_cuda = fp16_mod_on_cpu.to(device) |
74 | | - test_input_on_cuda = test_input_on_cpu.to(device) |
75 | | - output_on_cuda = fp16_mod_on_cuda(test_input_on_cuda) |
76 | | - assert torch.allclose(output_on_cpu, output_on_cuda.cpu(), atol=1.0e-3), ( |
77 | | - "The output of the model on CPU and CUDA should be close" |
78 | | - ) |
79 | | - |
80 | | - |
81 | | -@pytest.mark.parametrize("dtype", dtypes) |
82 | | -@pytest.mark.parametrize("group_size", group_sizes) |
83 | | -@pytest.mark.parametrize("device", devices) |
84 | | -@pytest.mark.skipif(not torch.accelerator.is_available(), reason="GPU not available") |
85 | | -def test_uintx_weight_only_model_quant(dtype, group_size, device): |
86 | | - scale = 512 |
87 | | - fp16 = Linear16(scale, device) |
88 | | - quantize_(fp16, UIntXWeightOnlyConfig(dtype, group_size=group_size)) |
89 | | - uintx = torch.compile(fp16, fullgraph=True) |
90 | | - test_input = torch.randn(scale * 2, dtype=torch.float16, device=device) |
91 | | - output = uintx.forward(test_input) |
92 | | - assert output is not None, "model quantization failed" |
93 | | - |
94 | | - |
95 | 63 | @pytest.mark.parametrize("dtype", dtypes) |
96 | 64 | @pytest.mark.parametrize("group_size", group_sizes) |
97 | 65 | @pytest.mark.parametrize("device", devices) |
@@ -128,55 +96,6 @@ def test_uintx_weight_only_quant(dtype, group_size, device): |
128 | 96 | assert deqaunt is not None, "deqauntization failed" |
129 | 97 |
|
130 | 98 |
|
131 | | -@pytest.mark.parametrize("dtype", dtypes) |
132 | | -@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available") |
133 | | -def test_uintx_target_dtype(dtype): |
134 | | - device = get_current_accelerator_device() |
135 | | - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) |
136 | | - # make sure it runs |
137 | | - quantize_(linear, UIntXWeightOnlyConfig(dtype)) |
138 | | - linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device)) |
139 | | - |
140 | | - |
141 | | -@pytest.mark.parametrize("dtype", dtypes) |
142 | | -@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available") |
143 | | -def test_uintx_target_dtype_compile(dtype): |
144 | | - device = get_current_accelerator_device() |
145 | | - linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device=device) |
146 | | - # make sure it runs |
147 | | - quantize_(linear, UIntXWeightOnlyConfig(dtype)) |
148 | | - linear = torch.compile(linear) |
149 | | - linear(torch.randn(1, 128, dtype=torch.bfloat16, device=device)) |
150 | | - |
151 | | - |
152 | | -@pytest.mark.parametrize("dtype", dtypes) |
153 | | -@pytest.mark.skipif(not torch.accelerator.is_available(), reason="Need GPU available") |
154 | | -def test_uintx_model_size(dtype): |
155 | | - from torchao.utils import get_model_size_in_bytes |
156 | | - |
157 | | - # scale size = 1/64 * 2 bytes = 1/32 bytes |
158 | | - # zero_point size = 1/64 * 4 bytes = 1/16 bytes |
159 | | - # dtype data size = 1 * bit_width/8 = bit_width/8 bytes |
160 | | - _dtype_to_ratio = { |
161 | | - torch.uint1: (1 / 8 + 1 / 16 + 1 / 32) / 2, |
162 | | - torch.uint2: (2 / 8 + 1 / 16 + 1 / 32) / 2, |
163 | | - torch.uint3: (3 / 8 + 1 / 16 + 1 / 32) / 2, |
164 | | - torch.uint4: (4 / 8 + 1 / 16 + 1 / 32) / 2, |
165 | | - torch.uint5: (5 / 8 + 1 / 16 + 1 / 32) / 2, |
166 | | - torch.uint6: (6 / 8 + 1 / 16 + 1 / 32) / 2, |
167 | | - torch.uint7: (7 / 8 + 1 / 16 + 1 / 32) / 2, |
168 | | - } |
169 | | - device = get_current_accelerator_device() |
170 | | - linear = torch.nn.Sequential( |
171 | | - torch.nn.Linear(128, 256, bias=False, dtype=torch.bfloat16, device=device) |
172 | | - ) |
173 | | - bf16_size = get_model_size_in_bytes(linear) |
174 | | - # make sure it runs |
175 | | - quantize_(linear[0], UIntXWeightOnlyConfig(dtype)) |
176 | | - quantized_size = get_model_size_in_bytes(linear) |
177 | | - assert bf16_size * _dtype_to_ratio[dtype] == quantized_size |
178 | | - |
179 | | - |
180 | 99 | def test_uintx_api_deprecation(): |
181 | 100 | """ |
182 | 101 | Test that deprecated uintx APIs trigger deprecation warnings on import. |
|
0 commit comments