Skip to content

Commit feba465

Browse files
committed
fix for siglip, llava, and lr decay
1 parent 1e75320 commit feba465

File tree

6 files changed

+40
-75
lines changed

6 files changed

+40
-75
lines changed

open_flamingo/src/factory.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,11 @@ def create_model_and_transforms(
6161
)
6262
vision_encoder.visual.output_tokens = True
6363
vision_encoder = vision_encoder.visual
64-
vis_hidden_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
65-
"width"
66-
]
64+
vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
65+
if "SigLIP" in clip_vision_encoder_path: # SigLIP models have a different config format
66+
vis_hidden_dim = vision_encoder_config["embed_dim"]
67+
else:
68+
vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]
6769

6870
# load tokenizer and ensure there is a pad token
6971
text_tokenizer = AutoTokenizer.from_pretrained(
@@ -145,6 +147,9 @@ def _infer_decoder_layers_attr_name(model):
145147
"gptneoxforcausallm": "gpt_neox.layers",
146148
"mpt": "transformer.blocks",
147149
"mosaicgpt": "transformer.blocks",
150+
"gemma": "model.layers",
151+
"phi": "model.layers",
152+
"minicpm": "model.layers",
148153
}
149154

150155

@@ -194,9 +199,5 @@ def check_embedding_fns(lang_model):
194199

195200

196201
def has_fn(model, fn_name):
197-
"""Try to call the fn_name function on the model"""
198-
try:
199-
getattr(model, fn_name)()
200-
return True
201-
except:
202-
return False
202+
"""Check if model has a function fn_name"""
203+
return callable(getattr(model, fn_name, None))

open_flamingo/src/llava.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,17 @@ def __init__(
3131
"media_token": "<image>",
3232
}
3333
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]
34+
35+
if vision_encoder.__class__.__name__ == "TimmModel":
36+
grid_size = vision_encoder.trunk.patch_embed.grid_size
37+
else:
38+
grid_size = vision_encoder.grid_size
39+
3440
super().__init__(
3541
vision_encoder=vision_encoder,
3642
vision_tokenizer=LinearPatchProjection(dim_visual=vis_feature_dim,
3743
dim_out=lang_embedding_dim,
38-
num_patches=vision_encoder.grid_size[0] * vision_encoder.grid_size[1]),
44+
num_patches=grid_size[0] * grid_size[1]),
3945
lang_model=lang_model,
4046
initial_tokenizer_len=initial_tokenizer_len,
4147
gradient_checkpointing=gradient_checkpointing,

open_flamingo/src/vlm.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,10 +184,13 @@ def _encode_vision_x(self, vision_x: torch.Tensor):
184184
"""
185185
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
186186
b, T, F = vision_x.shape[:3]
187-
187+
188188
vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
189189
with torch.no_grad():
190-
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
190+
if self.vision_encoder.__class__.__name__ == "TimmModel":
191+
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
192+
else:
193+
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
191194
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
192195
return vision_x
193196

open_flamingo/train/data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,15 @@ def preprocess_laion_image(sample, image_processor):
6868
return rearrange(sample, "(b t f) c h w -> b t f c h w", t=1, f=1)
6969

7070

71-
def preprocess_laion_text(sample, tokenizer, max_tokens=32):
71+
def preprocess_laion_text(sample, tokenizer, max_tokens=256):
7272
"""
7373
Preprocess text for LAION. Applied to a batch of captions.
74-
Captions are truncated to 32 tokens by default.
74+
Captions are truncated to 256 tokens by default.
7575
"""
7676
tokenizer.padding_side = "right"
7777
sample = [
78-
(f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
78+
# (f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
79+
(f"<image>{s.strip()}{tokenizer.eos_token}") for s in sample
7980
]
8081
text = tokenizer(
8182
sample,

open_flamingo/train/train.py

Lines changed: 13 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import os
44
import torch
55
import wandb
6-
import deepspeed
76
import functools
87
from torch.nn.parallel import DistributedDataParallel as DDP
98
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -16,16 +15,13 @@
1615
world_info_from_env,
1716
get_fsdp_config,
1817
get_fsdp_checkpoint_config,
19-
get_deepspeed_config,
2018
)
2119
from open_flamingo.train.train_utils import (
2220
train_one_epoch,
2321
random_seed,
24-
load_deepspeed_checkpoint,
2522
find_most_recent_checkpoint,
2623
load_checkpoint,
2724
save_checkpoint,
28-
save_deepspeed_checkpoint,
2925
)
3026
from open_flamingo.train.losses import (
3127
SUPPORTED_LOSSES,
@@ -44,8 +40,8 @@ def main():
4440
parser.add_argument(
4541
"--model_family", default="flamingo", type=str, choices=SUPPORTED_MODEL_FAMILIES
4642
)
47-
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
48-
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
43+
parser.add_argument("--vision_encoder_path", default="ViT-SO400M-14-SigLIP-384", type=str)
44+
parser.add_argument("--vision_encoder_pretrained", default="webli", type=str)
4945
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
5046
parser.add_argument(
5147
"--tokenizer_path",
@@ -73,7 +69,7 @@ def main():
7369
parser.add_argument(
7470
"--resume_from_checkpoint",
7571
type=str,
76-
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default. If using deepspeed this should be a directory, not a file.",
72+
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default.",
7773
default=None,
7874
)
7975
parser.add_argument(
@@ -187,20 +183,6 @@ def main():
187183
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"]
188184
)
189185

190-
# deepspeed args
191-
parser.add_argument(
192-
"--deepspeed",
193-
default=False,
194-
action="store_true",
195-
help="Use deepspeed for distributed training.",
196-
)
197-
parser.add_argument(
198-
"--deepspeed_stage",
199-
default=2,
200-
type=int,
201-
help="DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)",
202-
)
203-
204186
# wandb args
205187
parser.add_argument("--report_to_wandb", default=False, action="store_true")
206188
parser.add_argument(
@@ -251,16 +233,10 @@ def main():
251233
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
252234
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")
253235

254-
if args.fsdp and args.deepspeed:
255-
raise ValueError("Select either FSDP or deepspeed for distributed training.")
256-
257236
if args.fsdp:
258-
print(
259-
"Warning: FSDP is experimental and not fully tested. Preference should be given to Deepspeed."
260-
)
261237
assert (
262-
"dev" in torch.__version__ and torch.__version__ > "2.0.1"
263-
), "FSDP requires torch nightly > 2.0.1"
238+
torch.__version__ > "2.0.1"
239+
), "FSDP requires torch > 2.0.1"
264240

265241
# Set up distributed training
266242
args.local_rank, args.rank, args.world_size = world_info_from_env()
@@ -269,13 +245,7 @@ def main():
269245
if args.offline:
270246
os.environ["WANDB_MODE"] = "offline"
271247
os.environ["TRANSFORMERS_OFFLINE"] = "1"
272-
if args.deepspeed:
273-
torch.cuda.set_device(args.local_rank)
274-
deepspeed.init_distributed()
275-
ds_config = get_deepspeed_config(args)
276-
device_id = args.local_rank
277-
else:
278-
device_id = init_distributed_device(args)
248+
device_id = init_distributed_device(args)
279249

280250
random_seed(args.seed)
281251

@@ -316,8 +286,8 @@ def main():
316286
args.resume_from_checkpoint = find_most_recent_checkpoint(args)
317287

318288
if (
319-
args.resume_from_checkpoint is not None and not args.deepspeed
320-
): # deepspeed handles checkpoint loading
289+
args.resume_from_checkpoint is not None
290+
):
321291
resume_from_epoch, checkpoint = load_checkpoint(args, model)
322292
else:
323293
resume_from_epoch = 0
@@ -327,7 +297,6 @@ def main():
327297
model.init_gradient_checkpointing()
328298

329299
# Initialize FSDP / DDP, and ensure the model is on GPU
330-
# Deepspeed is initialized later
331300
if args.fsdp:
332301
auto_wrap_policy = functools.partial(
333302
lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn()
@@ -336,7 +305,7 @@ def main():
336305
distributed_model = FSDP(
337306
model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs
338307
)
339-
elif not args.deepspeed:
308+
else:
340309
model = model.to(device_id)
341310
distributed_model = DDP(model, device_ids=[device_id])
342311

@@ -351,7 +320,7 @@ def main():
351320
)
352321

353322
# load optimizer checkpoint
354-
if args.resume_from_checkpoint is not None and not args.deepspeed:
323+
if args.resume_from_checkpoint is not None:
355324
osd = checkpoint["optimizer_state_dict"]
356325
if args.fsdp:
357326
FSDP.set_state_dict_type(
@@ -370,7 +339,7 @@ def main():
370339
]
371340
total_training_steps = (
372341
getattr(args, f"train_num_samples_{datasets_to_train_on[0]}")
373-
// getattr(args, f"batch_size_{datasets_to_train_on[0]}")
342+
// (getattr(args, f"batch_size_{datasets_to_train_on[0]}") * args.gradient_accumulation_steps * args.world_size)
374343
) * args.num_epochs
375344

376345
if args.rank == 0:
@@ -395,21 +364,9 @@ def main():
395364
)
396365

397366
# load lr scheduler checkpoint
398-
if args.resume_from_checkpoint is not None and not args.deepspeed:
367+
if args.resume_from_checkpoint is not None:
399368
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])
400369

401-
if args.deepspeed:
402-
distributed_model, optimizer, _, lr_scheduler = deepspeed.initialize(
403-
model=model,
404-
optimizer=optimizer,
405-
args=args,
406-
config=ds_config,
407-
lr_scheduler=lr_scheduler,
408-
dist_init_required=True,
409-
)
410-
if args.resume_from_checkpoint is not None:
411-
resume_from_epoch = load_deepspeed_checkpoint(args, distributed_model)
412-
413370
# Initialize the loss fn
414371
loss_fn = get_loss_fn(args.loss)
415372

@@ -435,10 +392,7 @@ def main():
435392
wandb=wandb,
436393
)
437394

438-
if args.deepspeed:
439-
save_deepspeed_checkpoint(distributed_model, epoch, args)
440-
else:
441-
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
395+
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
442396

443397

444398
if __name__ == "__main__":

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
einops
22
einops-exts
3-
transformers==4.28.1
3+
transformers
44
torch>=2.0.1
55
pillow
66
open_clip_torch>=2.16.0

0 commit comments

Comments
 (0)