This is the source code to reproduce the experiments for "Localizing Task Information for Improved Model Merging and Compression" by Ke Wang*, Nikolaos Dimitriadis*, Guillermo Ortiz-Jimenez, Francois Fleuret, and Pascal Frossard.
Our paper identifies that the task-specific knowledge is preserved after mering, and proposed a method named TALL mask to localize them. Based on TALL mask, we proposed:
- a compression scheme which utilizes TALL mask to recover single-task fine-tuned performance for each task
- a merging algorithm which removes catastrophic and selfish weights to improve model merging performance
You can also check more information on the project website.
Our new work "LiNeS: Post-training Layer Scaling Prevents Forgetting and Enhances Model Merging" is accepted ICLR 2025. Checkout the github repo from LiNeS repo!
To run the code, please install all its dependencies:
conda env create
conda activate tall-masksWe provide the checkpoints, as well as the generated task-specific masks we used in the paper in this link. Alternatively, you can download the checkpoints and masks by running the following script:
# model options --model {ViT-B-32,ViT-L-14}
# kind options --kind {checkpoints,tall_masks}
# use python download_checkpoints.py --help for more information
python download_checkpoints.py --model='ViT-B-32' --kind=checkpointsThe script downloads all the checkpoints for one model corresponding to 40 files (finetuned checkpoint and classification head for 20 tasks). The script used the gdown package to download the files. If you encounter any issues, please refer to the gdown documentation. A common issue is that the download quota is exceeded, in which case you can download the files manually from the Google Drive folder or modify your local cookies file as described in the gdown documentation.
Alternatively, the checkpoints can be downloaded from the HuggingFace repo nik-dim/tall_masks. See the snapshot_download documentation for more details.
from huggingface_hub import snapshot_download
# download the ViT-B-32 checkpoints including backbone, classification heads and tall masks
snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*32*")
# download the ViT-B-16 checkpoints including backbone, classification heads and tall masks
snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*16*")
# download the ViT-L-14 checkpoints including backbone, classification heads and tall masks
snapshot_download(repo_id="nik-dim/tall_masks", allow_patterns="*14*")
# download everything
snapshot_download(repo_id="nik-dim/tall_masks")Most datasets being used should be downloaded automatically with torchvision or huggingface. For the datasets requiring manual preparation, please follow the instructions in this issue. Depending on the torchvision version, some issues might arise when downloading specific datasets like here or here. In this case, using a different torchvision version might solve the issue.
Below gives an example of pseudo-code to use TALL mask to localize the information in multi-task vector to reconstruct the individual checkpoints.
To create a task vector, you will need a pre-trained checkpoint and a fine-tuned checkpoint:
from task_vectors import TaskVector
task_vector_A = TaskVector(pretrained_checkpoint, finetuned_checkpoint_A)Create a multi-task vector:
multi_task_vector = task_vector_A + task_vector_B + task_vector_CConstruct tall mask:
tall_mask_A = task_vector_A.abs() > (multi_task_vector - task_vector_A).abs() * lambdaReconstruct fine-tuned model with tall mask:
# the reconstructed finetuned_checkpoint_A has near the same performance as original finetuned_checkpoint_A
reconstructed_finetuned_checkpoint_A = pretrained_checkpoint + multi_task_vector * tall_mask_AThe script finetune.py can be used to reproduce the training protocol we used to fine-tune our models on all our downstream tasks.
# Finetune on 2 GPUs
python finetune.py --model=ViT-B-32 --world-size=2 Evaluation is performed with Hydra, please modify model_location and data_location in config/config.yaml before evaluation.
# Evaluate with Task Arithmetic
python main.py model=ViT-B-32 method="sum"
# Evaluate with Ties-merging
python main.py model=ViT-B-32 method="ties" method.k=20# Evaluate with Tall mask + Task Arithmetic (load tall masks from storage)
python main.py model=ViT-B-32 method="tall_mask" method.load_mask=True
# Evaluate with Tall mask + Task Arithmetic (construct tall masks from scratch)
python main.py model=ViT-B-32 method="tall_mask"
# Evaluate with Tall mask + Ties-merging (load tall masks from storage)
python main.py model=ViT-B-32 method="tall_mask" method.use_ties=True method.load_mask=True
# Evaluate with Tall mask + Ties-merging (construct tall masks from scratch)
python main.py model=ViT-B-32 method="tall_mask" method.use_ties=True # Evaluate with Consensus Task Arithmetic
python main.py model=ViT-B-32 method="consensus" method.prun_thre_k=2
# Evaluate with Consensus Ties-merging
python main.py model=ViT-B-32 method="consensus" method.prun_thre_k=2 method.use_ties=TrueNote that you can set different number of tasks by setting num_tasks. Then, the first num_tasks are going to be selected from the list defined in src/utils/variables_and_paths.py. Alternatively, you can directly specify the tasks as a list of strings (e.g. DATASETS=[MNIST,Cars]). The results of the papers can be retrived by setting num_tasks to 8, 14 and 20 for the corresponding experiments.
You can evaluate the performance of the fine-tuned weights on each single task by running
# Evaluate pre-trained models.
python eval_single_task.py --model=ViT-B-32 --finetuning-mode=none
# Evaluate non-linearly fine-tuned models.
python eval_single_task.py --model=ViT-B-32 --finetuning-mode=standardThe results are saved in the results/ folder.
If you find this code useful, please cite the following paper:
@inproceedings{wang2024localizing,
title={Localizing Task Information for Improved Model Merging and Compression},
author={Wang, Ke and
Dimitriadis, Nikolaos and
Ortiz{-}Jim{\'{e}}nez, Guillermo and
Fleuret, Fran\c{c}ois and
Frossard, Pascal},
booktitle={International Conference on Machine Learning},
year={2024}
}