Skip to content

Commit b49d8cb

Browse files
authored
add gptq benchmark, and speed up by ~3x with compile (#4310)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 2c41725 commit b49d8cb

3 files changed

Lines changed: 123 additions & 2 deletions

File tree

benchmarks/benchmark_gptq.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import time
8+
9+
import fire
10+
import torch
11+
12+
from torchao.prototype.gptq import GPTQConfig, gptq_quantize
13+
from torchao.prototype.mx_formats.inference_workflow import (
14+
NVFP4DynamicActivationNVFP4WeightConfig,
15+
)
16+
17+
18+
def run(
19+
K: int = 2048,
20+
N: int = 4096,
21+
profile_fname: str = None,
22+
):
23+
print(f"K={K}, N={N}")
24+
25+
A = torch.randn(K, K, dtype=torch.float32, device="cuda")
26+
H = A.t() @ A
27+
28+
W_t = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
29+
30+
config = GPTQConfig(
31+
step="convert",
32+
base_config=NVFP4DynamicActivationNVFP4WeightConfig(
33+
use_dynamic_per_tensor_scale=True,
34+
use_triton_kernel=True,
35+
),
36+
)
37+
38+
# Warmup
39+
print("Warmup...")
40+
gptq_quantize(H.clone(), W_t.clone(), config)
41+
torch.cuda.synchronize()
42+
43+
num_runs = 5
44+
if profile_fname is not None:
45+
print("Profiling run...")
46+
with torch.profiler.profile(
47+
activities=[
48+
torch.profiler.ProfilerActivity.CPU,
49+
torch.profiler.ProfilerActivity.CUDA,
50+
],
51+
record_shapes=True,
52+
with_stack=True,
53+
) as prof:
54+
torch.cuda.synchronize()
55+
start = time.time()
56+
gptq_quantize(H.clone(), W_t.clone(), config)
57+
torch.cuda.synchronize()
58+
elapsed = time.time() - start
59+
print(f"gptq_quantize time: {elapsed:.3f}s")
60+
prof.export_chrome_trace(profile_fname)
61+
print(f"Saved: {profile_fname}")
62+
else:
63+
print(f"Timed run ({num_runs} iterations)...")
64+
times = []
65+
for _ in range(num_runs):
66+
torch.cuda.synchronize()
67+
start = time.time()
68+
gptq_quantize(H.clone(), W_t.clone(), config)
69+
torch.cuda.synchronize()
70+
times.append(time.time() - start)
71+
avg = sum(times) / len(times)
72+
print(f"gptq_quantize avg time: {avg:.3f}s")
73+
74+
75+
if __name__ == "__main__":
76+
fire.Fire(run)

test/prototype/gptq/test_gptqv2.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@
1010
import torch
1111
import torch.nn.functional as F
1212

13+
from torchao.utils import torch_version_at_least
14+
15+
pytestmark = pytest.mark.skipif(
16+
not torch_version_at_least("2.11.0"),
17+
reason="GPTQ prototype requires PyTorch 2.11+",
18+
)
19+
1320
from torchao.prototype.gptq import (
1421
GPTQConfig,
1522
gptq_quantize,

torchao/prototype/gptq/api.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,18 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
78
import types
9+
import warnings
810
from dataclasses import dataclass
911
from functools import partial
1012
from typing import Union
1113

1214
import torch
1315
import torch.nn as nn
1416

17+
from torchao.utils import torch_version_at_least
18+
1519
try:
1620
from mslk.quantize.shuffle import int4_row_quantize_zp, pack_int4
1721
except:
@@ -265,6 +269,40 @@ def _nvfp4_with_precalculated_scales_q(
265269
return data_lp_packed
266270

267271

272+
# Set to True to torch.compile the NVFP4 quantize/dequantize functions
273+
# inside gptq_quantize. Gives ~3x speedup.
274+
_use_torch_compile = True
275+
276+
if _use_torch_compile:
277+
_nvfp4_qdq_fn = torch.compile(_nvfp4_with_precalculated_scales_qdq)
278+
_nvfp4_q_fn = torch.compile(_nvfp4_with_precalculated_scales_q)
279+
280+
if torch_version_at_least("2.11.0"):
281+
# Triton's default f32 division uses approximate reciprocal which
282+
# introduces ~1 ULP error per division. In GPTQ's error propagation
283+
# loop this compounds across columns. IEEE-compliant division rounding
284+
# eliminates the drift.
285+
import torch._inductor.config as _inductor_config
286+
287+
if os.environ.get("TORCHINDUCTOR_EMULATE_DIVISION_ROUNDING") == "0":
288+
warnings.warn(
289+
"TORCHINDUCTOR_EMULATE_DIVISION_ROUNDING=0 may cause numerical "
290+
"drift in GPTQ with torch.compile. "
291+
"Consider unsetting it or setting it to 1."
292+
)
293+
else:
294+
_inductor_config.eager_numerics.division_rounding = True
295+
else:
296+
warnings.warn(
297+
"PyTorch < 2.11.0 detected. Upgrade to PyTorch 2.11.0+ for "
298+
"better GPTQ numerics with torch.compile (IEEE-compliant "
299+
"division rounding)."
300+
)
301+
else:
302+
_nvfp4_qdq_fn = _nvfp4_with_precalculated_scales_qdq
303+
_nvfp4_q_fn = _nvfp4_with_precalculated_scales_q
304+
305+
268306
def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
269307
"""
270308
This function implements the GPTQ algorithm described in this paper: https://arxiv.org/abs/2210.17323 (Algorithm 1)
@@ -472,7 +510,7 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
472510
)
473511
dq = q.dequantize(output_dtype=torch.float)
474512
elif isinstance(base_config, NVFP4DynamicActivationNVFP4WeightConfig):
475-
dq = _nvfp4_with_precalculated_scales_qdq(
513+
dq = _nvfp4_qdq_fn(
476514
w_t,
477515
nvfp4_global_scale,
478516
scale.squeeze(-1),
@@ -519,7 +557,7 @@ def gptq_quantize(H: torch.Tensor, W_t: torch.Tensor, config: GPTQConfig):
519557
combined_scale = (
520558
torch.cat(group_qparams, dim=0).reshape(K // group_size, N).t().contiguous()
521559
)
522-
qdata = _nvfp4_with_precalculated_scales_q(
560+
qdata = _nvfp4_q_fn(
523561
W_t,
524562
nvfp4_global_scale,
525563
combined_scale,

0 commit comments

Comments
 (0)