Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added GQA as eval dataset #298

Closed
wants to merge 64 commits into from
Closed
Changes from 1 commit
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
e19133d
deepspeed running
anas-awadalla Aug 25, 2023
870f20c
more progress
anas-awadalla Aug 26, 2023
f9162a0
added ds checkpointing
anas-awadalla Aug 26, 2023
ded3485
more progress
anas-awadalla Aug 30, 2023
3672042
mllm
Aug 30, 2023
99c350f
merge deepspeed
anas-awadalla Aug 30, 2023
2f634f0
rewrite src: add VLM, Kosmos, Flamingo
i-gao Sep 7, 2023
7261639
fix kosmos models
i-gao Sep 11, 2023
09977ba
cosmetic: num_params helper fn
i-gao Sep 11, 2023
6bb9071
revert to deepspeed branch code for train/
i-gao Sep 11, 2023
7984adb
add BLIP
i-gao Sep 12, 2023
7eab26a
minor train script fixes
i-gao Sep 12, 2023
aed0f21
fix vocab len issues
i-gao Sep 13, 2023
47c8e19
fixes
i-gao Sep 13, 2023
11ab894
big refactor of training code
i-gao Sep 15, 2023
cd4f3aa
many fixes + rewrite FSDP for torch nightly
i-gao Sep 16, 2023
74686a7
fixes
i-gao Sep 16, 2023
61f5a3d
fixes
i-gao Sep 16, 2023
ccfcb0f
run linter & fix gradient ckpting
i-gao Sep 16, 2023
303e707
no need to untie embeddings for fsdp
i-gao Sep 16, 2023
fc660e7
add in missing kwarg
i-gao Sep 16, 2023
be9a4dd
Merge branch deepspeed: eval code only
i-gao Sep 16, 2023
b0ff9a4
update eval code to match new src args
i-gao Sep 16, 2023
92bc4b7
update documentation and example scripts
i-gao Sep 16, 2023
60a82d7
fix deepspeed train script
anas-awadalla Sep 17, 2023
82d1c69
removed non default loss scale window
anas-awadalla Sep 17, 2023
4875822
init flamingo embeds new weights
anas-awadalla Sep 17, 2023
8f2f040
init flamingo embeds new weights
anas-awadalla Sep 17, 2023
beba4d2
Merge branch 'main' into mllm
anas-awadalla Sep 17, 2023
b81379f
fix mmc4 sim threshold arg
anas-awadalla Sep 17, 2023
f91c14a
add z-loss
anas-awadalla Sep 17, 2023
df96979
Merge pull request #262 from mlfoundations/add-z-loss
anas-awadalla Sep 17, 2023
bcc5a8f
Update eval README.md
i-gao Sep 17, 2023
770e653
have a default stdev for init
Sep 17, 2023
ef268be
Update run_train_deepspeed.sh
anas-awadalla Sep 17, 2023
da07e35
fix loss impl and model vocab size
Sep 17, 2023
3fcda82
Merge branch 'mllm' of https://github.com/mlfoundations/open_flamingo…
Sep 17, 2023
bcd2cf5
remove ds act checkpointing exception
Sep 18, 2023
9b1a764
fixes from PR review
i-gao Sep 19, 2023
866a780
Merge branch 'mllm' of github.com:mlfoundations/open_flamingo into mllm
i-gao Sep 19, 2023
5ad05c4
add weight/bias init to decouple linear
anas-awadalla Sep 20, 2023
939d460
Language stream changes (#264)
anas-awadalla Sep 21, 2023
ae76178
grad checkpointing + ds saving patch (we should find a cleaner solution)
anas-awadalla Sep 21, 2023
d29c8b8
Update run_train_deepspeed.sh
anas-awadalla Oct 18, 2023
b7af1d6
clearer parameter count logging
anas-awadalla Oct 18, 2023
43ac961
Fix model vocab size (now it is len of tokenizer)
anas-awadalla Oct 18, 2023
e7684b5
Update code example
anas-awadalla Oct 18, 2023
735a880
fix LR schedule
anas-awadalla Oct 23, 2023
496e656
fix var naming in load_deepspeed_checkpoint
anas-awadalla Oct 24, 2023
c5feb97
Update losses.py
anas-awadalla Nov 30, 2023
dbb1ad8
train_utils media token fix
Dec 2, 2023
fa6af69
remove unnecessary model unwrap lines
Dec 2, 2023
eb6b8aa
Merge pull request #283 from mlfoundations/media_token_fix
anas-awadalla Dec 2, 2023
1e75320
remove deepspeed, some fixes, and llava
Feb 22, 2024
feba465
fix for siglip, llava, and lr decay
anas-awadalla Feb 24, 2024
0b1c926
remove z-loss mess
anas-awadalla Feb 24, 2024
79ad152
some more fixes
anas-awadalla Mar 17, 2024
3945c87
Update data.py
anas-awadalla Mar 17, 2024
52ca075
Update losses.py
anas-awadalla Mar 17, 2024
292afa1
fix flamingo init
anas-awadalla Mar 20, 2024
a72c96b
fix resampler projection
anas-awadalla Mar 21, 2024
c7a5ae5
Update helpers.py
anas-awadalla Mar 21, 2024
a5378a8
blip.py import and output truncation fix
Mar 28, 2024
358cecc
added gqa as eval dataset
May 2, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix for siglip, llava, and lr decay
anas-awadalla committed Feb 24, 2024
commit feba465c0eacdae3d2f3948ce019ca630d1f25b6
19 changes: 10 additions & 9 deletions open_flamingo/src/factory.py
Original file line number Diff line number Diff line change
@@ -61,9 +61,11 @@ def create_model_and_transforms(
)
vision_encoder.visual.output_tokens = True
vision_encoder = vision_encoder.visual
vis_hidden_dim = open_clip.get_model_config(clip_vision_encoder_path)["vision_cfg"][
"width"
]
vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
if "SigLIP" in clip_vision_encoder_path: # SigLIP models have a different config format
vis_hidden_dim = vision_encoder_config["embed_dim"]
else:
vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]

# load tokenizer and ensure there is a pad token
text_tokenizer = AutoTokenizer.from_pretrained(
@@ -145,6 +147,9 @@ def _infer_decoder_layers_attr_name(model):
"gptneoxforcausallm": "gpt_neox.layers",
"mpt": "transformer.blocks",
"mosaicgpt": "transformer.blocks",
"gemma": "model.layers",
"phi": "model.layers",
"minicpm": "model.layers",
}


@@ -194,9 +199,5 @@ def check_embedding_fns(lang_model):


def has_fn(model, fn_name):
"""Try to call the fn_name function on the model"""
try:
getattr(model, fn_name)()
return True
except:
return False
"""Check if model has a function fn_name"""
return callable(getattr(model, fn_name, None))
8 changes: 7 additions & 1 deletion open_flamingo/src/llava.py
Original file line number Diff line number Diff line change
@@ -31,11 +31,17 @@ def __init__(
"media_token": "<image>",
}
lang_embedding_dim = lang_model.get_input_embeddings().weight.shape[1]

if vision_encoder.__class__.__name__ == "TimmModel":
grid_size = vision_encoder.trunk.patch_embed.grid_size
else:
grid_size = vision_encoder.grid_size

super().__init__(
vision_encoder=vision_encoder,
vision_tokenizer=LinearPatchProjection(dim_visual=vis_feature_dim,
dim_out=lang_embedding_dim,
num_patches=vision_encoder.grid_size[0] * vision_encoder.grid_size[1]),
num_patches=grid_size[0] * grid_size[1]),
lang_model=lang_model,
initial_tokenizer_len=initial_tokenizer_len,
gradient_checkpointing=gradient_checkpointing,
7 changes: 5 additions & 2 deletions open_flamingo/src/vlm.py
Original file line number Diff line number Diff line change
@@ -184,10 +184,13 @@ def _encode_vision_x(self, vision_x: torch.Tensor):
"""
assert vision_x.ndim == 6, "vision_x should be of shape (b, T_img, F, C, H, W)"
b, T, F = vision_x.shape[:3]

vision_x = rearrange(vision_x, "b T F c h w -> (b T F) c h w")
with torch.no_grad():
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
if self.vision_encoder.__class__.__name__ == "TimmModel":
vision_x = self.vision_encoder.trunk.forward_features(vision_x)
else:
vision_x = self.vision_encoder(vision_x)[1] # OpenCLIP returns tuples
vision_x = rearrange(vision_x, "(b T F) v d -> b T F v d", b=b, T=T, F=F)
return vision_x

7 changes: 4 additions & 3 deletions open_flamingo/train/data.py
Original file line number Diff line number Diff line change
@@ -68,14 +68,15 @@ def preprocess_laion_image(sample, image_processor):
return rearrange(sample, "(b t f) c h w -> b t f c h w", t=1, f=1)


def preprocess_laion_text(sample, tokenizer, max_tokens=32):
def preprocess_laion_text(sample, tokenizer, max_tokens=256):
"""
Preprocess text for LAION. Applied to a batch of captions.
Captions are truncated to 32 tokens by default.
Captions are truncated to 256 tokens by default.
"""
tokenizer.padding_side = "right"
sample = [
(f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
# (f"<image>{s.strip()}<|endofchunk|>{tokenizer.eos_token}") for s in sample
(f"<image>{s.strip()}{tokenizer.eos_token}") for s in sample
]
text = tokenizer(
sample,
72 changes: 13 additions & 59 deletions open_flamingo/train/train.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,6 @@
import os
import torch
import wandb
import deepspeed
import functools
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
@@ -16,16 +15,13 @@
world_info_from_env,
get_fsdp_config,
get_fsdp_checkpoint_config,
get_deepspeed_config,
)
from open_flamingo.train.train_utils import (
train_one_epoch,
random_seed,
load_deepspeed_checkpoint,
find_most_recent_checkpoint,
load_checkpoint,
save_checkpoint,
save_deepspeed_checkpoint,
)
from open_flamingo.train.losses import (
SUPPORTED_LOSSES,
@@ -44,8 +40,8 @@ def main():
parser.add_argument(
"--model_family", default="flamingo", type=str, choices=SUPPORTED_MODEL_FAMILIES
)
parser.add_argument("--vision_encoder_path", default="ViT-L-14", type=str)
parser.add_argument("--vision_encoder_pretrained", default="openai", type=str)
parser.add_argument("--vision_encoder_path", default="ViT-SO400M-14-SigLIP-384", type=str)
parser.add_argument("--vision_encoder_pretrained", default="webli", type=str)
parser.add_argument("--lm_path", default="facebook/opt-1.3b", type=str)
parser.add_argument(
"--tokenizer_path",
@@ -73,7 +69,7 @@ def main():
parser.add_argument(
"--resume_from_checkpoint",
type=str,
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default. If using deepspeed this should be a directory, not a file.",
help="path to checkpoint to resume from, this should contain model, optimizer, and lr_scheduler states. if there exists a checkpoint in the dir named run_name, we will resume from that checkpoint by default.",
default=None,
)
parser.add_argument(
@@ -187,20 +183,6 @@ def main():
"--fsdp_sharding_strategy", default="full", type=str, choices=["full", "hybrid"]
)

# deepspeed args
parser.add_argument(
"--deepspeed",
default=False,
action="store_true",
help="Use deepspeed for distributed training.",
)
parser.add_argument(
"--deepspeed_stage",
default=2,
type=int,
help="DeepSpeed distributed training stage. 1: ZeRO-1 (optimizer sharding), 2: ZeRO-2 (optimizer + gradient sharding), 3: ZeRO-3 (optimizer + gradient + model sharding)",
)

# wandb args
parser.add_argument("--report_to_wandb", default=False, action="store_true")
parser.add_argument(
@@ -251,16 +233,10 @@ def main():
if args.save_checkpoints_to_wandb and not args.report_to_wandb:
raise ValueError("save_checkpoints_to_wandb requires report_to_wandb")

if args.fsdp and args.deepspeed:
raise ValueError("Select either FSDP or deepspeed for distributed training.")

if args.fsdp:
print(
"Warning: FSDP is experimental and not fully tested. Preference should be given to Deepspeed."
)
assert (
"dev" in torch.__version__ and torch.__version__ > "2.0.1"
), "FSDP requires torch nightly > 2.0.1"
torch.__version__ > "2.0.1"
), "FSDP requires torch > 2.0.1"

# Set up distributed training
args.local_rank, args.rank, args.world_size = world_info_from_env()
@@ -269,13 +245,7 @@ def main():
if args.offline:
os.environ["WANDB_MODE"] = "offline"
os.environ["TRANSFORMERS_OFFLINE"] = "1"
if args.deepspeed:
torch.cuda.set_device(args.local_rank)
deepspeed.init_distributed()
ds_config = get_deepspeed_config(args)
device_id = args.local_rank
else:
device_id = init_distributed_device(args)
device_id = init_distributed_device(args)

random_seed(args.seed)

@@ -316,8 +286,8 @@ def main():
args.resume_from_checkpoint = find_most_recent_checkpoint(args)

if (
args.resume_from_checkpoint is not None and not args.deepspeed
): # deepspeed handles checkpoint loading
args.resume_from_checkpoint is not None
):
resume_from_epoch, checkpoint = load_checkpoint(args, model)
else:
resume_from_epoch = 0
@@ -327,7 +297,6 @@ def main():
model.init_gradient_checkpointing()

# Initialize FSDP / DDP, and ensure the model is on GPU
# Deepspeed is initialized later
if args.fsdp:
auto_wrap_policy = functools.partial(
lambda_auto_wrap_policy, lambda_fn=model.get_fsdp_lambda_fn()
@@ -336,7 +305,7 @@ def main():
distributed_model = FSDP(
model, auto_wrap_policy=auto_wrap_policy, **wrapper_kwargs
)
elif not args.deepspeed:
else:
model = model.to(device_id)
distributed_model = DDP(model, device_ids=[device_id])

@@ -351,7 +320,7 @@ def main():
)

# load optimizer checkpoint
if args.resume_from_checkpoint is not None and not args.deepspeed:
if args.resume_from_checkpoint is not None:
osd = checkpoint["optimizer_state_dict"]
if args.fsdp:
FSDP.set_state_dict_type(
@@ -370,7 +339,7 @@ def main():
]
total_training_steps = (
getattr(args, f"train_num_samples_{datasets_to_train_on[0]}")
// getattr(args, f"batch_size_{datasets_to_train_on[0]}")
// (getattr(args, f"batch_size_{datasets_to_train_on[0]}") * args.gradient_accumulation_steps * args.world_size)
) * args.num_epochs

if args.rank == 0:
@@ -395,21 +364,9 @@ def main():
)

# load lr scheduler checkpoint
if args.resume_from_checkpoint is not None and not args.deepspeed:
if args.resume_from_checkpoint is not None:
lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"])

if args.deepspeed:
distributed_model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model,
optimizer=optimizer,
args=args,
config=ds_config,
lr_scheduler=lr_scheduler,
dist_init_required=True,
)
if args.resume_from_checkpoint is not None:
resume_from_epoch = load_deepspeed_checkpoint(args, distributed_model)

# Initialize the loss fn
loss_fn = get_loss_fn(args.loss)

@@ -435,10 +392,7 @@ def main():
wandb=wandb,
)

if args.deepspeed:
save_deepspeed_checkpoint(distributed_model, epoch, args)
else:
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)
save_checkpoint(distributed_model, optimizer, lr_scheduler, epoch, args)


if __name__ == "__main__":
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
einops
einops-exts
transformers==4.28.1
transformers
torch>=2.0.1
pillow
open_clip_torch>=2.16.0