Skip to content

Commit

Permalink
modified: fusion_bench/models/linearized/vision_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke committed Nov 15, 2024
1 parent 3705fae commit fed3200
Showing 1 changed file with 36 additions and 7 deletions.
43 changes: 36 additions & 7 deletions fusion_bench/models/linearized/vision_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Tuple
from typing import Tuple, Union

from huggingface_hub import hf_hub_download
from peft import LoraConfig, PeftModel, get_peft_model
Expand Down Expand Up @@ -44,14 +44,32 @@ def linearize_lora_model_(model):
return model


def load_fft_vision_model_hf(model_name: str) -> CLIPVisionTransformer:
return CLIPVisionModel.from_pretrained(model_name).vision_model
def load_fft_vision_model_hf(
model_name: str, return_vison_model=True
) -> Union[CLIPVisionTransformer, CLIPVisionModel]:
"""
Load a CLIP vision model from Hugging Face.
Args:
model_name (str): The name of the CLIP vision model to load from Hugging Face.
return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
Returns:
Union[CLIPVisionTransformer, CLIPVisionModel]: The vision model.
"""
model = CLIPVisionModel.from_pretrained(model_name)

if return_vison_model:
return CLIPVisionModel.from_pretrained(model_name).vision_model
else:
return model


def load_lora_vision_model_hf(
base_model_name: str,
peft_name: str,
merge_and_unload: bool = False,
return_vison_model=True,
):
"""
Load a LoRA (Low-Rank Adaptation) vision model from Hugging Face.
Expand All @@ -62,17 +80,28 @@ def load_lora_vision_model_hf(
base_model_name (str): The name of the base vision model to load from Hugging Face.
peft_name (str): The name of the LoRA adaptation to apply to the base model.
merge_and_unload (bool, optional): If True, the LoRA adaptation is merged into the base model and the LoRA layers are removed. Defaults to False.
return_vison_model (bool, optional): If False, the full CLIPVisionModel is returned. If True, only the vision model (`CLIPVisionTransformer`) is returned. Defaults to True.
Returns:
PeftModel: The adapted vision model, optionally merged and unloaded.
"""
model = CLIPVisionModel.from_pretrained(base_model_name)

# Load the Peft model
# note that we apply lora on type `CLIPVisionTransformer` instead of `CLIPVisionModel`
model = CLIPVisionModel.from_pretrained(base_model_name).vision_model
peft_model = PeftModel.from_pretrained(model, peft_name, is_trainable=True)
vision_model = model.vision_model
peft_model = PeftModel.from_pretrained(vision_model, peft_name, is_trainable=True)
if merge_and_unload:
return peft_model.merge_and_unload()
vision_model = peft_model.merge_and_unload()
else:
vision_model = peft_model

# Return the vision model
if return_vison_model:
return vision_model
else:
return peft_model
model.vision_model = vision_model
return model


def load_l_lora_vision_model_hf(base_model_name: str, peft_name: str):
Expand Down

0 comments on commit fed3200

Please sign in to comment.