-
Notifications
You must be signed in to change notification settings - Fork 135
Continuously optimize AutoScheme RAM consumption #1703
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
ee19523
f19224e
f0d183c
fe0a541
3a9575f
68396c0
cae2d80
7f65d03
59f3639
dd52cd2
8073fa7
318b3b3
507f3ef
d8d332a
107485d
69cae58
1643ce1
26c7574
8bced5f
4c2238f
145847b
cc66be7
a4f9bf9
38ef946
c369070
d9e0f6a
9324bdf
4d99174
74594eb
66ed80d
f518956
8573308
2b47583
75325d2
a97e334
1295774
4c77a98
a7d01a2
82a7b99
bd935e4
330bd78
976f90d
a15b825
410a4e4
07e784b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -633,6 +633,17 @@ def _gen_auto_scheme(self) -> dict[str, dict]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.enable_torch_compile and self.super_bits is None and not self.orig_scheme.low_gpu_mem_usage: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.warning("we strongly recommend to set `enable_torch_compile` to True for AutoScheme to save VRAM") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # When low_cpu_mem_usage is enabled, pass the model path (string) to | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # AutoScheme so it can load a meta skeleton instead of keeping the full | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # model in RAM. The loaded model is freed here and reloaded afterward. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _need_reload = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _model_path = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if self.orig_scheme.low_cpu_mem_usage and self.orig_scheme.low_gpu_mem_usage: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _model_path = getattr(self.model.config, "_name_or_path", None) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _model_path is not None and os.path.isdir(_model_path): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _need_reload = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.scheme_generator = GenScheme( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.orig_scheme, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -643,7 +654,41 @@ def _gen_auto_scheme(self) -> dict[str, dict]: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| tokenizer=self.tokenizer, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| enable_torch_compile=self.enable_torch_compile, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _need_reload: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # GenScheme.__init__ has computed avg bit ranges using the model. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Now swap the model reference with the path string so that | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # gen_layer_config will load a meta skeleton instead. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.scheme_generator.model = _model_path | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| del self.model | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| import gc | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| gc.collect() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| clear_memory(device_list=self.device_list) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("Released loaded model before AutoScheme (will reload after)") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer_config = self.scheme_generator.get_layer_config() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if _need_reload: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| logger.info("Reloading model after AutoScheme") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model, self.tokenizer = llm_load_model( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| _model_path, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| device="cpu", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| trust_remote_code=self.trust_remote_code, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = self.model.eval() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| check_and_mark_quantized_module(self.model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Re-apply module structure updates that quantize() applied before AutoScheme | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| formats = self.formats if hasattr(self, "formats") else None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not self.diffusion and formats is not None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model = update_module( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.model, formats=formats, trust_remote_code=self.trust_remote_code, cleanup_original=False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for n, m in self.model.named_modules(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| m.global_name = n | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.shared_cache_keys = get_shared_keys(self.model) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
779
to
+801
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| layer_config = self.scheme_generator.get_layer_config() | |
| if _need_reload: | |
| logger.info("Reloading model after AutoScheme") | |
| self.model, self.tokenizer = llm_load_model( | |
| _model_path, | |
| device="cpu", | |
| trust_remote_code=self.trust_remote_code, | |
| ) | |
| self.model = self.model.eval() | |
| check_and_mark_quantized_module(self.model) | |
| # Re-apply module structure updates that quantize() applied before AutoScheme | |
| formats = self.formats if hasattr(self, "formats") else None | |
| if not self.diffusion and formats is not None: | |
| self.model = update_module( | |
| self.model, formats=formats, trust_remote_code=self.trust_remote_code, cleanup_original=False | |
| ) | |
| for n, m in self.model.named_modules(): | |
| m.global_name = n | |
| self.shared_cache_keys = get_shared_keys(self.model) | |
| try: | |
| layer_config = self.scheme_generator.get_layer_config() | |
| finally: | |
| if _need_reload: | |
| logger.info("Reloading model after AutoScheme") | |
| self.model, self.tokenizer = llm_load_model( | |
| _model_path, | |
| device="cpu", | |
| trust_remote_code=self.trust_remote_code, | |
| ) | |
| self.model = self.model.eval() | |
| check_and_mark_quantized_module(self.model) | |
| # Re-apply module structure updates that quantize() applied before AutoScheme | |
| formats = self.formats if hasattr(self, "formats") else None | |
| if not self.diffusion and formats is not None: | |
| self.model = update_module( | |
| self.model, formats=formats, trust_remote_code=self.trust_remote_code, cleanup_original=False | |
| ) | |
| for n, m in self.model.named_modules(): | |
| m.global_name = n | |
| self.shared_cache_keys = get_shared_keys(self.model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The low-CPU-memory reload path is gated on
os.path.isdir(_model_path). For many Hugging Face loads,config._name_or_pathis a repo id (not a directory), so this optimization will silently not activate. Consider resolving repo ids to a local snapshot directory (e.g., viahuggingface_hub.snapshot_download(local_files_only=True)or an existing helper) rather than requiring_name_or_pathto already be a local dir.