Skip to content

Commit 79ad152

Browse files
committed
some more fixes
1 parent 0b1c926 commit 79ad152

File tree

17 files changed

+159
-153
lines changed

17 files changed

+159
-153
lines changed

open_flamingo/eval/eval_datasets.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
SUPPORTED_TASKS = [
1111
"coco",
12-
"flickr",
12+
"flickr30",
1313
"vqav2",
14-
"ok_vqa",
14+
"okvqa",
1515
"vizwiz",
1616
"textvqa",
1717
"hateful_memes",
@@ -87,15 +87,15 @@ def __init__(
8787
self.image_dir_path = image_dir_path
8888
self.is_train = is_train
8989
self.dataset_name = dataset_name
90-
if self.dataset_name in {"vqav2", "ok_vqa"}:
90+
if self.dataset_name in {"vqav2", "okvqa"}:
9191
self.img_coco_split = self.image_dir_path.strip("/").split("/")[-1]
9292
assert self.img_coco_split in {"train2014", "val2014", "test2015"}
9393

9494
def __len__(self):
9595
return len(self.questions)
9696

9797
def get_img_path(self, question):
98-
if self.dataset_name in {"vqav2", "ok_vqa"}:
98+
if self.dataset_name in {"vqav2", "okvqa"}:
9999
return os.path.join(
100100
self.image_dir_path,
101101
f"COCO_{self.img_coco_split}_{question['image_id']:012d}.jpg"

open_flamingo/eval/eval_models/blip.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,15 @@
1212
class EvalModel(BaseEvalModel):
1313
"""BLIP-2 model evaluation."""
1414

15-
def __init__(self, model_args, init_on_device=False):
16-
super().__init__(model_args, init_on_device)
17-
with self.init_ctx:
18-
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
19-
self.model = Blip2ForConditionalGeneration.from_pretrained(
20-
model_args["lm_path"]
21-
)
22-
self.tokenizer = self.processor.tokenizer
15+
def __init__(self, model_args):
16+
super().__init__(model_args)
17+
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
18+
self.model = Blip2ForConditionalGeneration.from_pretrained(
19+
model_args["lm_path"]
20+
)
21+
self.tokenizer = self.processor.tokenizer
22+
2323
self._check_init()
24-
2524
@property
2625
def required_args(self):
2726
return ["processor_path", "lm_path"]
@@ -100,7 +99,7 @@ def get_outputs(
10099
def get_vqav2_prompt(self, question, answer=None) -> str:
101100
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
102101

103-
def get_ok_vqa_prompt(self, question, answer=None) -> str:
102+
def get_okvqa_prompt(self, question, answer=None) -> str:
104103
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
105104

106105
def get_vizwiz_prompt(self, question, answer=None) -> str:
@@ -112,5 +111,5 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
112111
def get_coco_prompt(self, caption=None) -> str:
113112
return f"A photo of {caption if caption is not None else ''}"
114113

115-
def get_flickr_prompt(self, caption=None) -> str:
114+
def get_flickr30_prompt(self, caption=None) -> str:
116115
return f"A photo of {caption if caption is not None else ''}"

open_flamingo/eval/eval_models/idefics.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,15 @@
1717
class EvalModel(BaseEvalModel):
1818
"""IDEFICS model evaluation."""
1919

20-
def __init__(self, model_args, init_on_device=False):
21-
super().__init__(model_args, init_on_device)
22-
with self.init_ctx:
23-
self.model = IdeficsForVisionText2Text.from_pretrained(
24-
model_args["lm_path"]
25-
)
26-
self.processor = AutoProcessor.from_pretrained(model_args["processor_path"])
27-
self.tokenizer = self.processor.tokenizer
20+
def __init__(self, model_args):
21+
super().__init__(model_args)
22+
self.model = IdeficsForVisionText2Text.from_pretrained(
23+
model_args["lm_path"]
24+
)
25+
self.processor = AutoProcessor.from_pretrained(model_args["processor_path"])
26+
self.tokenizer = self.processor.tokenizer
27+
2828
self._check_init()
29-
3029
@property
3130
def required_args(self):
3231
return ["lm_path", "processor_path"]
@@ -171,7 +170,7 @@ def get_vqav2_prompt(self, question, answer=None) -> str:
171170
# TODO: handle prefix prompts
172171
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
173172

174-
def get_ok_vqa_prompt(self, question, answer=None) -> str:
173+
def get_okvqa_prompt(self, question, answer=None) -> str:
175174
# TODO: handle prefix prompts
176175
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
177176

@@ -187,6 +186,6 @@ def get_coco_prompt(self, caption=None) -> str:
187186
# TODO: handle prefix prompts
188187
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
189188

190-
def get_flickr_prompt(self, caption=None) -> str:
189+
def get_flickr30_prompt(self, caption=None) -> str:
191190
# TODO: handle prefix prompts
192191
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

open_flamingo/eval/eval_models/open_flamingo.py

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,31 +13,40 @@
1313
class EvalModel(BaseEvalModel):
1414
"""OpenFlamingo model evaluation."""
1515

16-
def __init__(self, model_args, init_on_device=False):
17-
super().__init__(model_args, init_on_device)
16+
def __init__(self, model_args):
17+
super().__init__(model_args)
18+
19+
if model_args["model_family"] == "openflamingo":
20+
assert "cross_attn_every_n_layers" in model_args, "cross_attn_every_n_layers is required for Flamingo models"
21+
else:
22+
assert "cross_attn_every_n_layers" not in model_args, "cross_attn_every_n_layers is only for Flamingo models"
23+
1824
# initialize the model
19-
with self.init_ctx:
20-
(
21-
self.model,
22-
self.image_processor,
23-
self.tokenizer,
24-
) = create_model_and_transforms(
25-
clip_vision_encoder_path=model_args["vision_encoder_path"],
26-
clip_vision_encoder_pretrained=model_args["vision_encoder_pretrained"],
27-
lang_model_path=model_args["lm_path"],
28-
tokenizer_path=model_args["tokenizer_path"],
29-
model_family=model_args["model_family"],
30-
cross_attn_every_n_layers=int(
31-
model_args.get("cross_attn_every_n_layers", 1)
32-
),
33-
)
25+
additional_kwargs = (
26+
{"cross_attn_every_n_layers": model_args.get("cross_attn_every_n_layers", 1)}
27+
if model_args["model_family"] == "flamingo"
28+
else {}
29+
)
30+
(
31+
self.model,
32+
self.image_processor,
33+
self.tokenizer,
34+
) = create_model_and_transforms(
35+
clip_vision_encoder_path=model_args["vision_encoder_path"],
36+
clip_vision_encoder_pretrained=model_args["vision_encoder_pretrained"],
37+
lang_model_path=model_args["lm_path"],
38+
tokenizer_path=model_args["tokenizer_path"],
39+
model_family=model_args["model_family"],
40+
**additional_kwargs,
41+
)
3442

3543
# load the checkpoint
3644
checkpoint = torch.load(model_args["checkpoint_path"], map_location="cpu")
3745
if "model_state_dict" in checkpoint:
3846
checkpoint = checkpoint["model_state_dict"]
3947
checkpoint = {k.replace("module.", ""): v for k, v in checkpoint.items()}
4048
self.model.load_state_dict(checkpoint, strict=False)
49+
self.model_family = model_args["model_family"]
4150

4251
self._check_init()
4352

@@ -46,11 +55,10 @@ def required_args(self):
4655
"""Return list of required arguments to initialize model."""
4756
return [
4857
"vision_encoder_path",
49-
"model_familyl",
58+
"model_family",
5059
"lm_path",
5160
"checkpoint_path",
5261
"tokenizer_path",
53-
"cross_attn_every_n_layers",
5462
"vision_encoder_pretrained",
5563
]
5664

@@ -170,8 +178,9 @@ def get_outputs(
170178
**decode_kwargs,
171179
)
172180

173-
# Extract only the new generated tokens
174-
outputs = outputs[:, len(input_ids[0]) :]
181+
if self.model_family == "flamingo":
182+
# Extract only the new generated tokens
183+
outputs = outputs[:, len(input_ids[0]) :]
175184
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
176185

177186
def get_rank_classifications(
@@ -270,8 +279,8 @@ def get_rank_classifications(
270279
def get_vqav2_prompt(self, question, answer=None) -> str:
271280
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
272281

273-
def get_ok_vqa_prompt(self, question, answer=None) -> str:
274-
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
282+
def get_okvqa_prompt(self, question, answer=None) -> str:
283+
return f"<image>Instruct: {question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
275284

276285
def get_vizwiz_prompt(self, question, answer=None) -> str:
277286
return f"<image>Question:{question} Short answer:{answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
@@ -282,7 +291,7 @@ def get_textvqa_prompt(self, question, answer=None) -> str:
282291
def get_coco_prompt(self, caption=None) -> str:
283292
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
284293

285-
def get_flickr_prompt(self, caption=None) -> str:
294+
def get_flickr30_prompt(self, caption=None) -> str:
286295
return f"<image>Output:{caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
287296

288297
def get_imagenet_prompt(self, label=None) -> str:

open_flamingo/eval/evaluate.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@
122122
help="Whether to evaluate on VQAV2.",
123123
)
124124
parser.add_argument(
125-
"--eval_ok_vqa",
125+
"--eval_okvqa",
126126
action="store_true",
127127
default=False,
128128
help="Whether to evaluate on OK-VQA.",
@@ -408,10 +408,8 @@ def main():
408408
model_args["device"] = device_id
409409

410410
# initialize model
411-
eval_model = get_eval_model(args.model, model_args, init_on_device=False)
412-
eval_model.init_distributed(
413-
local_rank=args.local_rank,
414-
)
411+
eval_model = get_eval_model(args.model, model_args)
412+
eval_model.init_distributed()
415413

416414
# Validate args
417415
if args.model in ZERO_SHOT_ONLY_MODELS and args.shots != [0]:
@@ -504,7 +502,7 @@ def main():
504502
}
505503
)
506504

507-
if args.eval_ok_vqa:
505+
if args.eval_okvqa:
508506
print("Evaluating on OK-VQA...")
509507

510508
# load cached demonstration features for RICES
@@ -523,7 +521,7 @@ def main():
523521
eval_model=eval_model,
524522
num_shots=shot,
525523
seed=seed,
526-
dataset_name="ok_vqa",
524+
dataset_name="okvqa",
527525
cached_features=cached_features,
528526
)
529527
if args.rank == 0:
@@ -919,7 +917,7 @@ def evaluate_vqa(
919917
seed: int = 42,
920918
min_new_tokens: int = 0,
921919
max_new_tokens: int = 5,
922-
num_beams: int = 3,
920+
num_beams: int = 5,
923921
length_penalty: float = 0.0,
924922
num_shots: int = 8,
925923
dataset_name: str = "vqav2",
@@ -936,13 +934,13 @@ def evaluate_vqa(
936934
num_beams (int, optional): number of beams to use for beam search. Defaults to 3.
937935
length_penalty (float, optional): length penalty for beam search. Defaults to -2.0.
938936
num_shots (int, optional): number of shots to use. Defaults to 8.
939-
dataset_name (string): type of vqa dataset: currently supports vqav2, ok_vqa. Defaults to vqav2.
937+
dataset_name (string): type of vqa dataset: currently supports vqav2, okvqa. Defaults to vqav2.
940938
cached_features (tensor, optional): cached demonstration features for RICES. Defaults to None.
941939
Returns:
942940
float: accuracy score
943941
"""
944942

945-
if dataset_name == "ok_vqa":
943+
if dataset_name == "okvqa":
946944
train_image_dir_path = args.ok_vqa_train_image_dir_path
947945
train_questions_json_path = args.ok_vqa_train_questions_json_path
948946
train_annotations_json_path = args.ok_vqa_train_annotations_json_path
@@ -989,7 +987,7 @@ def evaluate_vqa(
989987
dataset_name=dataset_name,
990988
)
991989

992-
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
990+
effective_num_shots = num_shots #utils.compute_effective_num_shots(num_shots, args.model)
993991

994992
np.random.seed(seed)
995993
test_dataloader = utils.prepare_eval_samples(
@@ -1012,6 +1010,11 @@ def evaluate_vqa(
10121010

10131011
utils.random_seed(seed, args.rank)
10141012
predictions = []
1013+
1014+
get_vqa_prompt = getattr(
1015+
eval_model, f"get_{dataset_name}_prompt"
1016+
)
1017+
10151018
for batch in tqdm(
10161019
test_dataloader,
10171020
desc=f"Running inference {dataset_name}",
@@ -1034,7 +1037,7 @@ def evaluate_vqa(
10341037

10351038
context_text = "".join(
10361039
[
1037-
eval_model.get_vqa_prompt(
1040+
get_vqa_prompt(
10381041
question=x["question"], answer=x["answers"][0]
10391042
)
10401043
+ "\n"
@@ -1047,9 +1050,9 @@ def evaluate_vqa(
10471050
context_text = context_text.replace("<image>", "")
10481051

10491052
batch_text.append(
1050-
context_text + eval_model.get_vqa_prompt(question=batch["question"][i])
1053+
context_text + get_vqa_prompt(question=batch["question"][i])
10511054
)
1052-
1055+
10531056
outputs = eval_model.get_outputs(
10541057
batch_images=batch_images,
10551058
batch_text=batch_text,
@@ -1186,7 +1189,7 @@ def evaluate_classification(
11861189

11871190
class_id_to_name = dict(zip(range(len(all_class_names)), all_class_names))
11881191

1189-
effective_num_shots = utils.compute_effective_num_shots(num_shots, args.model)
1192+
effective_num_shots = num_shots #utils.compute_effective_num_shots(num_shots, args.model)
11901193

11911194
np.random.seed(seed)
11921195
test_dataloader = utils.prepare_eval_samples(

open_flamingo/eval/ok_vqa_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def stem(self, input_string):
210210

211211
def postprocess_ok_vqa_generation(predictions) -> str:
212212
prediction = re.split("Question|Answer|Short", predictions, 1)[0]
213+
prediction = prediction.split(". ", 1)[0]
213214
prediction = re.split(", ", prediction, 1)[0]
214215
prediction_stem = stemmer.stem(prediction)
215216
return prediction_stem

open_flamingo/eval/vqa_metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,5 +556,6 @@ def compute_vqa_accuracy(result_json_path, question_json_path, annotation_json_p
556556

557557
def postprocess_vqa_generation(predictions):
558558
answer = re.split("Question|Answer|Short", predictions, 1)[0]
559+
answer = answer.split(". ", 1)[0]
559560
answer = re.split(", ", answer, 1)[0]
560561
return answer

open_flamingo/src/blip.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,6 @@ def set_trainable(self):
5050
"""
5151
self.requires_grad_(False)
5252
self.vision_tokenizer.requires_grad_(True)
53-
self.lang_model.get_output_embeddings().set_requires_grad(
54-
require_regular_grad=False,
55-
require_additional_grad=True,
56-
)
57-
self.lang_model.get_input_embeddings().set_requires_grad(
58-
require_regular_grad=False,
59-
require_additional_grad=True,
60-
)
6153

6254
def _should_apply_weight_decay(self, parameter_name):
6355
"""BLIP applies 0.05 weight decay to everything"""

open_flamingo/src/factory.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,12 @@ def create_model_and_transforms(
5858
clip_vision_encoder_path,
5959
pretrained=clip_vision_encoder_pretrained,
6060
cache_dir=cache_dir,
61+
force_image_size=490,
6162
)
6263
vision_encoder.visual.output_tokens = True
6364
vision_encoder = vision_encoder.visual
6465
vision_encoder_config = open_clip.get_model_config(clip_vision_encoder_path)
65-
if "SigLIP" in clip_vision_encoder_path: # SigLIP models have a different config format
66+
if "SigLIP" in clip_vision_encoder_path or "EVA" in clip_vision_encoder_path: # SigLIP models have a different config format
6667
vis_hidden_dim = vision_encoder_config["embed_dim"]
6768
else:
6869
vis_hidden_dim = vision_encoder_config["vision_cfg"]["width"]
@@ -74,8 +75,9 @@ def create_model_and_transforms(
7475
trust_remote_code=True,
7576
cache_dir=cache_dir,
7677
)
77-
if text_tokenizer.pad_token is None:
78-
text_tokenizer.pad_token_id = text_tokenizer.eos_token_id
78+
if text_tokenizer.pad_token is None or text_tokenizer.pad_token == text_tokenizer.eos_token:
79+
# add a pad token if it doesn't exist
80+
text_tokenizer.add_special_tokens({"pad_token": "<pad>"})
7981

8082
# load langauge model
8183
lang_model = AutoModelForCausalLM.from_pretrained(
@@ -150,6 +152,9 @@ def _infer_decoder_layers_attr_name(model):
150152
"gemma": "model.layers",
151153
"phi": "model.layers",
152154
"minicpm": "model.layers",
155+
"stablelm": "model.layers",
156+
"qwen": "model.layers",
157+
"mistral": "model.layers"
153158
}
154159

155160

0 commit comments

Comments
 (0)