-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Update fusion_bench with new features and improvements
- Updated README.md with new instructions - Added new taskpool configuration for clip-vit-classification - Updated documentation for fusion_bench, modelpool, and taskpool - Added new dataset initialization file - Modified dummy method in fusion_bench - Updated base_pool in modelpool and taskpool - Modified huggingface_clip_vision in modelpool - Added new model file for hf_clip - Updated CLI script for fusion_bench - Added new task initialization file and base_task file - Added new image_classification task - Updated utils initialization file - Added new devices utility file - Updated parameters utility file - Added new timer utility file
- Loading branch information
Showing
22 changed files
with
536 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
type: clip_vit_classification | ||
dataset_type: huggingface_image_classification | ||
tasks: | ||
- name: sun397 | ||
dataset: | ||
name: tanganke/sun397 | ||
split: test | ||
- name: stanford_cars | ||
dataset: | ||
name: tanganke/stanford-cars | ||
split: test | ||
- name: resisc45 | ||
dataset: | ||
name: tanganke/resisc45 | ||
split: test | ||
- name: eurosat | ||
dataset: | ||
name: tanganke/eurosat | ||
split: test | ||
- name: svhn | ||
dataset: | ||
name: svhn | ||
split: test | ||
- name: gtsrb | ||
dataset: | ||
name: tanganke/gtsrb | ||
split: test | ||
- name: mnist | ||
dataset: | ||
name: mnist | ||
split: test | ||
- name: dtd | ||
dataset: | ||
name: tanganke/dtd | ||
split: test | ||
|
||
clip_model: openai/clip-vit-base-patch32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from datasets import load_dataset | ||
from omegaconf import DictConfig, open_dict | ||
|
||
|
||
def load_dataset_from_config(dataset_config: DictConfig): | ||
""" | ||
Load the dataset from the configuration. | ||
""" | ||
assert hasattr(dataset_config, "type"), "Dataset type not specified" | ||
if dataset_config.type == "huggingface_image_classification": | ||
if not hasattr(dataset_config, "path"): | ||
with open_dict(dataset_config): | ||
dataset_config.path = dataset_config.name | ||
|
||
dataset = load_dataset( | ||
dataset_config.path, | ||
**(dataset_config.kwargs if hasattr(dataset_config, "kwargs") else {}), | ||
) | ||
if hasattr(dataset_config, "split"): | ||
dataset = dataset[dataset_config.split] | ||
return dataset | ||
else: | ||
raise ValueError(f"Unknown dataset type: {dataset_config.type}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from typing import Callable, Iterable, List | ||
|
||
import torch | ||
from torch import Tensor, nn | ||
from torch.types import _device | ||
from transformers import CLIPModel, CLIPProcessor, CLIPVisionModel, CLIPTextModel | ||
|
||
|
||
default_templates = [ | ||
lambda c: f"a photo of a {c}", | ||
] | ||
|
||
|
||
class HFCLIPClassifier(nn.Module): | ||
def __init__( | ||
self, | ||
clip_model: CLIPModel, | ||
processor: CLIPProcessor, | ||
): | ||
super().__init__() | ||
# we only fine-tune the vision model | ||
clip_model.visual_projection.requires_grad_(False) | ||
clip_model.text_model.requires_grad_(False) | ||
clip_model.text_projection.requires_grad_(False) | ||
clip_model.logit_scale.requires_grad_(False) | ||
|
||
self.clip_model = clip_model | ||
self.processor = processor | ||
self.register_buffer( | ||
"zeroshot_weights", | ||
None, | ||
persistent=False, | ||
) | ||
|
||
@property | ||
def text_model(self): | ||
return self.clip_model.text_model | ||
|
||
@property | ||
def vision_model(self): | ||
return self.clip_model.vision_model | ||
|
||
def set_classification_task( | ||
self, | ||
classnames: List[str], | ||
templates: List[Callable[[str], str]] = default_templates, | ||
): | ||
processor = self.processor | ||
|
||
self.classnames = classnames | ||
self.templates = templates | ||
|
||
with torch.no_grad(): | ||
zeroshot_weights = [] | ||
for classname in classnames: | ||
text = [template(classname) for template in templates] | ||
inputs = processor(text=text, return_tensors="pt", padding=True) | ||
|
||
embeddings = self.text_model(**inputs)[1] | ||
embeddings = self.clip_model.text_projection(embeddings) | ||
|
||
# normalize embeddings | ||
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True) | ||
|
||
embeddings = embeddings.mean(dim=0) | ||
embeddings = embeddings / embeddings.norm(p=2, dim=-1, keepdim=True) | ||
|
||
zeroshot_weights.append(embeddings) | ||
|
||
zeroshot_weights = torch.stack(zeroshot_weights, dim=0) | ||
|
||
self.zeroshot_weights = zeroshot_weights | ||
|
||
def forward(self, images): | ||
if self.zeroshot_weights is None: | ||
raise ValueError("Must set classification task before forward pass") | ||
text_embeds = self.zeroshot_weights | ||
|
||
image_embeds = self.vision_model(images)[1] | ||
image_embeds = self.clip_model.visual_projection(image_embeds) | ||
|
||
# normalize embeddings | ||
image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) | ||
|
||
# cosine similarity | ||
logit_scale = self.clip_model.logit_scale.exp() | ||
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale | ||
logits_per_image = logits_per_text.t() | ||
|
||
return logits_per_image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.