Skip to content
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
f46445f
Update
vkuzo Apr 20, 2026
3c92c1a
Update
vkuzo Apr 20, 2026
b513b61
Update
vkuzo Apr 21, 2026
a669b9e
Update
vkuzo Apr 21, 2026
53bd8d0
Update
vkuzo Apr 21, 2026
4c86363
Update
vkuzo Apr 21, 2026
3cc91ed
Update
vkuzo Apr 21, 2026
9b7dc74
Update
vkuzo Apr 21, 2026
d69b32a
Update
vkuzo Apr 21, 2026
294c9cc
Update
vkuzo Apr 21, 2026
65fae62
Update
vkuzo Apr 21, 2026
5ee2ad2
Update
vkuzo Apr 22, 2026
2adda75
Update
vkuzo Apr 22, 2026
6463808
Update
vkuzo Apr 22, 2026
d121bff
Update
vkuzo Apr 22, 2026
80421c8
Update
vkuzo Apr 22, 2026
d302888
Update
vkuzo Apr 22, 2026
9631b76
Update
vkuzo Apr 22, 2026
5fe6574
Update
vkuzo Apr 22, 2026
5292f2f
Update
vkuzo Apr 22, 2026
f679216
Update
vkuzo Apr 22, 2026
68dc794
Update
vkuzo Apr 23, 2026
3ffc619
Update
vkuzo Apr 23, 2026
2f0a3cf
Update
vkuzo Apr 23, 2026
fad1467
Update
vkuzo Apr 23, 2026
f668c26
Update
vkuzo Apr 23, 2026
522de32
Update
vkuzo Apr 23, 2026
f635432
Update
vkuzo Apr 23, 2026
31bcb11
Update
vkuzo Apr 23, 2026
75542fa
Update
vkuzo Apr 23, 2026
be9dc1b
Update
vkuzo Apr 23, 2026
f14cde0
Update
vkuzo Apr 23, 2026
83283cf
Update
vkuzo Apr 23, 2026
ed9e39f
Update
vkuzo Apr 23, 2026
2386670
Update
vkuzo Apr 23, 2026
c1da849
Update
vkuzo Apr 23, 2026
cdcd2b3
Update
vkuzo Apr 23, 2026
196d439
Update
vkuzo Apr 23, 2026
b0697ac
Update
vkuzo Apr 23, 2026
620250d
Update
vkuzo Apr 23, 2026
19bc5c8
Update
vkuzo Apr 23, 2026
61493d9
Update
vkuzo Apr 23, 2026
5a0db16
Update
vkuzo Apr 23, 2026
7794548
Update
vkuzo Apr 23, 2026
8f1f410
Update
vkuzo Apr 23, 2026
b664d8a
Update
vkuzo Apr 23, 2026
dc35c65
Update
vkuzo Apr 23, 2026
f23cf38
Update
vkuzo Apr 23, 2026
6650b1d
Update
vkuzo Apr 23, 2026
4d9b68f
Update
vkuzo Apr 23, 2026
8a21110
Update
vkuzo Apr 23, 2026
4a456ec
Update
vkuzo Apr 23, 2026
f8d1861
Update
vkuzo Apr 23, 2026
932677b
Update
vkuzo Apr 23, 2026
0c74af8
Update
vkuzo Apr 24, 2026
9efbf9f
Update
vkuzo Apr 24, 2026
5daac37
Update
vkuzo Apr 24, 2026
0306ca4
Update
vkuzo Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 84 additions & 5 deletions scripts/prototype/test_nvfp4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,68 @@
from transformers.models.olmoe.modeling_olmoe import OlmoeExperts
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer

from torchao.prototype.gptq.gptq_example import prepare_dataset
from torchao.prototype.mx_formats.inference_workflow import (
NVFP4DynamicActivationNVFP4WeightConfig,
)
from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor
from torchao.quantization import FqnToConfig, quantize_


def main(recipe: str = "bf16", run_lm_eval: bool = False):
def install_expert_counters(model):
"""Install forward pre-hooks on OlmoeExperts to count per-expert token routing.

Returns a dict mapping module FQN to a (num_experts,) int64 tensor of counts,
and a list of hook handles for removal.
"""
expert_counts = {}
handles = []

for name, mod in model.named_modules():
if not isinstance(mod, OlmoeExperts):
continue
counts = torch.zeros(mod.num_experts, dtype=torch.int64)
expert_counts[name] = counts

def make_hook(c, n):
def hook(module, args, kwargs):
top_k_index = args[1]
c.add_(top_k_index.flatten().bincount(minlength=n).cpu())

return hook

handles.append(
mod.register_forward_pre_hook(
make_hook(counts, mod.num_experts), with_kwargs=True
)
)

return expert_counts, handles


def print_expert_counts(expert_counts):
print("\n=== Per-expert token counts ===")
for name, counts in expert_counts.items():
total = counts.sum().item()
print(f"{name}: total={total}, per_expert={counts.tolist()}")

print("\n=== Global expert utilization summary ===")
all_counts = torch.cat([c for c in expert_counts.values()])
n = len(all_counts)
for threshold in range(129):
if threshold % 10 != 0:
continue
num = int((all_counts <= threshold).sum().item())
print(f"experts with <= {threshold} tokens: {num}/{n} ({num / n * 100:.1f}%)")


def main(
recipe: str = "bf16",
run_lm_eval: bool = False,
calibrate_on_c4: bool = False,
num_calibration_samples: int = 128,
max_sequence_length: int = 2048,
):
print(f"{recipe=}")
model_id = "allenai/OLMoE-1B-7B-0924"

Expand Down Expand Up @@ -54,15 +108,40 @@ def main(recipe: str = "bf16", run_lm_eval: bool = False):
f"{name}.{pname} is {type(param).__name__}, expected NVFP4Tensor"
)

# generate() switches to batched_mm for decoding, which doesn't support
# NVFP4Tensor (needs aten.index.Tensor). Override to keep grouped_mm.
# TODO(future PR): implement bmm for nvfp4 and remove this workaround
model._optimize_model_for_decode = nullcontext
elif recipe == "bf16":
pass
else:
raise ValueError(f"Unknown recipe: {recipe}")

# generate() switches to batched_mm for decoding, which doesn't support
# NVFP4Tensor (needs aten.index.Tensor). Override to keep grouped_mm.
# TODO(future PR): implement bmm for nvfp4 and remove this workaround
model._optimize_model_for_decode = nullcontext

if calibrate_on_c4:
assert recipe == "bf16", (
"calibrate_on_c4 is only supported with recipe=bf16 for now"
)

expert_counts, hooks = install_expert_counters(model)

dataset = prepare_dataset(
tokenizer,
max_sequence_length,
num_calibration_samples=num_calibration_samples,
dataset_id="c4",
dataset_split="train",
)
print(f"Running calibration on {len(dataset)} C4 samples...")
with torch.no_grad():
for seq in dataset:
model(seq.to("cuda"))
print("Calibration complete.")

print_expert_counts(expert_counts)
for h in hooks:
h.remove()

prompt = "The capital of France is"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")

Expand Down
94 changes: 47 additions & 47 deletions test/prototype/gptq/test_gptqv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,33 +102,13 @@ def test_observer_tensor_creation(self):
# Check hp_data is stored correctly
torch.testing.assert_close(observer.hp_data, weight)

# Check hessian is initialized as None
assert observer.hessian is None
# Check hessian is initialized as zeros
assert torch.equal(
observer.hessian, torch.zeros(64, 64, dtype=torch.float32, device="cuda")
)

# Check total_batches is initialized as 0
assert observer.total_batches == 0

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_observer_tensor_attributes(self):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test is not that useful, deleting instead of updating with new contents of the observer tensor

"""Test GPTQObserverTensor attributes are correctly set."""
weight = torch.randn(16, 32, dtype=torch.bfloat16, device="cuda")
observer = GPTQObserverTensor.from_hp(weight)

# Test hp_data attribute
assert hasattr(observer, "hp_data")
assert isinstance(observer.hp_data, torch.Tensor)

# Test hessian attribute
assert hasattr(observer, "hessian")
assert observer.hessian is None

# Test total_batches attribute
assert hasattr(observer, "total_batches")
assert observer.total_batches == 0

# Test update method exists
assert hasattr(observer, "update")
assert callable(observer.update)
assert (observer.total_batches == 0).all()

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_linear_operation_with_observer(self):
Expand Down Expand Up @@ -157,7 +137,7 @@ def test_linear_operation_with_observer(self):
# Check that Hessian was initialized and updated
assert observer_weight.hessian is not None
assert observer_weight.hessian.shape == (in_features, in_features)
assert observer_weight.total_batches == 1
assert (observer_weight.total_batches == 1).all()

# Verify output is correct
expected_output = F.linear(input_tensor, weight)
Expand Down Expand Up @@ -191,35 +171,47 @@ def test_multiple_observations(self):
assert observer_weight.hessian.shape == (in_features, in_features)

# Check total_batches matches total samples
assert observer_weight.total_batches == total_samples
assert (observer_weight.total_batches == total_samples).all()

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
def test_bmm_operation_with_observer(self):
"""Test torch.bmm with GPTQObserverTensor updates Hessian correctly."""
batch = 4
num_experts = 4
m = 8
n = 16
k = 12
num_passes = 4

# Create input and weight tensors
input_tensor = torch.randn(batch, m, k, dtype=torch.float32, device="cuda")
weight = torch.randn(batch, k, n, dtype=torch.float32, device="cuda")
observer_weight = GPTQObserverTensor.from_hp(weight)

# Perform bmm operation
output = torch.bmm(input_tensor, observer_weight)
weight = torch.randn(num_experts, n, k, dtype=torch.float32, device="cuda")

# Check output shape
assert output.shape == (batch, m, n)
inputs = [
torch.randn(num_experts, m, k, dtype=torch.float32, device="cuda")
for _ in range(num_passes)
]

# Check Hessian was initialized and updated
assert observer_weight.hessian is not None
# For bmm with batch dimension, the Hessian is computed on the last dimension
assert observer_weight.total_batches == batch
# 3D path: single observer with bmm
observer_3d = GPTQObserverTensor.from_hp(weight)
for x in inputs:
torch.bmm(x, observer_3d.transpose(-2, -1))

# Verify output is correct
expected_output = torch.bmm(input_tensor, weight)
torch.testing.assert_close(output, expected_output)
# 2D path: per-expert observers with F.linear
observers_2d = [
GPTQObserverTensor.from_hp(weight[e]) for e in range(num_experts)
]
for x in inputs:
for e in range(num_experts):
F.linear(x[e], observers_2d[e])

# Verify per-expert hessians match bitwise to calculating each expert's
# hessian individually
for e in range(num_experts):
assert torch.equal(observer_3d.hessian[e], observers_2d[e].hessian), (
f"Expert {e} hessian mismatch"
)
assert torch.equal(
observer_3d.total_batches[e : e + 1], observers_2d[e].total_batches
), f"Expert {e} total_batches mismatch"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
@pytest.mark.parametrize(
Expand Down Expand Up @@ -250,17 +242,20 @@ def test_observer_config_transform(self, base_config):
# Check hp_data matches original weight
torch.testing.assert_close(linear.weight.hp_data, original_weight)

# Check hessian is None initially
assert linear.weight.hessian is None
assert linear.weight.total_batches == 0
# Check hessian is initialized as zeros
assert torch.equal(
linear.weight.hessian,
torch.zeros(64, 64, dtype=torch.float32, device="cuda"),
)
assert (linear.weight.total_batches == 0).all()

# Perform a forward pass
input_tensor = torch.randn(4, 64, dtype=torch.float32, device="cuda")
output = linear(input_tensor)

# Check Hessian was initialized after forward pass
assert linear.weight.hessian is not None
assert linear.weight.total_batches == 1
assert (linear.weight.total_batches == 1).all()

# Check output shape
assert output.shape == (4, 32)
Expand Down Expand Up @@ -376,6 +371,9 @@ def test_unified_config_two_phase(self, base_config):
)
def test_gptq_quantize_function(self, base_config):
"""Test gptq_quantize function with synthetic Hessian and weights."""
if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100():
pytest.skip("int4 kernels do not work on sm100")

torch.manual_seed(42)

# Create synthetic weight matrix
Expand Down Expand Up @@ -548,6 +546,8 @@ def test_gptq_sqnr(self, base_config):
and not is_sm_at_least_100()
):
pytest.skip("CUDA capability >= 10.0 required for nvfp4")
if isinstance(base_config, Int4WeightOnlyConfig) and is_sm_at_least_100():
pytest.skip("int4 kernels do not work on sm100")

torch.manual_seed(43)

Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/gptq/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ def _gptq_config_transform(
)

# Validate that observations were recorded
if tensor.hessian is None:
if (tensor.total_batches == 0).any():
raise ValueError(
f"No observations recorded for {parameter_name}. "
f"Hessian is None. Did you run forward passes during the observe step?"
f"total_batches is 0. Did you run forward passes during the observe step?"
)

# Use pre-computed Hessian directly
Expand Down
2 changes: 1 addition & 1 deletion torchao/prototype/gptq/gptq_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ def skip_lm_head_o_proj(module, fqn):
import inspect

source = inspect.getsource(TorchAoHfQuantizer.get_weight_conversions)
if "_weight_per_tensor_scale" not in source:
if "per_tensor_scale" not in source:
raise RuntimeError(
"Your version of `transformers` does not support NVFP4 serialization. "
"Please install a version that includes "
Expand Down
Loading
Loading