Skip to content

Commit

Permalink
Add documentation for model fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed May 15, 2024
1 parent d62a0d4 commit 5cd11e4
Show file tree
Hide file tree
Showing 24 changed files with 771 additions and 11 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/gh-deploy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# This is a basic workflow to help you get started with Actions
name: gh-deploy

# Controls when the workflow will run
on:
# Triggers the workflow on push or pull request events but only for the master branch
push:
branches: [ main ]

# Allows you to run this workflow manually from the Actions tab
workflow_dispatch:

# A workflow run is made up of one or more jobs that can run sequentially or in parallel
jobs:
# This workflow contains a single job called "build"
build:
# The type of runner that the job will run on
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
with:
fetch-depth: 0
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- run: pip install -r requirements.txt
- run: pip install mkdocs mkdocs-material 'mkdocstrings[python]'
- run: mkdocs gh-deploy --force --no-history --remote-branch gh-pages
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# FusionBench: A Comprehensive Benchmark of Deep Model Fusion

> Stay tuned. Working in progress.
## Installation

```bash
Expand Down Expand Up @@ -30,8 +32,9 @@ fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \

`fusion_bench` has the following options:

| **Option** | **Default** | **Description** |
| ---------- | ------------------------- | ------------------------------- |
| method | `simple_average` | The fusion method to be used. |
| model_pool | `huggingface_clip_vision` | The pool of models to be fused. |
| **Option** | **Default** | **Description** |
| ------------ | ------------------------- | -------------------------------------------------- |
| method | `simple_average` | The fusion method to be used. |
| modelpool | `huggingface_clip_vision` | The pool of models to be fused. |
| print_config | `true` | Whether to print the configuration to the console. |

4 changes: 2 additions & 2 deletions config/modelpool/clip-vit-base-patch32_TA8.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
type: huggingface_clip_vision
model_type: huggingface_clip_vision
models:
- name: pretrained
- name: _pretrained_
path: tanganke/clip-vit-base-patch32
- name: sun397
path: tanganke/clip-vit-base-patch32_sun397
Expand Down
8 changes: 8 additions & 0 deletions docs/algorithms/simple_averaging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Simple Averaging

Simple averaging is known in the literature as ModelSoups, aims to yield a more robust and generalizable model.

In the context of full fine-tuned models, the weights are averaged directly. Concretely, this means that if we have $n$ models with their respective weights $\theta_i$, the weights of the final model $\theta$ are computed as:

$$ \theta = \frac{1}{n} \sum_{i=1}^{n} \theta_i $$

22 changes: 22 additions & 0 deletions docs/algorithms/task_arithmetic.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Task Arithmetic

In the rapidly advancing field of machine learning, multi-task learning has emerged as a powerful paradigm, allowing models to leverage information from multiple tasks to improve performance and generalization. One intriguing method in this domain is Task Arithmetic, which involves the combination of task-specific vectors derived from model parameters.

**Task Vector**. A task vector is used to encapsulate the adjustments needed by a model to specialize in a specific task.
It is derived from the differences between a pre-trained model's parameters and those fine-tuned for a particular task.
Formally, if $\theta_i$ represents the model parameters fine-tuned for the i-th task and $\theta_0$ denotes the parameters of the pre-trained model, the task vector for the i-th task is defined as:

$$\tau_i = \theta_i - \theta_0$$

This representation is crucial for methods like Task Arithmetic, where multiple task vectors are aggregated and scaled to form a comprehensive multi-task model.

**Task Arithmetic** begins by computing a task vector $\tau_i$ for each individual task, using the set of model parameters $\theta_0 \cup \{\theta_i\}_i$ where $\theta_0$ is the pre-trained model and $\theta_i$ are the fine-tuned parameters for i-th task.
These task vectors are then aggregated to form a multi-task vector.
Subsequently, the multi-task vector is combined with the pre-trained model parameters to obtain the final multi-task model.
This process involves scaling the combined vector element-wise by a scaling coefficient (denoted as $\lambda$), before adding it to the initial pre-trained model parameters.
The resulting formulation for obtaining a multi-task model is expressed as

$$ \theta = \theta_0 + \lambda \sum_{i} \tau_i. $$

The choice of the scaling coefficient $\lambda$ plays a crucial role in the final model performance. Typically, $\lambda$ is chosen based on validation set performance.

27 changes: 27 additions & 0 deletions docs/cli/fusion_bench.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# fusion_bench

`fusion_bench` is the command line interface for running the benchmark.
It takes a configuration file as input, which specifies the models, fusion method to be used, and the datasets to be evaluated.

```
fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \
OPTION_1=VALUE_1 OPTION_2=VALUE_2 ...
```

`fusion_bench` has the following options:

| **Option** | **Default** | **Description** |
| ------------ | ------------------------- | -------------------------------------------------- |
| method | `simple_average` | The fusion method to be used. |
| modelpool | `huggingface_clip_vision` | The pool of models to be fused. |
| print_config | `true` | Whether to print the configuration to the console. |

## Basic Examples

merge multiple CLIP models using simple averaging:

```bash
fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml
```


27 changes: 27 additions & 0 deletions docs/css/mkdocstrings.css
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Indentation. */
div.doc-contents:not(.first) {
padding-left: 25px;
border-left: .05rem solid var(--md-typeset-table-color);
}

/* Mark external links as such. */
a.external::after,
a.autorefs-external::after {
/* https://primer.style/octicons/arrow-up-right-24 */
mask-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18.25 15.5a.75.75 0 00.75-.75v-9a.75.75 0 00-.75-.75h-9a.75.75 0 000 1.5h7.19L6.22 16.72a.75.75 0 101.06 1.06L17.5 7.56v7.19c0 .414.336.75.75.75z"></path></svg>');
-webkit-mask-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M18.25 15.5a.75.75 0 00.75-.75v-9a.75.75 0 00-.75-.75h-9a.75.75 0 000 1.5h7.19L6.22 16.72a.75.75 0 101.06 1.06L17.5 7.56v7.19c0 .414.336.75.75.75z"></path></svg>');
content: ' ';

display: inline-block;
vertical-align: middle;
position: relative;

height: 1em;
width: 1em;
background-color: var(--md-typeset-a-color);
}

a.external:hover::after,
a.autorefs-external:hover::after {
background-color: var(--md-accent-fg-color);
}
43 changes: 43 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# FusionBench: A Comprehensive Benchmark of Deep Model Fusion

!!! note

Stay tuned. Working in progress.

!!! note

- Any questions or comments can be directed to the [GitHub Issues](https://github.com/tanganke/fusion_bench/issues) page for this project.
- Any contributions or pull requests are welcome.




## Introduction to Deep Model Fusion

Deep model fusion is a technique that merges, ensemble, or fuse multiple deep neural networks to obtain a unified model.
It can be used to improve the performance and rubustness of model or to combine the strengths of different models, such as fuse multiple task-specific models to create a multi-task model.
For a more detailed introduction to deep model fusion, you can refer to [W. Li, 2023, 'Deep Model Fusion: A Survey'](http://arxiv.org/abs/2303.16203).
In this benchmark, we evaluate the performance of different fusion methods on a variety of datasets and tasks.

## Getting Started

### Installation

```bash
# install from github repository
git clone https://github.com/tanganke/fusion_bench.git
cd fusion_bench

pip install -e . # install the package in editable mode
```

### Command Line Interface

`fusion_bench` is the command line interface for running the benchmark.
It takes a configuration file as input, which specifies the models, fusion method to be used, and the datasets to be evaluated.

```
fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \
OPTION_1=VALUE_1 OPTION_2=VALUE_2 ...
```

16 changes: 16 additions & 0 deletions docs/javascripts/mathjax.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
window.MathJax = {
tex: {
inlineMath: [["\\(", "\\)"]],
displayMath: [["\\[", "\\]"]],
processEscapes: true,
processEnvironments: true
},
options: {
ignoreHtmlClass: ".*|",
processHtmlClass: "arithmatex"
}
};

document$.subscribe(() => {
MathJax.typesetPromise()
})
11 changes: 11 additions & 0 deletions docs/modelpool/clip_vit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# CLIP-ViT Models for Open Vocabulary Image Classification

Here we provides a list of CLIP-ViT models that are trained for open vocabulary image classification.

Here is a simple footnote[^1]. With some additional text after it.

[^1]: My reference, with further explanation and a [supporting link](https://website.com).

Here is another footnote[^2].

[^2]: Another reference.
26 changes: 26 additions & 0 deletions docs/modelpool/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# ModelPool

A modelpool is a collection of models that are utilized in the process of model fusion.
In the context of straightforward model fusion techniques, like averaging, only models with the same architecture are used.
While for more complex methods, such as AdaMerging [^1], each model is paired with a unique set of unlabeled test data. This data is used during the test-time adaptation phase.

A modelpool is specified by a `yaml` configuration file, which often contains the following fields:

- `model_type`: The name of the modelpool.
- `models`: A list of models, each model is dict with the following fields:
- `name`: The name of the model.
- `path`: The path to the model file.
- `type`: The type of the model. If this field is not specified, the type is inferred from the `model_type`.

For more complex model fusion techniques that requires data, the modelpool configuration file may also contain the following fields:

- `dataset_type`: The type of the dataset used for training the models in the modelpool.
- `datasets`: A list of datasets, each dataset is dict with the following fields:
- `name`: The name of the dataset.
- `path`: The path to the dataset file.
- `type`: The type of the dataset. If this field is not specified, the type is inferred from the `dataset_type`.

We provide a list of modelpools that contain models trained on different datasets and with different architectures.
Each modelpool is described in a separate document.

[^1]: AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
9 changes: 9 additions & 0 deletions fusion_bench/method/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from omegaconf import DictConfig
from .simple_average import SimpleAverageAlgorithm


def load_algorithm(method_config: DictConfig):
if method_config.name == "simple_average":
return SimpleAverageAlgorithm(method_config)
else:
raise ValueError(f"Unknown algorithm: {method_config.name}")
7 changes: 7 additions & 0 deletions fusion_bench/method/base_algorithm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from abc import ABC, abstractmethod


class ModelFusionAlgorithm(ABC):
@abstractmethod
def fuse(self, modelpool):
pass
46 changes: 46 additions & 0 deletions fusion_bench/method/simple_average.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from copy import deepcopy
from typing import List, Mapping, Union

import torch
from torch import Tensor, nn

from ..utils.state_dict_arithmetic import state_dict_avg
from ..utils.type import _StateDict
from .base_algorithm import ModelFusionAlgorithm


def simple_average(modules: List[Union[nn.Module, _StateDict]]):
"""
Averages the parameters of a list of PyTorch modules or state dictionaries.
This function takes a list of PyTorch modules or state dictionaries and returns a new module with the averaged parameters, or a new state dictionary with the averaged parameters.
Args:
modules (List[Union[nn.Module, _StateDict]]): A list of PyTorch modules or state dictionaries.
Returns:
module_or_state_dict (Union[nn.Module, _StateDict]): A new PyTorch module with the averaged parameters, or a new state dictionary with the averaged parameters.
Examples:
>>> import torch.nn as nn
>>> model1 = nn.Linear(10, 10)
>>> model2 = nn.Linear(10, 10)
>>> averaged_model = simple_averageing([model1, model2])
>>> state_dict1 = model1.state_dict()
>>> state_dict2 = model2.state_dict()
>>> averaged_state_dict = simple_averageing([state_dict1, state_dict2])
"""
if isinstance(modules[0], nn.Module):
new_module = deepcopy(modules[0])
state_dict = state_dict_avg([module.state_dict() for module in modules])
new_module.load_state_dict(state_dict)
return new_module
elif isinstance(modules[0], Mapping):
return state_dict_avg(modules)


class SimpleAverageAlgorithm(ModelFusionAlgorithm):

def fuse(self, modelpool):

Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from omegaconf import DictConfig

from .base_pool import ModelPool
from .huggingface_clip_vision import HuggingFaceClipVisionPool


def load_model_pool(config: DictConfig):
def load_modelpool(config: DictConfig):
if hasattr(config, "type"):
if config.type == "huggingface_clip_vision":
return HuggingFaceClipVisionPool(config)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from omegaconf import DictConfig


class BasePool(ABC):
class ModelPool(ABC):
models = {}

def __init__(self, modelpool_config: DictConfig):
Expand All @@ -15,6 +15,15 @@ def __init__(self, modelpool_config: DictConfig):
assert len(model_names) == len(set(model_names))
self.model_names = model_names

@property
def model_names(self):
names = [
model["name"]
for model in self.config["models"]
if model["name"][0] != "_" and model["name"][-1] != "_"
]
return names

def get_model_config(self, model_name: str):
"""
Retrieves the configuration for a specific model from the model pool.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from omegaconf import DictConfig
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel

from .base_pool import BasePool
from .base_pool import ModelPool


class HuggingFaceClipVisionPool(BasePool):
class HuggingFaceClipVisionPool(ModelPool):
def __init__(self, modelpool_config: DictConfig):
super().__init__(modelpool_config)

Expand Down
10 changes: 9 additions & 1 deletion fusion_bench/scripts/cli.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import importlib
import importlib.resources
import os

import hydra
from omegaconf import DictConfig, OmegaConf
from rich import print as rich_print
from rich.syntax import Syntax
import importlib

from ..modelpool import load_modelpool
from ..method import load_algorithm


@hydra.main(
Expand All @@ -25,6 +29,10 @@ def main(cfg: DictConfig) -> None:
)
)

modelpool = load_modelpool(cfg.modelpool)
algorithm = load_algorithm(cfg.method)
merged_model = algorithm.fuse(modelpool)


if __name__ == "__main__":
main()
Loading

0 comments on commit 5cd11e4

Please sign in to comment.