Skip to content

Commit 56d174c

Browse files
authored
gptq: more obvious group var names (#4226)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 553a010 commit 56d174c

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

torchao/prototype/gptq/api.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,29 +274,31 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
274274

275275
# If we are doing per-row quantization, the group_size is equal to the number of columns and this will only run once.
276276
# Otherwise, if we do per-group quantization, we need to iterate through the block one group at a time.
277-
for group_start in range(k_block_start, k_block_end, group_size):
278-
group_end = min(group_start + group_size, k_block_end)
277+
for k_group_start in range(k_block_start, k_block_end, group_size):
278+
k_group_end = min(k_group_start + group_size, k_block_end)
279279

280280
# We only need to calculate initial qparams for the group once
281-
if group_start % group_size == 0:
281+
if k_group_start % group_size == 0:
282282
if isinstance(base_config, Int4WeightOnlyConfig):
283283
_, scale, zero_point = int4_row_quantize_zp(
284284
W_t_quantize_block[
285-
:, group_start - k_block_start : group_end - k_block_start
285+
:,
286+
k_group_start - k_block_start : k_group_end - k_block_start,
286287
],
287288
group_size,
288289
)
289290
group_qparams.append((scale, zero_point))
290291
elif isinstance(base_config, Int8WeightOnlyConfig):
291292
quantized_tensor = Int8Tensor.from_hp(
292293
W_t_quantize_block[
293-
:, group_start - k_block_start : group_end - k_block_start
294+
:,
295+
k_group_start - k_block_start : k_group_end - k_block_start,
294296
],
295297
base_config.granularity,
296298
)
297299

298300
# Quantize each column and propagate errors to subsequent columns
299-
for i in range(group_start - k_block_start, group_end - k_block_start):
301+
for i in range(k_group_start - k_block_start, k_group_end - k_block_start):
300302
w_t = W_t_quantize_block[:, i].unsqueeze(1)
301303
if isinstance(base_config, Int4WeightOnlyConfig):
302304
q = _int4_row_quantize_zp_precomputed_qparams(

0 commit comments

Comments
 (0)