diff --git a/auto_round/auto_scheme/delta_loss.py b/auto_round/auto_scheme/delta_loss.py index a7e2bdf90..935c78299 100644 --- a/auto_round/auto_scheme/delta_loss.py +++ b/auto_round/auto_scheme/delta_loss.py @@ -59,7 +59,7 @@ to_dtype, ) from auto_round.utils.device import MemoryMonitor -from auto_round.utils.offload import OffloadManager +from auto_round.utils.offload import OffloadManager, load_model_meta_skeleton, materialize_non_block_layers from auto_round.wrapper import WrapperLinear __all__ = ["gen_layer_config"] @@ -480,8 +480,8 @@ def backward_pre_hook(module, grad_input): for block_name in reversed(block_names): - # Retrieve stored inputs for the block - block_input_info = block_inputs.get(block_name, {}) + # Retrieve stored inputs for the block (pop to free memory immediately) + block_input_info = block_inputs.pop(block_name, {}) block_input_args = to_device(block_input_info.get("args", []), major_device) block_input_kwargs = to_device(block_input_info.get("kwargs", {}), major_device) @@ -1021,6 +1021,15 @@ def _gen_layer_config( force_mllm = is_vlm and any(any(marker in n.lower() for marker in vision_markers) for n in quant_layer_names) block_name = get_block_names(model)[0] # TODO need change to support vlm + + # When model was loaded as meta skeleton, materialize non-block layers + # from checkpoint now. Block weights stay as empty tensors and will be + # loaded on demand by OffloadManager hooks. + if offload_context is not None and model_name is not None: + _is_meta_skeleton = any(p.is_meta or p.numel() == 0 for p in model.parameters()) + if _is_meta_skeleton: + materialize_non_block_layers(model, model_name, block_name) + for name in block_name: module = get_module(model, name) module.in_block = True @@ -1312,6 +1321,7 @@ def gen_layer_config( ): model_name = None is_vlm = False + meta_skeleton_loaded = False if isinstance(model, str): model_name = model # Detect VLM (Qwen-VL / Qwen3-VL / LLaVA / etc.) and load via the MLLM @@ -1325,6 +1335,12 @@ def gen_layer_config( device="cpu", use_auto_mapping=False, ) + elif low_gpu_mem_usage and auto_scheme.low_cpu_mem_usage: + # Load model as meta skeleton (no real weights) to minimize peak RAM. + # Non-block layers will be materialized from checkpoint below; + # block weights are loaded on demand by OffloadManager hooks. + model, tokenizer, _ = load_model_meta_skeleton(model_name) + meta_skeleton_loaded = True else: # Load model on CPU only; do not apply automatic device map or tuning-aware placement at load time. model, tokenizer, _ = llm_load_model(model_name, device_map="cpu") @@ -1363,11 +1379,17 @@ def gen_layer_config( else: model = dispatch_model_by_all_available_devices(model, device_map) else: - model.to("cpu") - if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1: - import accelerate - - accelerate.hooks.remove_hook_from_submodules(model) + # Skip model.to("cpu") when model was loaded as meta skeleton -- + # non-block layers are already on CPU and block weights are empty tensors. + _is_meta_loaded = meta_skeleton_loaded + if not _is_meta_loaded: + model.to("cpu") + if hasattr(model, "hf_device_map"): + if _is_meta_loaded or len(model.hf_device_map) > 1: + import accelerate + + accelerate.hooks.remove_hook_from_submodules(model) + delattr(model, "hf_device_map") if (isinstance(device_map, str) and "," in device_map) or device_map == "auto": set_avg_auto_device_map(model, device_map) else: diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index c60bc2849..01cf15375 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -740,6 +740,17 @@ def _is_scoreable_layer(name: str) -> bool: if not self.enable_torch_compile and self.super_bits is None and not self.orig_scheme.low_gpu_mem_usage: logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") + + # When low_cpu_mem_usage is enabled, pass the model path (string) to + # AutoScheme so it can load a meta skeleton instead of keeping the full + # model in RAM. The loaded model is freed here and reloaded afterward. + _need_reload = False + _model_path = None + if self.orig_scheme.low_cpu_mem_usage and self.orig_scheme.low_gpu_mem_usage: + _model_path = getattr(self.model.config, "_name_or_path", None) + if _model_path is not None and os.path.isdir(_model_path): + _need_reload = True + self.scheme_generator = GenScheme( self.orig_scheme, self.model, @@ -751,7 +762,44 @@ def _is_scoreable_layer(name: str) -> bool: enable_torch_compile=self.enable_torch_compile, processor=getattr(self, "processor", None), ) + + if _need_reload: + # GenScheme.__init__ has computed avg bit ranges using the model. + # Now swap the model reference with the path string so that + # gen_layer_config will load a meta skeleton instead. + self.scheme_generator.model = _model_path + del self.model + self.model = None + import gc + + gc.collect() + clear_memory(device_list=self.device_list) + logger.info("Released loaded model before AutoScheme (will reload after)") + layer_config = self.scheme_generator.get_layer_config() + + if _need_reload: + logger.info("Reloading model after AutoScheme") + self.model, self.tokenizer = llm_load_model( + _model_path, + device="cpu", + trust_remote_code=self.trust_remote_code, + ) + self.model = self.model.eval() + check_and_mark_quantized_module(self.model) + # Re-apply module structure updates that quantize() applied before AutoScheme + formats = self.formats if hasattr(self, "formats") else None + if not self.diffusion and formats is not None: + self.model = update_module( + self.model, + formats=formats, + trust_remote_code=self.trust_remote_code, + cleanup_original=False, + ) + for n, m in self.model.named_modules(): + m.global_name = n + self.shared_cache_keys = get_shared_keys(self.model) + # Re-attach vision/audio-tower layers we peeled off earlier so the # downstream quantization pipeline sees the complete layer map. # NOTE: ``nontext_skipped_layers`` was populated *after* the first diff --git a/auto_round/utils/offload.py b/auto_round/utils/offload.py index 5e2d6aefd..d8eeb316b 100644 --- a/auto_round/utils/offload.py +++ b/auto_round/utils/offload.py @@ -57,7 +57,11 @@ from auto_round.logger import logger from auto_round.utils.model import get_module -__all__ = ["OffloadManager"] +__all__ = [ + "OffloadManager", + "load_model_meta_skeleton", + "materialize_non_block_layers", +] # ===================================================================== # Low-level helpers @@ -109,11 +113,15 @@ def _clear_module_weights( if module is None: return if hasattr(module, "orig_layer"): + # Delegate to the original layer inside the wrapper so that the + # real weight tensors get cleared. The wrapper's own tuning params + # are preserved because they won't be in restorable_params. + _clear_module_weights(module.orig_layer, cache_numel, restorable_params) return with torch.no_grad(): for name, param in list(module.named_parameters(recurse=False)): - if param is None or param.numel() == 0: + if param is None or (param.numel() == 0 and not param.is_meta): continue if restorable_params is not None and name not in restorable_params: continue @@ -126,7 +134,7 @@ def _clear_module_weights( torch.nn.Parameter(torch.empty(0, dtype=param.dtype, device="cpu"), requires_grad=param.requires_grad), ) for name, buf in list(module.named_buffers(recurse=False)): - if buf is None or buf.numel() == 0: + if buf is None or (buf.numel() == 0 and not buf.is_meta): continue if restorable_params is not None and name not in restorable_params: continue @@ -180,7 +188,69 @@ def _build_weight_map(model_dir: str) -> dict[str, str]: ) -def load_block_from_model_files(model_dir: str, block_name: str, block: torch.nn.Module) -> None: +def _load_layers_from_model_files(model_dir: str, layer_names: list[str], model: torch.nn.Module) -> None: + """Load specific layers' weights from the original model checkpoint. + + this function loads multiple individual layers + in a single pass, grouping shard reads for efficiency. + + Args: + model_dir: Path to the model directory. + layer_names: Full dotted module names (e.g. ``["model.embed_tokens", "model.norm"]``). + model: The root ``nn.Module``. + """ + model_dir = _resolve_model_dir(model_dir) + weight_map = _build_weight_map(model_dir) + + # Collect all matching tensor names + matching: dict[str, str] = {} # tensor_name -> shard_file + for tensor_name, shard_file in weight_map.items(): + for layer_name in layer_names: + if tensor_name == layer_name or tensor_name.startswith(layer_name + "."): + matching[tensor_name] = shard_file + break + if not matching: + return + + shard_to_tensors: dict[str, list[str]] = defaultdict(list) + for tensor_name, shard_file in matching.items(): + shard_to_tensors[shard_file].append(tensor_name) + + for shard_file, tensor_names in shard_to_tensors.items(): + shard_path = os.path.join(model_dir, shard_file) + if shard_file.endswith(".safetensors"): + from safetensors import safe_open + + with safe_open(shard_path, framework="pt", device="cpu") as f: + for name in tensor_names: + _set_tensor_in_model(model, name, f.get_tensor(name)) + else: + full_state = torch.load(shard_path, map_location="cpu") + for name in tensor_names: + if name in full_state: + _set_tensor_in_model(model, name, full_state[name]) + del full_state + + +def _set_tensor_in_model(model: torch.nn.Module, full_name: str, tensor: torch.Tensor) -> None: + """Set a single tensor into the model by its full dotted name.""" + parts = full_name.split(".") + target = model + for part in parts[:-1]: + target = getattr(target, part) + param_name = parts[-1] + old = getattr(target, param_name, None) + if isinstance(old, torch.nn.Parameter): + setattr( + target, + param_name, + torch.nn.Parameter(tensor.to(dtype=old.dtype), requires_grad=old.requires_grad), + ) + else: + setattr(target, param_name, tensor) + + +def _load_block_from_model_files(model_dir: str, block_name: str, block: torch.nn.Module) -> None: """Reload a module's weights directly from the original model checkpoint. Selectively loads only tensors belonging to *block_name* without loading @@ -223,6 +293,158 @@ def load_block_from_model_files(model_dir: str, block_name: str, block: torch.nn _load_state_dict_into_module(state_dict, block) +# ===================================================================== +# Meta-skeleton helpers (for low-CPU-memory AutoScheme) +# ===================================================================== + + +def load_model_meta_skeleton(model_name: str): + """Load a model as a meta-device skeleton (no real weight data in RAM). + + Uses HuggingFace ``from_pretrained`` with ``low_cpu_mem_usage=True`` and + ``device_map="meta"`` so that all parameters live on the meta device. + The model structure is fully intact but occupies near-zero CPU RAM. + + Returns: + (model, tokenizer, None) -- same signature as ``llm_load_model``. + """ + import re + + from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer + + is_glm = bool(re.search("chatglm", model_name.lower())) + model_cls = AutoModel if is_glm else AutoModelForCausalLM + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = model_cls.from_pretrained( + model_name, + torch_dtype="auto", + trust_remote_code=True, + low_cpu_mem_usage=True, + device_map="meta", + ) + model = model.eval() + logger.info("Loaded model as meta skeleton for low-memory AutoScheme") + return model, tokenizer, None + + +def materialize_non_block_layers(model, model_name, block_names): + """Materialize non-block layers from checkpoint into CPU memory. + + Block weights stay as cleared (empty) tensors and will be loaded + on demand by OffloadManager. Non-block modules (embeddings, norms, + lm_head, etc.) are loaded from the original checkpoint files and + moved to CPU. + + Computed buffers (e.g. ``inv_freq`` in rotary embeddings) that are not + stored in the checkpoint are re-created by re-initializing their parent + modules on CPU. + + Args: + model: The model with meta/empty tensors. + model_name: Path to the model checkpoint directory. + block_names: List of block name prefixes (e.g. ``["model.layers.0", ...]``). + """ + block_prefixes = tuple(bn + "." for bn in block_names) + + # Build weight map to distinguish checkpoint tensors from computed ones + model_dir = _resolve_model_dir(model_name) + weight_map = _build_weight_map(model_dir) + + with torch.no_grad(): + # Replace all meta PARAMETERS with empty CPU tensors. + # Block params will be loaded on demand by OffloadManager. + # Non-block params will be loaded from checkpoint below. + for name, param in list(model.named_parameters()): + if param.is_meta: + parts = name.split(".") + target = model + for part in parts[:-1]: + target = getattr(target, part) + param_attr = parts[-1] + # Cache numel/shape for block params so compute_layer_bits works + if any(name.startswith(bp) for bp in block_prefixes) and param_attr == "weight": + target._cached_weight_numel = param.numel() + target._cached_weight_shape = tuple(param.shape) + setattr( + target, + param_attr, + torch.nn.Parameter( + torch.empty(0, dtype=param.dtype, device="cpu"), + requires_grad=param.requires_grad, + ), + ) + + # Replace meta BUFFERS that ARE in the checkpoint with empty CPU tensors. + # Leave non-checkpoint meta buffers (computed in __init__) untouched for now; + # they will be re-created by reinit_computed_buffers below. + for name, buf in list(model.named_buffers()): + if buf.is_meta and name in weight_map: + parts = name.split(".") + target = model + for part in parts[:-1]: + target = getattr(target, part) + target.register_buffer(parts[-1], torch.empty(0, dtype=buf.dtype, device="cpu")) + + # Load non-block layers from checkpoint + non_block_layer_names = [] + for name in list(weight_map.keys()): + if not any(name.startswith(bp) for bp in block_prefixes): + module_name = ".".join(name.split(".")[:-1]) if "." in name else name + if module_name not in non_block_layer_names: + non_block_layer_names.append(module_name) + + if non_block_layer_names: + _load_layers_from_model_files(model_name, non_block_layer_names, model) + + # Re-establish tied weights + if hasattr(model, "tie_weights"): + model.tie_weights() + + # Re-create computed buffers (non-checkpoint meta buffers like inv_freq) + _reinit_computed_buffers(model) + + n_loaded = len(non_block_layer_names) + logger.info(f"Materialized {n_loaded} non-block layers from checkpoint " f"(block weights stay offloaded)") + + +def _reinit_computed_buffers(model): + """Re-create non-checkpoint meta buffers on CPU with correct values. + + These buffers were computed during ``__init__`` on meta device and + aren't stored in the checkpoint. For parameter-free modules (e.g. + rotary embeddings), we re-run ``__init__`` to recompute the buffers + on CPU. For others, we fall back to zeros with a warning. + """ + config = getattr(model, "config", None) + + for module_name, module in model.named_modules(): + meta_bufs = [(n, b) for n, b in module.named_buffers(recurse=False) if b.is_meta] + if not meta_bufs: + continue + + # If the module has no learnable parameters, it's safe to re-init + has_params = any(True for _ in module.parameters(recurse=False)) + if not has_params and config is not None: + try: + module.__class__.__init__(module, config=config) + # Verify all buffers are no longer meta + still_meta = any(b.is_meta for _, b in module.named_buffers(recurse=False)) + if not still_meta: + continue + except Exception: + pass + + # Fallback: replace remaining meta buffers with zeros + for buf_name, buf in meta_bufs: + if buf.is_meta: + logger.warning( + f"Could not recompute buffer {module_name}.{buf_name}, " + "using zeros (may slightly affect AutoScheme scoring accuracy)" + ) + module.register_buffer(buf_name, torch.zeros(buf.shape, dtype=buf.dtype, device="cpu")) + + # ===================================================================== # OffloadManager -- the offload manager class # ===================================================================== @@ -505,7 +727,7 @@ def _reload(self, model: torch.nn.Module, name: str) -> None: if self.model_dir is None: logger.warning("OffloadManager: model_dir is required for clean mode") return - load_block_from_model_files(self.model_dir, name, module) + _load_block_from_model_files(self.model_dir, name, module) # ------------------------------------------------------------------ # Hook-based transparent offloading @@ -813,12 +1035,12 @@ def _clear(self, module: torch.nn.Module, block_name: str | None = None) -> None @staticmethod def _needs_loading(module: torch.nn.Module) -> bool: - """Return *True* if any parameter in *module* has been cleared.""" + """Return *True* if any parameter in *module* has been cleared or is on meta device.""" for submodule in module.modules(): if hasattr(submodule, "orig_layer"): submodule = submodule.orig_layer for param in submodule.parameters(recurse=False): - if param is not None and param.numel() == 0: + if param is not None and (param.numel() == 0 or param.is_meta): return True return False diff --git a/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py index deb06aafc..74b22c3ff 100644 --- a/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py +++ b/test/test_cpu/schemes/test_auto_scheme_low_cpu_mem.py @@ -299,7 +299,7 @@ def test_clear_and_load_model_block(self, tiny_opt_model_path): """Test clearing and reloading an actual model block from checkpoint files.""" from transformers import AutoModelForCausalLM - from auto_round.utils.offload import load_block_from_model_files + from auto_round.utils.offload import _load_block_from_model_files model = AutoModelForCausalLM.from_pretrained(tiny_opt_model_path, torch_dtype=torch.float32) block_names = get_block_names(model)[0] @@ -322,7 +322,7 @@ def test_clear_and_load_model_block(self, tiny_opt_model_path): assert current_params < original_params # Load back from model files - load_block_from_model_files(tiny_opt_model_path, block_name, block) + _load_block_from_model_files(tiny_opt_model_path, block_name, block) restored_params = sum(p.numel() for p in block.parameters()) assert restored_params == original_params