Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ def __init__(self, *args, **kwargs):
"Higher values may speed up tuning but require more memory. "
"Recommended to keep at 1 for stability with large models.",
)
tuning.add_argument(
"--nblocks_overlap",
default=0,
type=int,
help="Number of overlapping blocks between adjacent nblocks windows. "
"For CBQ-style CBD, use --nblocks 2 --nblocks_overlap 1.",
)
tuning.add_argument(
"--scale_dtype",
default=None,
Expand Down Expand Up @@ -703,6 +710,7 @@ def tune(args):
lr=args.lr,
minmax_lr=args.minmax_lr,
nblocks=args.nblocks,
nblocks_overlap=args.nblocks_overlap,
to_quant_block_names=args.to_quant_block_names,
scale_dtype=args.scale_dtype,
)
Expand Down
10 changes: 10 additions & 0 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(
kwargs.pop("vlm", None)
amp = kwargs.pop("amp", True)
nblocks = kwargs.pop("nblocks", 1)
nblocks_overlap = kwargs.pop("nblocks_overlap", 0)
disable_deterministic_algorithms = kwargs.pop("disable_deterministic_algorithms", True)
enable_deterministic_algorithms = kwargs.pop("enable_deterministic_algorithms", False)

Expand Down Expand Up @@ -270,6 +271,15 @@ def __init__(
set_seed(self.seed)

self.nblocks = nblocks
self.nblocks_overlap = nblocks_overlap
if self.nblocks <= 0:
raise ValueError("`nblocks` must be positive")
if self.nblocks == 1:
self.nblocks_overlap = 0
if self.nblocks_overlap < 0:
raise ValueError("`nblocks_overlap` must be non-negative")
if self.nblocks_overlap >= self.nblocks:
raise ValueError("`nblocks_overlap` must be smaller than `nblocks`")

self.enable_torch_compile = enable_torch_compile

Expand Down
3 changes: 3 additions & 0 deletions auto_round/compressors/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
lr_scheduler: Callable = None,
minmax_lr: float = None,
nblocks: int = 1,
nblocks_overlap: int = 0,
to_quant_block_names: Union[str, list, None] = None,
scale_dtype: str = "fp16",
Comment on lines 41 to 47
# scheme
Expand Down Expand Up @@ -122,6 +123,7 @@ def __init__(
lr_scheduler=lr_scheduler,
minmax_lr=minmax_lr,
nblocks=nblocks,
nblocks_overlap=nblocks_overlap,
to_quant_block_names=to_quant_block_names,
scale_dtype=scale_dtype,
)
Expand Down Expand Up @@ -257,6 +259,7 @@ class TuningExtraConfig(BaseExtraConfig):
lr_scheduler: Callable = None
minmax_lr: float = None
nblocks: int = 1
nblocks_overlap: int = 0
to_quant_block_names: Union[str, list, None] = None
scale_dtype: str = "fp16"

Expand Down
119 changes: 110 additions & 9 deletions auto_round/compressors/data_driven.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,77 @@ def quantize_block(
finally:
self.model_context.is_mllm = orig_is_mllm

@staticmethod
def _clone_overlap_tail_value(value):
if torch.is_tensor(value):
return value.detach().cpu().clone()
return copy.deepcopy(value)

def _snapshot_overlap_tail_state(self, block: torch.nn.Module, advance_blocks: int):
if advance_blocks <= 0 or not isinstance(block, WrapperMultiblock):
return None
tail_modules = list(block.layers[advance_blocks:])
if len(tail_modules) == 0:
return None

tracked_attrs = (
"imatrix",
"imatrix_cnt",
"scale",
"zp",
"w_scale",
"w_zp",
"w_wmin",
"w_d_scale",
"w_d_wmin",
"weight_global_scale",
)
snapshots = []
for module in tail_modules:
state = {key: value.detach().cpu().clone() for key, value in module.state_dict().items()}
attrs = {}
for sub_name, sub_module in module.named_modules():
sub_attrs = {}
for attr_name in tracked_attrs:
if hasattr(sub_module, attr_name):
sub_attrs[attr_name] = (
True,
self._clone_overlap_tail_value(getattr(sub_module, attr_name)),
)
else:
sub_attrs[attr_name] = (False, None)
attrs[sub_name] = sub_attrs
snapshots.append((module, state, attrs))
return snapshots

def _restore_overlap_tail_state(self, snapshots):
if snapshots is None:
return
for module, state, attrs in snapshots:
module.load_state_dict(state, strict=False)
named_modules = dict(module.named_modules())
for sub_name, sub_attrs in attrs.items():
sub_module = named_modules.get(sub_name)
if sub_module is None:
continue
for attr_name, (existed, value) in sub_attrs.items():
if existed:
setattr(sub_module, attr_name, self._clone_overlap_tail_value(value))
elif hasattr(sub_module, attr_name):
delattr(sub_module, attr_name)

def _get_overlap_advance_block(self, block: torch.nn.Module, advance_blocks: int) -> Optional[torch.nn.Module]:
if advance_blocks <= 0:
return None
if isinstance(block, WrapperMultiblock):
modules = list(block.layers[:advance_blocks])
if len(modules) == 1:
return modules[0]
return WrapperMultiblock(modules)
if advance_blocks == 1:
return block
return None

def _quantize_blocks(
self,
model: torch.nn.Module,
Expand Down Expand Up @@ -457,15 +528,18 @@ def _quantize_blocks(
for k in extra_keys:
input_others[k] = input_ids.pop(k)

overlap = self.nblocks_overlap if nblocks > 1 else 0
stride = nblocks - overlap
block_starts = self._get_block_window_starts(block_names, nblocks)
if pbar is None:
pbar = tqdm(range(0, len(block_names), nblocks))
pbar = tqdm(block_starts)

for i in range(0, len(block_names), nblocks):
for step_idx, i in enumerate(block_starts):
if input_others_extra_blocks and block_names[i] in input_others_extra_blocks:
input_others = input_others_extra_blocks[block_names[i]]
_, input_others = self._preprocess_block_inputs(input_others)
input_others_extra_blocks.pop(block_names[i])
if i != 0:
if step_idx != 0:
pbar.update(1)
if nblocks == 1:
n = block_names[i]
Expand All @@ -476,6 +550,7 @@ def _quantize_blocks(
pbar.set_description(f"Quantizing [{i + 1}-{min(i + nblocks, len(block_names))}]/{len(block_names)}")
modules = [get_module(model, n) for n in names]
m = WrapperMultiblock(modules)
overlap_advance = stride if overlap > 0 and step_idx + 1 < len(block_starts) else 0

if self.compress_context.low_cpu_mem_usage:
if nblocks == 1:
Expand Down Expand Up @@ -537,6 +612,19 @@ def _quantize_blocks(
for h in hook_handles:
h.remove()

overlap_input = None
overlap_block = self._get_overlap_advance_block(m, overlap_advance)
if overlap_block is not None:
overlap_input = self.quantizer._get_block_outputs(
overlap_block,
input_ids,
input_others,
bs,
device_override=loss_device,
)

overlap_tail_state = self._snapshot_overlap_tail_state(m, overlap_advance)

# ── Infrastructure: swap q_input ──────────────────────────────────
if q_input is not None:
if input_ids is not q_input:
Expand All @@ -555,14 +643,16 @@ def _quantize_blocks(
loss_device=loss_device,
mid_iter_mem_check=mid_iter_mem_check,
)
self._restore_overlap_tail_state(overlap_tail_state)

# ── MoE scale alignment for FP8 dispatch efficiency ────────────────
if is_nv_fp(self.quantizer.act_data_type) or is_static_wfp8afp8(self.quantizer):
set_amax_for_all_moe_layers(m, attr_name="act_max")

# ── Infrastructure: collect q_outputs if needed ───────────────────
if self.quantizer.enable_quanted_input:
q_input = self.quantizer._get_block_outputs(m, input_ids, input_others, bs)
q_output_block = self._get_overlap_advance_block(m, overlap_advance) or m
q_input = self.quantizer._get_block_outputs(q_output_block, input_ids, input_others, bs)
else:
q_input = None

Expand All @@ -577,7 +667,7 @@ def _quantize_blocks(
# from the current block's reference output, while q_input (when
# enabled) is only used as the quantized-input companion for the
# next block.
next_input_ids = reference_output
next_input_ids = overlap_input if overlap_input is not None else reference_output
clear_memory(
input_ids if input_ids is not next_input_ids else None, device_list=self.compress_context.device_list
)
Expand Down Expand Up @@ -618,6 +708,19 @@ def _quantize_blocks(

clear_memory(device_list=self.compress_context.device_list)

def _get_block_window_starts(self, block_names: list, nblocks: int) -> list[int]:
overlap = self.nblocks_overlap if nblocks > 1 else 0
stride = nblocks - overlap
block_starts = []
block_idx = 0
while block_idx < len(block_names):
remaining = len(block_names) - block_idx
if block_idx > 0 and overlap > 0 and remaining <= overlap:
break
block_starts.append(block_idx)
block_idx += stride
return block_starts
Comment on lines +765 to +776

def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
"""Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound.
Returns:
Expand Down Expand Up @@ -695,10 +798,8 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
self.compress_context.low_cpu_mem_usage = False
else:
self.compress_context.low_cpu_mem_usage = False
if len(all_blocks) > 1:
pbar = tqdm(range(0, sum([len(i) for i in all_blocks]), self.nblocks))
else:
pbar = tqdm(range(0, len(all_blocks[0]), self.nblocks)) # move the alg warning outside pbar
block_window_count = sum(len(self._get_block_window_starts(block, self.nblocks)) for block in all_blocks)
pbar = tqdm(range(block_window_count)) # move the alg warning outside pbar

start_time = time.time()
for block_names in all_blocks:
Expand Down