-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
771 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# This is a basic workflow to help you get started with Actions | ||
name: gh-deploy | ||
|
||
# Controls when the workflow will run | ||
on: | ||
# Triggers the workflow on push or pull request events but only for the master branch | ||
push: | ||
branches: [ main ] | ||
|
||
# Allows you to run this workflow manually from the Actions tab | ||
workflow_dispatch: | ||
|
||
# A workflow run is made up of one or more jobs that can run sequentially or in parallel | ||
jobs: | ||
# This workflow contains a single job called "build" | ||
build: | ||
# The type of runner that the job will run on | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v3 | ||
with: | ||
fetch-depth: 0 | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: "3.10" | ||
- run: pip install -r requirements.txt | ||
- run: pip install mkdocs mkdocs-material 'mkdocstrings[python]' | ||
- run: mkdocs gh-deploy --force --no-history --remote-branch gh-pages |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Simple Averaging | ||
|
||
Simple averaging is known in the literature as ModelSoups, aims to yield a more robust and generalizable model. | ||
|
||
In the context of full fine-tuned models, the weights are averaged directly. Concretely, this means that if we have $n$ models with their respective weights $\theta_i$, the weights of the final model $\theta$ are computed as: | ||
|
||
$$ \theta = \frac{1}{n} \sum_{i=1}^{n} \theta_i $$ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Task Arithmetic | ||
|
||
In the rapidly advancing field of machine learning, multi-task learning has emerged as a powerful paradigm, allowing models to leverage information from multiple tasks to improve performance and generalization. One intriguing method in this domain is Task Arithmetic, which involves the combination of task-specific vectors derived from model parameters. | ||
|
||
**Task Vector**. A task vector is used to encapsulate the adjustments needed by a model to specialize in a specific task. | ||
It is derived from the differences between a pre-trained model's parameters and those fine-tuned for a particular task. | ||
Formally, if $\theta_i$ represents the model parameters fine-tuned for the i-th task and $\theta_0$ denotes the parameters of the pre-trained model, the task vector for the i-th task is defined as: | ||
|
||
$$\tau_i = \theta_i - \theta_0$$ | ||
|
||
This representation is crucial for methods like Task Arithmetic, where multiple task vectors are aggregated and scaled to form a comprehensive multi-task model. | ||
|
||
**Task Arithmetic** begins by computing a task vector $\tau_i$ for each individual task, using the set of model parameters $\theta_0 \cup \{\theta_i\}_i$ where $\theta_0$ is the pre-trained model and $\theta_i$ are the fine-tuned parameters for i-th task. | ||
These task vectors are then aggregated to form a multi-task vector. | ||
Subsequently, the multi-task vector is combined with the pre-trained model parameters to obtain the final multi-task model. | ||
This process involves scaling the combined vector element-wise by a scaling coefficient (denoted as $\lambda$), before adding it to the initial pre-trained model parameters. | ||
The resulting formulation for obtaining a multi-task model is expressed as | ||
|
||
$$ \theta = \theta_0 + \lambda \sum_{i} \tau_i. $$ | ||
|
||
The choice of the scaling coefficient $\lambda$ plays a crucial role in the final model performance. Typically, $\lambda$ is chosen based on validation set performance. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# fusion_bench | ||
|
||
`fusion_bench` is the command line interface for running the benchmark. | ||
It takes a configuration file as input, which specifies the models, fusion method to be used, and the datasets to be evaluated. | ||
|
||
``` | ||
fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \ | ||
OPTION_1=VALUE_1 OPTION_2=VALUE_2 ... | ||
``` | ||
|
||
`fusion_bench` has the following options: | ||
|
||
| **Option** | **Default** | **Description** | | ||
| ------------ | ------------------------- | -------------------------------------------------- | | ||
| method | `simple_average` | The fusion method to be used. | | ||
| modelpool | `huggingface_clip_vision` | The pool of models to be fused. | | ||
| print_config | `true` | Whether to print the configuration to the console. | | ||
|
||
## Basic Examples | ||
|
||
merge multiple CLIP models using simple averaging: | ||
|
||
```bash | ||
fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml | ||
``` | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
/* Indentation. */ | ||
div.doc-contents:not(.first) { | ||
padding-left: 25px; | ||
border-left: .05rem solid var(--md-typeset-table-color); | ||
} | ||
|
||
/* Mark external links as such. */ | ||
a.external::after, | ||
a.autorefs-external::after { | ||
/* https://primer.style/octicons/arrow-up-right-24 */ | ||
mask-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18.25 15.5a.75.75 0 00.75-.75v-9a.75.75 0 00-.75-.75h-9a.75.75 0 000 1.5h7.19L6.22 16.72a.75.75 0 101.06 1.06L17.5 7.56v7.19c0 .414.336.75.75.75z"></path></svg>'); | ||
-webkit-mask-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18.25 15.5a.75.75 0 00.75-.75v-9a.75.75 0 00-.75-.75h-9a.75.75 0 000 1.5h7.19L6.22 16.72a.75.75 0 101.06 1.06L17.5 7.56v7.19c0 .414.336.75.75.75z"></path></svg>'); | ||
content: ' '; | ||
|
||
display: inline-block; | ||
vertical-align: middle; | ||
position: relative; | ||
|
||
height: 1em; | ||
width: 1em; | ||
background-color: var(--md-typeset-a-color); | ||
} | ||
|
||
a.external:hover::after, | ||
a.autorefs-external:hover::after { | ||
background-color: var(--md-accent-fg-color); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# FusionBench: A Comprehensive Benchmark of Deep Model Fusion | ||
|
||
!!! note | ||
|
||
Stay tuned. Working in progress. | ||
|
||
!!! note | ||
|
||
- Any questions or comments can be directed to the [GitHub Issues](https://github.com/tanganke/fusion_bench/issues) page for this project. | ||
- Any contributions or pull requests are welcome. | ||
|
||
|
||
|
||
|
||
## Introduction to Deep Model Fusion | ||
|
||
Deep model fusion is a technique that merges, ensemble, or fuse multiple deep neural networks to obtain a unified model. | ||
It can be used to improve the performance and rubustness of model or to combine the strengths of different models, such as fuse multiple task-specific models to create a multi-task model. | ||
For a more detailed introduction to deep model fusion, you can refer to [W. Li, 2023, 'Deep Model Fusion: A Survey'](http://arxiv.org/abs/2303.16203). | ||
In this benchmark, we evaluate the performance of different fusion methods on a variety of datasets and tasks. | ||
|
||
## Getting Started | ||
|
||
### Installation | ||
|
||
```bash | ||
# install from github repository | ||
git clone https://github.com/tanganke/fusion_bench.git | ||
cd fusion_bench | ||
|
||
pip install -e . # install the package in editable mode | ||
``` | ||
|
||
### Command Line Interface | ||
|
||
`fusion_bench` is the command line interface for running the benchmark. | ||
It takes a configuration file as input, which specifies the models, fusion method to be used, and the datasets to be evaluated. | ||
|
||
``` | ||
fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \ | ||
OPTION_1=VALUE_1 OPTION_2=VALUE_2 ... | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
window.MathJax = { | ||
tex: { | ||
inlineMath: [["\\(", "\\)"]], | ||
displayMath: [["\\[", "\\]"]], | ||
processEscapes: true, | ||
processEnvironments: true | ||
}, | ||
options: { | ||
ignoreHtmlClass: ".*|", | ||
processHtmlClass: "arithmatex" | ||
} | ||
}; | ||
|
||
document$.subscribe(() => { | ||
MathJax.typesetPromise() | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# CLIP-ViT Models for Open Vocabulary Image Classification | ||
|
||
Here we provides a list of CLIP-ViT models that are trained for open vocabulary image classification. | ||
|
||
Here is a simple footnote[^1]. With some additional text after it. | ||
|
||
[^1]: My reference, with further explanation and a [supporting link](https://website.com). | ||
|
||
Here is another footnote[^2]. | ||
|
||
[^2]: Another reference. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
# ModelPool | ||
|
||
A modelpool is a collection of models that are utilized in the process of model fusion. | ||
In the context of straightforward model fusion techniques, like averaging, only models with the same architecture are used. | ||
While for more complex methods, such as AdaMerging [^1], each model is paired with a unique set of unlabeled test data. This data is used during the test-time adaptation phase. | ||
|
||
A modelpool is specified by a `yaml` configuration file, which often contains the following fields: | ||
|
||
- `model_type`: The name of the modelpool. | ||
- `models`: A list of models, each model is dict with the following fields: | ||
- `name`: The name of the model. | ||
- `path`: The path to the model file. | ||
- `type`: The type of the model. If this field is not specified, the type is inferred from the `model_type`. | ||
|
||
For more complex model fusion techniques that requires data, the modelpool configuration file may also contain the following fields: | ||
|
||
- `dataset_type`: The type of the dataset used for training the models in the modelpool. | ||
- `datasets`: A list of datasets, each dataset is dict with the following fields: | ||
- `name`: The name of the dataset. | ||
- `path`: The path to the dataset file. | ||
- `type`: The type of the dataset. If this field is not specified, the type is inferred from the `dataset_type`. | ||
|
||
We provide a list of modelpools that contain models trained on different datasets and with different architectures. | ||
Each modelpool is described in a separate document. | ||
|
||
[^1]: AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from omegaconf import DictConfig | ||
from .simple_average import SimpleAverageAlgorithm | ||
|
||
|
||
def load_algorithm(method_config: DictConfig): | ||
if method_config.name == "simple_average": | ||
return SimpleAverageAlgorithm(method_config) | ||
else: | ||
raise ValueError(f"Unknown algorithm: {method_config.name}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class ModelFusionAlgorithm(ABC): | ||
@abstractmethod | ||
def fuse(self, modelpool): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from copy import deepcopy | ||
from typing import List, Mapping, Union | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
|
||
from ..utils.state_dict_arithmetic import state_dict_avg | ||
from ..utils.type import _StateDict | ||
from .base_algorithm import ModelFusionAlgorithm | ||
|
||
|
||
def simple_average(modules: List[Union[nn.Module, _StateDict]]): | ||
""" | ||
Averages the parameters of a list of PyTorch modules or state dictionaries. | ||
This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters. | ||
Args: | ||
modules (List[Union[nn.Module, _StateDict]]): A list of PyTorch modules or state dictionaries. | ||
Returns: | ||
module_or_state_dict (Union[nn.Module, _StateDict]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters. | ||
Examples: | ||
>>> import torch.nn as nn | ||
>>> model1 = nn.Linear(10, 10) | ||
>>> model2 = nn.Linear(10, 10) | ||
>>> averaged_model = simple_averageing([model1, model2]) | ||
>>> state_dict1 = model1.state_dict() | ||
>>> state_dict2 = model2.state_dict() | ||
>>> averaged_state_dict = simple_averageing([state_dict1, state_dict2]) | ||
""" | ||
if isinstance(modules[0], nn.Module): | ||
new_module = deepcopy(modules[0]) | ||
state_dict = state_dict_avg([module.state_dict() for module in modules]) | ||
new_module.load_state_dict(state_dict) | ||
return new_module | ||
elif isinstance(modules[0], Mapping): | ||
return state_dict_avg(modules) | ||
|
||
|
||
class SimpleAverageAlgorithm(ModelFusionAlgorithm): | ||
|
||
def fuse(self, modelpool): | ||
|
4 changes: 3 additions & 1 deletion
4
fusion_bench/model_pool/__init__.py → fusion_bench/modelpool/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.