Skip to content

Commit 1e75320

Browse files
author
Anas Awadalla
committed
remove deepspeed, some fixes, and llava
1 parent eb6b8aa commit 1e75320

22 files changed

+203
-392
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ To instantiate an OpenFlamingo model with one of our released weights, initializ
102102
from huggingface_hub import hf_hub_download
103103
import torch
104104

105-
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-3B-vitl-mpt1b", "checkpoint.pt")
105+
checkpoint_path = hf_hub_download("openflamingo/OpenFlamingo-4B-vitl-rpj3b", "checkpoint.pt")
106106
model.load_state_dict(torch.load(checkpoint_path), strict=False)
107107
```
108108

open_flamingo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .src.flamingo import Flamingo
22
from .src.kosmos import Kosmos
33
from .src.blip import BLIP
4+
from .src.llava import Llava
45
from .src.factory import create_model_and_transforms, SUPPORTED_MODEL_FAMILIES

open_flamingo/eval/README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,6 @@ To help standardize VLM evaluations, we have implemented EvalModel wrappers for
3030
## Distributed evaluation
3131
Our codebase uses DistributedDataParallel to parallelize evaluation by default, so please make sure to set the `MASTER_ADDR` and `MASTER_PORT` environment variables or use `torchrun` (see sample scripts section below).
3232

33-
We have also implemented distributed evaluation using Deepspeed, which additionally shards model parameters across GPUs for memory savings. To use Deepspeed instead of DDP, use the `--deepspeed` flag.
34-
3533
We also support evaluating at a lower precision using the `--precision` flag. We find minimal difference between evaluating at full precision vs. amp_bf16.
3634

3735
## Sample scripts

open_flamingo/eval/eval_models/eval_model.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def get_eval_model(name, *args, **kwargs):
3131
class BaseEvalModel(abc.ABC):
3232
"""Base class encapsulating functionality needed to evaluate a model."""
3333

34-
def __init__(self, model_args: List[str], init_on_device=False):
34+
def __init__(self, model_args: List[str]):
3535
"""Initialize model.
3636
3737
Args:
@@ -59,17 +59,6 @@ def __init__(self, model_args: List[str], init_on_device=False):
5959
self.autocast = get_autocast(self.precision)
6060
self.cast_dtype = get_cast_dtype(self.precision)
6161

62-
# initialization context
63-
if init_on_device:
64-
# for deepspeed, must init on device, or likely CPU OOM
65-
import deepspeed
66-
67-
self.init_ctx = deepspeed.OnDevice(
68-
dtype=self.cast_dtype, device=self.device
69-
)
70-
else:
71-
self.init_ctx = suppress()
72-
7362
@property
7463
def required_args(self):
7564
"""Return list of required arguments to initialize model."""
@@ -83,23 +72,9 @@ def _check_init(self):
8372
assert hasattr(self, "tokenizer"), "Tokenizer has not been initialized"
8473
self.tokenizer.padding_side = "left"
8574

86-
def init_distributed(self, world_size=None, use_deepspeed=False):
87-
"""Wrap model as DDP or deepspeed."""
88-
if use_deepspeed:
89-
assert "amp" not in self.precision, "Deepspeed does not support amp"
90-
import deepspeed
91-
92-
self.ds_engine = deepspeed.init_inference(
93-
self.model,
94-
mp_size=world_size,
95-
dtype=self.cast_dtype,
96-
checkpoint=None,
97-
replace_with_kernel_inject=True,
98-
)
99-
self.model = self.ds_engine.module
100-
self.autocast = get_autocast(None)
101-
else:
102-
self.model = DDP(self.model, device_ids=[self.device])
75+
def init_distributed(self):
76+
"""Wrap model as DDP."""
77+
self.model = DDP(self.model, device_ids=[self.device])
10378

10479
def __call__(
10580
self,

open_flamingo/eval/evaluate.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -394,12 +394,6 @@
394394
action="store_true",
395395
help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
396396
)
397-
parser.add_argument(
398-
"--deepspeed",
399-
default=False,
400-
action="store_true",
401-
help="Whether to use deepspeed for distributed inference.",
402-
)
403397

404398

405399
def main():
@@ -414,11 +408,9 @@ def main():
414408
model_args["device"] = device_id
415409

416410
# initialize model
417-
eval_model = get_eval_model(args.model, model_args, init_on_device=args.deepspeed)
411+
eval_model = get_eval_model(args.model, model_args, init_on_device=False)
418412
eval_model.init_distributed(
419413
local_rank=args.local_rank,
420-
world_size=args.world_size,
421-
use_deepspeed=args.deepspeed,
422414
)
423415

424416
# Validate args

open_flamingo/scripts/run_eval_deepspeed.sh

Lines changed: 0 additions & 77 deletions
This file was deleted.

open_flamingo/scripts/run_train_deepspeed.sh

Lines changed: 0 additions & 41 deletions
This file was deleted.

open_flamingo/src/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .helpers import VLMOutputWithPast

open_flamingo/src/cross_attn_lm.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,15 +88,16 @@ def init_cross_attention_layers(
8888
"""
8989
Add gated cross attn layers to the decoder.
9090
"""
91-
self.old_decoder_blocks = self._get_decoder_layers()
91+
old_decoder_blocks = self._get_decoder_layers()
92+
self.decoder_block_class = old_decoder_blocks[0].__class__
9293
self.gated_cross_attn_layers = nn.ModuleList(
9394
[
9495
GatedCrossAttentionBlock(
9596
dim=lang_hidden_size, dim_visual=vis_hidden_size
9697
)
9798
if (layer_idx + 1) % cross_attn_every_n_layers == 0
9899
else None
99-
for layer_idx, _ in enumerate(self._get_decoder_layers())
100+
for layer_idx, _ in enumerate(old_decoder_blocks)
100101
]
101102
)
102103
self._set_decoder_layers(
@@ -106,7 +107,7 @@ def init_cross_attention_layers(
106107
gated_cross_attn_layer, decoder_layer, gradient_checkpointing
107108
)
108109
for gated_cross_attn_layer, decoder_layer in zip(
109-
self.gated_cross_attn_layers, self.old_decoder_blocks
110+
self.gated_cross_attn_layers, old_decoder_blocks
110111
)
111112
]
112113
)
@@ -119,11 +120,14 @@ def _condition_media_before_forward(
119120
vision_tokens: torch.Tensor = None,
120121
past_media_locations: torch.Tensor = None,
121122
past_vision_tokens: torch.Tensor = None,
123+
num_beams: int = 1,
122124
):
123125
"""Each xattn layer needs to save the vision tokens and the locations of the media tokens in the language sequence"""
124126
assert (
125127
self.initialized_cross_attention
126128
), "Cross attention layers have not been initialized. "
129+
130+
# concat with past
127131
if past_media_locations is not None and past_vision_tokens is not None:
128132
if vision_tokens is not None:
129133
updated_vision_tokens = torch.cat(
@@ -146,6 +150,15 @@ def _condition_media_before_forward(
146150
updated_vision_tokens = vision_tokens
147151
updated_media_locations = input_ids == self.media_token_id
148152

153+
# repeat the vision tokens and media locations for each beam
154+
updated_vision_tokens = updated_vision_tokens.repeat_interleave(
155+
num_beams, dim=0
156+
)
157+
updated_media_locations = updated_media_locations.repeat_interleave(
158+
num_beams, dim=0
159+
)
160+
161+
# condition
149162
for layer in self._get_decoder_layers():
150163
layer.condition_vis_x(updated_vision_tokens)
151164
layer.condition_media_locations(updated_media_locations)
@@ -157,4 +170,4 @@ def is_conditioned(self) -> bool:
157170
def clear_conditioned_layers(self):
158171
for layer in self._get_decoder_layers():
159172
layer.condition_vis_x(None)
160-
layer.condition_media_locations(None)
173+
layer.condition_media_locations(None)

open_flamingo/src/factory.py

Lines changed: 21 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,21 @@
11
from typing import Optional
2-
import torch.nn as nn
32

43
from transformers import AutoModelForCausalLM, AutoTokenizer
54
import open_clip
65

76
from .flamingo import Flamingo
87
from .kosmos import Kosmos
98
from .blip import BLIP
9+
from .llava import Llava
1010
from .utils import hasattr_recursive, setattr_recursive
1111

12-
SUPPORTED_MODEL_FAMILIES = ("flamingo", "kosmos", "blip")
13-
12+
SUPPORTED_MODEL_FAMILIES = ("flamingo", "kosmos", "blip", "llava")
13+
MODEL_FAMILY_TO_CLASS = {
14+
"flamingo": Flamingo,
15+
"kosmos": Kosmos,
16+
"blip": BLIP,
17+
"llava": Llava,
18+
}
1419

1520
def create_model_and_transforms(
1621
clip_vision_encoder_path: str,
@@ -83,41 +88,16 @@ def create_model_and_transforms(
8388
if decoder_layers_attr_name is None:
8489
decoder_layers_attr_name = _infer_decoder_layers_attr_name(lang_model)
8590

86-
if model_family == "flamingo":
87-
model = Flamingo(
88-
vision_encoder=vision_encoder,
89-
lang_model=lang_model,
90-
vis_feature_dim=vis_hidden_dim,
91-
initial_tokenizer_len=len(text_tokenizer),
92-
gradient_checkpointing=gradient_checkpointing,
93-
decoder_layers_attr_name=decoder_layers_attr_name,
94-
pad_token_id=text_tokenizer.pad_token_id,
95-
**model_kwargs,
96-
)
97-
98-
elif model_family == "kosmos":
99-
model = Kosmos(
100-
vision_encoder=vision_encoder,
101-
lang_model=lang_model,
102-
vis_feature_dim=vis_hidden_dim,
103-
initial_tokenizer_len=len(text_tokenizer),
104-
gradient_checkpointing=gradient_checkpointing,
105-
pad_token_id=text_tokenizer.pad_token_id,
106-
decoder_layers_attr_name=decoder_layers_attr_name,
107-
**model_kwargs,
108-
)
109-
110-
elif model_family == "blip":
111-
model = BLIP(
112-
vision_encoder=vision_encoder,
113-
lang_model=lang_model,
114-
vis_feature_dim=vis_hidden_dim,
115-
initial_tokenizer_len=len(text_tokenizer),
116-
gradient_checkpointing=gradient_checkpointing,
117-
pad_token_id=text_tokenizer.pad_token_id,
118-
decoder_layers_attr_name=decoder_layers_attr_name,
119-
**model_kwargs,
120-
)
91+
model = MODEL_FAMILY_TO_CLASS[model_family](
92+
vision_encoder=vision_encoder,
93+
lang_model=lang_model,
94+
vis_feature_dim=vis_hidden_dim,
95+
initial_tokenizer_len=len(text_tokenizer),
96+
gradient_checkpointing=gradient_checkpointing,
97+
decoder_layers_attr_name=decoder_layers_attr_name,
98+
pad_token_id=text_tokenizer.pad_token_id,
99+
**model_kwargs,
100+
)
121101

122102
# add special tokens to the tokenizer and language models
123103
text_tokenizer.add_special_tokens(
@@ -130,7 +110,6 @@ def create_model_and_transforms(
130110
for v in model.special_tokens.values()
131111
}
132112
)
133-
134113
# freeze appropriate parameters
135114
model.set_trainable()
136115

@@ -139,8 +118,8 @@ def create_model_and_transforms(
139118
print(
140119
f"{model_family} model initialized with {model.num_trainable_params:,} trainable parameters"
141120
)
142-
print(f"========== Trainable Parameters\n{model.num_trainable_params_per_module}")
143-
print(f"========== Total Parameters\n{model.num_params_per_module}\n==========")
121+
print(f"==========Trainable Parameters\n{model.num_trainable_params_per_module}")
122+
print(f"==========Total Parameters\n{model.num_params_per_module}\n==========")
144123
return model, image_processor, text_tokenizer
145124

146125

@@ -220,4 +199,4 @@ def has_fn(model, fn_name):
220199
getattr(model, fn_name)()
221200
return True
222201
except:
223-
return False
202+
return False

0 commit comments

Comments
 (0)