Skip to content

Commit

Permalink
Update fusion_bench with new features and improvements
Browse files Browse the repository at this point in the history
- Updated README.md with new instructions
- Added new taskpool configuration for clip-vit-classification
- Updated documentation for fusion_bench, modelpool, and taskpool
- Added new dataset initialization file
- Modified dummy method in fusion_bench
- Updated base_pool in modelpool and taskpool
- Modified huggingface_clip_vision in modelpool
- Added new model file for hf_clip
- Updated CLI script for fusion_bench
- Added new task initialization file and base_task file
- Added new image_classification task
- Updated utils initialization file
- Added new devices utility file
- Updated parameters utility file
- Added new timer utility file
  • Loading branch information
tanganke committed May 15, 2024
1 parent 8241e1b commit 18c50bd
Show file tree
Hide file tree
Showing 22 changed files with 536 additions and 24 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

> Stay tuned. Working in progress.
Documentation is available at [tanganke.github.io/fusion_bench/](https://tanganke.github.io/fusion_bench/).

## Installation

```bash
Expand Down
37 changes: 37 additions & 0 deletions config/taskpool/clip-vit-classification_TA8.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
type: clip_vit_classification
dataset_type: huggingface_image_classification
tasks:
- name: sun397
dataset:
name: tanganke/sun397
split: test
- name: stanford_cars
dataset:
name: tanganke/stanford-cars
split: test
- name: resisc45
dataset:
name: tanganke/resisc45
split: test
- name: eurosat
dataset:
name: tanganke/eurosat
split: test
- name: svhn
dataset:
name: svhn
split: test
- name: gtsrb
dataset:
name: tanganke/gtsrb
split: test
- name: mnist
dataset:
name: mnist
split: test
- name: dtd
dataset:
name: tanganke/dtd
split: test

clip_model: openai/clip-vit-base-patch32
12 changes: 7 additions & 5 deletions docs/cli/fusion_bench.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ fusion_bench [--config-path CONFIG_PATH] [--config-name CONFIG_NAME] \

`fusion_bench` has the following options:

| **Option** | **Default** | **Description** |
| ------------ | ------------------------- | -------------------------------------------------- |
| method | `simple_average` | The fusion method to be used. |
| modelpool | `huggingface_clip_vision` | The pool of models to be fused. |
| print_config | `true` | Whether to print the configuration to the console. |
| **Option** | **Default** | **Description** |
| ------------ | ------------------------- | ---------------------------------------------------------------------------------- |
| method | `simple_average` | The fusion method to be used. |
| modelpool | `huggingface_clip_vision` | The pool of models to be fused. See [modelpool](/modelpool/) for more information. |
| print_config | `true` | Whether to print the configuration to the console. |

## Basic Examples

Expand All @@ -23,3 +23,5 @@ merge multiple CLIP models using simple averaging:
```bash
fusion_bench method=simple_average modelpool=clip-vit-base-patch32_TA8.yaml
```

::: fusion_bench.scripts.cli.run_model_fusion
13 changes: 13 additions & 0 deletions docs/modelpool/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,17 @@ For more complex model fusion techniques that requires data, the modelpool confi
We provide a list of modelpools that contain models trained on different datasets and with different architectures.
Each modelpool is described in a separate document.

## Basic Usage

The model is not loaded by default when you initialize a modelpool, you can load a model from a modelpool by calling the `load_model` method:

```python
model = modelpool.load_model('model_name')
```


## References

::: fusion_bench.modelpool.ModelPool

[^1]: AdaMerging: Adaptive Model Merging for Multi-Task Learning. http://arxiv.org/abs/2310.02575
4 changes: 4 additions & 0 deletions docs/taskpool/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ Each task in the taskpool is defined by a dataset and a metric.
A taskpool is specified by a `yaml` configuration file, which often contains the following fields:

- `type`: The type of the taskpool.
- `dataset_type`: The type of the dataset used in the tasks.
- `tasks`: A list of tasks, each task is dict with the following fields:
- `name`: The name of the task.
- `dataset`: The dataset used for the task.
- `metric`: The metric used to evaluate the performance of the model on the task.


::: fusion_bench.taskpool.TaskPool
23 changes: 23 additions & 0 deletions fusion_bench/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from datasets import load_dataset
from omegaconf import DictConfig, open_dict


def load_dataset_from_config(dataset_config: DictConfig):
"""
Load the dataset from the configuration.
"""
assert hasattr(dataset_config, "type"), "Dataset type not specified"
if dataset_config.type == "huggingface_image_classification":
if not hasattr(dataset_config, "path"):
with open_dict(dataset_config):
dataset_config.path = dataset_config.name

dataset = load_dataset(
dataset_config.path,
**(dataset_config.kwargs if hasattr(dataset_config, "kwargs") else {}),
)
if hasattr(dataset_config, "split"):
dataset = dataset[dataset_config.split]
return dataset
else:
raise ValueError(f"Unknown dataset type: {dataset_config.type}")
2 changes: 1 addition & 1 deletion fusion_bench/method/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def fuse(self, modelpool: ModelPool):
This method returns the pretrained model from the model pool.
If the pretrained model is not available, it returns the first model from the model pool.
Raiases:
Raises:
AssertionError: If the model is not found in the model pool.
"""
if "_pretrained_" in modelpool._model_names:
Expand Down
7 changes: 7 additions & 0 deletions fusion_bench/modelpool/base_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@


class ModelPool(ABC):
"""
This is the base class for all modelpools.
"""

models = {}

def __init__(self, modelpool_config: DictConfig):
Expand All @@ -20,6 +24,7 @@ def __init__(self, modelpool_config: DictConfig):
def model_names(self) -> List[str]:
"""
This property returns a list of model names from the configuration, excluding any names that start or end with an underscore.
To obtain all model names, including those starting or ending with an underscore, use the `_model_names` attribute.
Returns:
list: A list of model names.
Expand Down Expand Up @@ -60,6 +65,8 @@ def get_model_config(self, model_name: str):
@abstractmethod
def load_model(self, model_config: Union[str, DictConfig]):
"""
The models are load lazily, so this method should be implemented to load the model from the model pool.
Load the model from the model pool.
Args:
Expand Down
26 changes: 25 additions & 1 deletion fusion_bench/modelpool/huggingface_clip_vision.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import logging
from functools import cached_property

from omegaconf import DictConfig
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel

from fusion_bench.utils import timeit_context

from .base_pool import ModelPool

log = logging.getLogger(__name__)


class HuggingFaceClipVisionPool(ModelPool):
"""
A model pool for managing Hugging Face's CLIP Vision models.
This class extends the base `ModelPool` class and overrides its methods to handle
the specifics of the CLIP Vision models provided by the Hugging Face Transformers library.
"""

def __init__(self, modelpool_config: DictConfig):
super().__init__(modelpool_config)

Expand All @@ -18,8 +30,20 @@ def clip_processor(self):
return self._clip_processor

def load_model(self, model_config: str | DictConfig) -> CLIPVisionModel:
"""
Load a CLIP Vision model from the given configuration.
Args:
model_config (str | DictConfig): The configuration for the model to load.
Returns:
CLIPVisionModel: The loaded CLIP Vision model.
"""
if isinstance(model_config, str):
model_config = self.get_model_config(model_config)

vision_model = CLIPVisionModel.from_pretrained(model_config.path)
with timeit_context(
f"Loading CLIP vision model: '{model_config.name}' from '{model_config.path}'."
):
vision_model = CLIPVisionModel.from_pretrained(model_config.path)
return vision_model
90 changes: 90 additions & 0 deletions fusion_bench/models/hf_clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Callable, Iterable, List

import torch
from torch import Tensor, nn
from torch.types import _device
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, CLIPTextModel


default_templates = [
lambda c: f"a photo of a {c}",
]


class HFCLIPClassifier(nn.Module):
def __init__(
self,
clip_model: CLIPModel,
processor: CLIPProcessor,
):
super().__init__()
# we only fine-tune the vision model
clip_model.visual_projection.requires_grad_(False)
clip_model.text_model.requires_grad_(False)
clip_model.text_projection.requires_grad_(False)
clip_model.logit_scale.requires_grad_(False)

self.clip_model = clip_model
self.processor = processor
self.register_buffer(
"zeroshot_weights",
None,
persistent=False,
)

@property
def text_model(self):
return self.clip_model.text_model

@property
def vision_model(self):
return self.clip_model.vision_model

def set_classification_task(
self,
classnames: List[str],
templates: List[Callable[[str], str]] = default_templates,
):
processor = self.processor

self.classnames = classnames
self.templates = templates

with torch.no_grad():
zeroshot_weights = []
for classname in classnames:
text = [template(classname) for template in templates]
inputs = processor(text=text, return_tensors="pt", padding=True)

embeddings = self.text_model(**inputs)[1]
embeddings = self.clip_model.text_projection(embeddings)

# normalize embeddings
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

embeddings = embeddings.mean(dim=0)
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True)

zeroshot_weights.append(embeddings)

zeroshot_weights = torch.stack(zeroshot_weights, dim=0)

self.zeroshot_weights = zeroshot_weights

def forward(self, images):
if self.zeroshot_weights is None:
raise ValueError("Must set classification task before forward pass")
text_embeds = self.zeroshot_weights

image_embeds = self.vision_model(images)[1]
image_embeds = self.clip_model.visual_projection(image_embeds)

# normalize embeddings
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)

# cosine similarity
logit_scale = self.clip_model.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.t()

return logits_per_image
27 changes: 21 additions & 6 deletions fusion_bench/scripts/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,26 @@
from ..taskpool import load_taskpool


def run_model_fusion(cfg: DictConfig):
"""
Run the model fusion process based on the provided configuration.
1. This function loads a model pool and an model fusion algorithm based on the configuration.
2. It then uses the algorithm to fuse the models in the model pool into a single model.
3. If a task pool is specified in the configuration, it loads the task pool and uses it to evaluate the merged model.
"""
modelpool = load_modelpool(cfg.modelpool)

algorithm = load_algorithm(cfg.method)
merged_model = algorithm.fuse(modelpool)

if hasattr(cfg, "taskpool") and cfg.taskpool is not None:
taskpool = load_taskpool(cfg.taskpool)
taskpool.evaluate(merged_model)
else:
print("No task pool specified. Skipping evaluation.")


@hydra.main(
config_path=os.path.join(
importlib.import_module("fusion_bench").__path__[0], "../config"
Expand All @@ -35,12 +55,7 @@ def main(cfg: DictConfig) -> None:
)
)

modelpool = load_modelpool(cfg.modelpool)
algorithm = load_algorithm(cfg.method)
merged_model = algorithm.fuse(modelpool)

taskpool = load_taskpool(cfg.taskpool)
taskpool.evaluate(merged_model)
run_model_fusion(cfg)


if __name__ == "__main__":
Expand Down
7 changes: 6 additions & 1 deletion fusion_bench/taskpool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from omegaconf import DictConfig

from .dummy import DummyTaskPool
from .base_pool import TaskPool
from .clip_image_classification import CLIPImageClassificationTaskPool


def load_taskpool(taskpool_config: DictConfig):
if hasattr(taskpool_config, "type"):
if taskpool_config.type == 'dummy':
if taskpool_config.type == "dummy":
return DummyTaskPool(taskpool_config)
if taskpool_config.type == "clip_vit_classification":
return CLIPImageClassificationTaskPool(taskpool_config)
else:
raise ValueError(f"Unknown task pool type: {taskpool_config.type}")
else:
Expand Down
Loading

0 comments on commit 18c50bd

Please sign in to comment.