Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ def __init__(self, *args, **kwargs):
"--rotation_type",
default=None,
type=str,
choices=["hadamard", "random_hadamard"],
choices=["hadamard", "random_hadamard", "quarot_hadamard"],
help="Research feature: applies a rotation (e.g., Hadamard) to reduce activation/weight outliers",
Comment thread
wenhuach21 marked this conversation as resolved.
)
gguf = self.add_argument_group("Double Quant Arguments")
Expand Down
243 changes: 159 additions & 84 deletions auto_round/experimental/rotation_inplace/apply_rotation_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_resolve,
infer_mapping_from_model,
)
from auto_round.experimental.rotation_inplace.special_model_handler import apply_special_overrides
from auto_round.experimental.rotation_inplace.utils import (
CrossHeadOnlineHadamardHook,
FullOnlineHadamardHook,
Expand All @@ -41,6 +42,25 @@
# ---------------------------------------------------------------------------


def _resolve_head_dim(mapping, config, hidden_size, num_heads):
"""Resolve the per-head attention dimension.

Resolution order:
1. ``mapping.attn_head_dim`` (explicit override on the RotationMapping).
2. ``config.head_dim`` if present (Qwen-3 and other models declare an
explicit ``head_dim`` that does not necessarily equal
``hidden_size // num_heads``; e.g. Qwen3-32B has hidden=5120,
heads=64, head_dim=128 → o_proj.in_features = 8192, not 5120).
3. ``hidden_size // num_heads`` as a last-resort default.
"""
if mapping.attn_head_dim:
return mapping.attn_head_dim
cfg_head_dim = getattr(config, "head_dim", None)
if isinstance(cfg_head_dim, int) and cfg_head_dim > 0:
return cfg_head_dim
return hidden_size // num_heads


def _fuse_ln_linear(
layernorm: torch.nn.Module,
linear_layers: typing.Iterable[torch.nn.Linear],
Expand Down Expand Up @@ -69,32 +89,79 @@ def _reset_ln_params(layernorm: torch.nn.Module) -> None:
layernorm.bias.data.fill_(0.0)


def _rotate_weight_chunked(
weight: torch.Tensor,
Q: torch.Tensor,
side: str,
compute_device,
chunk: int = 4096,
) -> torch.Tensor:
"""Compute the rotated weight without ever materialising the full fp64 copy.

* ``side == 'input'`` → returns ``W @ Q`` (chunked over rows of ``W``).
* ``side == 'output'`` → returns ``Q^T @ W`` (chunked over columns of ``W``).

The output is pre-allocated in the **original** dtype on the **original**
device of ``weight``. At any moment only a single chunk lives in fp64 on
``compute_device``, so peak transient memory is roughly
``chunk * other_dim * 8`` bytes instead of ``W.numel() * 8``.

Embedding/lm_head on Qwen3-14B (151936 × 5120) drops from ~12 GB to a few
hundred MB transient.
"""
dtype = weight.dtype
dev = weight.device
out = torch.empty_like(weight)
Q_ = Q.to(device=compute_device, dtype=torch.float64)
try:
if side == "input":
# (R, C) @ (C, C) → (R, C); chunk over R.
R = weight.shape[0]
for i in range(0, R, chunk):
j = min(i + chunk, R)
blk = weight.data[i:j].to(device=compute_device, dtype=torch.float64)
rotated = (blk @ Q_).to(device=dev, dtype=dtype)
out[i:j].copy_(rotated)
del blk, rotated
elif side == "output":
# Q^T @ (R, C) → (R, C); chunk over C so each block is (R, chunk).
C = weight.shape[1]
Q_T = Q_.T.contiguous()
for i in range(0, C, chunk):
j = min(i + chunk, C)
blk = weight.data[:, i:j].to(device=compute_device, dtype=torch.float64)
rotated = (Q_T @ blk).to(device=dev, dtype=dtype)
out[:, i:j].copy_(rotated)
del blk, rotated
del Q_T
else:
raise ValueError(f"side must be 'input' or 'output', got {side!r}")
finally:
del Q_
return out


def _rotate_linear_by_Q(module: torch.nn.Linear, Q: torch.Tensor, side: str, compute_device=None) -> None:
"""Apply rotation *Q* to a Linear layer's weight (and bias if present).

Memory-efficient: never materialises the full fp64 weight at once.

Args:
side: ``'input'`` → W = W @ Q (rotate input side)
``'output'`` → W = Q^T @ W (rotate output side)
compute_device: Device to run computation on. If None, auto-detects GPU.
"""
dtype = module.weight.data.dtype
dev = module.weight.data.device
cdev = _resolve_compute_device(compute_device)
W_ = module.weight.data.to(device=cdev, dtype=torch.float64)
Q_ = Q.to(device=cdev)
if side == "input":
new_W = torch.matmul(W_, Q_).to(device=dev, dtype=dtype)
else:
new_W = torch.matmul(Q_.T, W_).to(device=dev, dtype=dtype)
# Release fp64 copy before assigning back so peak memory ≈ 1× weight + 1× rotated.
del W_
module.weight.data = new_W
module.weight.data = _rotate_weight_chunked(module.weight.data, Q, side, cdev)
if side == "output" and module.bias is not None:
dtype = module.bias.data.dtype
dev = module.bias.data.device
# Bias is a 1-D vector → small; safe to do in one shot.
b = module.bias.data.to(device=cdev, dtype=torch.float64)
Q_ = Q.to(device=cdev, dtype=torch.float64)
new_b = torch.matmul(Q_.T, b).to(device=dev, dtype=dtype)
del b
del b, Q_
module.bias.data = new_b
del Q_


def _untie_word_embeddings(model, mapping: RotationMapping) -> None:
Expand Down Expand Up @@ -255,7 +322,7 @@ def _rotate_weights(
hidden_size = getattr(config, mapping.hidden_size_attr)
intermediate_size = getattr(config, mapping.intermediate_size_attr)
num_heads = getattr(config, mapping.num_heads_attr)
head_dim = mapping.attn_head_dim or (hidden_size // num_heads)
head_dim = _resolve_head_dim(mapping, config, hidden_size, num_heads)

is_grouped = group_size is not None and group_size > 0
desc = f"Rotating (group_size={group_size})" if is_grouped else "Rotating"
Expand Down Expand Up @@ -308,13 +375,11 @@ def _online_had(dim):
embedding, group_size, use_fast_had=fused_fast, compute_device=compute_device, had_matrix=had_matrix
)
else:
dtype = embedding.weight.data.dtype
dev = embedding.weight.data.device
cdev = compute_device
W_ = embedding.weight.data.to(device=cdev, dtype=torch.float64)
new_W = torch.matmul(W_, Q.to(cdev)).to(device=dev, dtype=dtype)
del W_
embedding.weight.data = new_W
# Chunked: avoids a full fp64 copy of the (vocab, hidden) embedding,
# which on Qwen3-14B is ~6 GB on its own.
embedding.weight.data = _rotate_weight_chunked(
embedding.weight.data, Q, side="input", compute_device=compute_device
)

if mapping.positional_embedding is not None:
pos_emb = _resolve(model, mapping.positional_embedding)
Expand All @@ -323,13 +388,9 @@ def _online_had(dim):
pos_emb, group_size, use_fast_had=fused_fast, compute_device=compute_device, had_matrix=had_matrix
)
else:
pos_dtype = pos_emb.weight.data.dtype
pos_dev = pos_emb.weight.data.device
cdev = compute_device
P_ = pos_emb.weight.data.to(device=cdev, dtype=torch.float64)
new_P = torch.matmul(P_, Q.to(cdev)).to(device=pos_dev, dtype=pos_dtype)
del P_
pos_emb.weight.data = new_P
pos_emb.weight.data = _rotate_weight_chunked(
pos_emb.weight.data, Q, side="input", compute_device=compute_device
)

# ---- Top-level: lm_head ----
lm_head = _resolve(model, mapping.lm_head)
Expand Down Expand Up @@ -432,7 +493,18 @@ def _online_had(dim):
had_matrix=_online_had(intermediate_size),
)

# OV projection: v_proj per-head output + o_proj full/cross-head input
# OV projection: v_proj per-head output + o_proj decomposed input
#
# The online hook on o_proj applies (H_cross ⊗ I_head)⁻¹ at
# runtime, so the weight-side rotation must equal exactly
# (H_cross ⊗ I_head)(I_heads ⊗ H_head) = H_cross ⊗ H_head.
#
# IMPORTANT: we must NOT use a single full-dimension Hadamard
# (``had_dim=-1``) on o_proj, because the butterfly construction
# ``matmul_hadU(hidden_size)`` does NOT satisfy the Kronecker
# decomposition ``H_hidden = H_num_heads ⊗ H_head_dim`` when
# ``num_heads`` is not a power of 2 (e.g. Qwen3-14B, num_heads=40).
# Instead we always apply per-head + cross-head separately.
v_proj = _resolve(layer, mapping.attn_v)
o_proj = _resolve(layer, mapping.attn_o)
if is_grouped:
Expand All @@ -447,31 +519,22 @@ def _online_had(dim):
compute_device=compute_device,
had_matrix=online_head_had,
)
if preset == "random_hadamard":
apply_exact_had_to_linear(
o_proj,
had_dim=head_dim,
output=False,
use_fast_had=online_fast,
compute_device=compute_device,
had_matrix=online_head_had,
)
apply_cross_head_had_to_linear(
o_proj,
num_heads,
head_dim,
use_fast_had=online_fast,
compute_device=compute_device,
had_matrix=_online_had(num_heads),
)
else:
apply_exact_had_to_linear(
o_proj,
had_dim=-1,
output=False,
use_fast_had=online_fast,
compute_device=compute_device,
)
apply_exact_had_to_linear(
o_proj,
had_dim=head_dim,
output=False,
use_fast_had=online_fast,
compute_device=compute_device,
had_matrix=online_head_had,
)
apply_cross_head_had_to_linear(
o_proj,
num_heads,
head_dim,
use_fast_had=online_fast,
compute_device=compute_device,
had_matrix=_online_had(num_heads),
)

else:
# ---- unfused mode: no residual rotation, only input-side Had ----
Expand Down Expand Up @@ -608,7 +671,7 @@ def _register_online_hooks(
num_heads = getattr(config, mapping.num_heads_attr)
hidden_size = getattr(config, mapping.hidden_size_attr)
intermediate_size = getattr(config, mapping.intermediate_size_attr)
head_dim = mapping.attn_head_dim or (hidden_size // num_heads)
head_dim = _resolve_head_dim(mapping, config, hidden_size, num_heads)

is_grouped = group_size is not None and group_size > 0

Expand Down Expand Up @@ -767,6 +830,24 @@ def apply_rotation_transform(
fuse_online_to_weight = True
else:
fuse_online_to_weight = False

# ---- Model-specific overrides ----
# Some models require a specific rotation configuration to preserve
# accuracy or to run correctly. The mapping lives in
# ``special_model_handler.SPECIAL_MODEL_REGISTRY`` so that adding a new
# special-cased model is a one-liner there instead of a code change here.
_override_kwargs = {
"rotation_matrix": rotation_matrix,
"fuse_online_to_weight": fuse_online_to_weight,
"group_size": group_size,
"allow_online_rotation": allow_online_rotation,
}
apply_special_overrides(model, _override_kwargs)
rotation_matrix = _override_kwargs["rotation_matrix"]
fuse_online_to_weight = _override_kwargs["fuse_online_to_weight"]
group_size = _override_kwargs["group_size"]
allow_online_rotation = _override_kwargs["allow_online_rotation"]

had_dict, use_fast_had, preset = _normalize_rotation_matrix(rotation_matrix, group_size)
compute_device = _resolve_compute_device(compute_device)

Expand Down Expand Up @@ -836,42 +917,36 @@ def apply_rotation_transform(
if __name__ == "__main__":
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "/models/opt-125m"
model_name = "/models/Qwen3-14B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
model.to("cuda")
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")
apply_rotation_transform(
model, group_size=-1, allow_online_rotation=True, rotation_matrix="random_hadamard", fuse_online_to_weight=False
model, group_size=128, allow_online_rotation=True, rotation_matrix="hadamard", fuse_online_to_weight=True
)
Comment thread
wenhuach21 marked this conversation as resolved.
model.to("cuda")
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

model_name = "/models/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
apply_rotation_transform(model, group_size=-1, allow_online_rotation=True, fuse_online_to_weight=True)
model.to("cuda")
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "/models/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
apply_rotation_transform(model, fuse_online_to_weight=True, group_size=32)
model.to("cuda")
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
#
# model_name = "/models/Qwen3-8B"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
# apply_rotation_transform(model, group_size=-1, allow_online_rotation=True, fuse_online_to_weight=True)
# model.to("cuda")
# text = "There is a girl who likes adventure,"
# inputs = tokenizer(text, return_tensors="pt").to(model.device)
# print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
#
# from transformers import AutoModelForCausalLM, AutoTokenizer
#
# model_name = "/models/Meta-Llama-3.1-8B-Instruct"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
# apply_rotation_transform(model, fuse_online_to_weight=True, group_size=32)
# model.to("cuda")
# text = "There is a girl who likes adventure,"
# inputs = tokenizer(text, return_tensors="pt").to(model.device)
# print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
#
# model_name = "/models/Llama-2-7b-chat-hf"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
Expand Down
Loading
Loading