Skip to content

Commit b0ff9a4

Browse files
committed
update eval code to match new src args
1 parent be9a4dd commit b0ff9a4

File tree

8 files changed

+262
-204
lines changed

8 files changed

+262
-204
lines changed

open_flamingo/eval/eval_datasets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77

88
from open_flamingo.eval.classification_utils import IMAGENET_CLASSNAMES
99

10+
SUPPORTED_TASKS = ["coco", "flickr", "vqav2", "ok_vqa", "vizwiz", "textvqa", "hateful_memes", "imagenet"]
11+
1012

1113
class CaptionDataset(Dataset):
1214
def __init__(
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .eval_model import *

open_flamingo/eval/models/blip.py renamed to open_flamingo/eval/eval_models/blip.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@ class EvalModel(BaseEvalModel):
1313
"""BLIP-2 model evaluation."""
1414

1515
def __init__(self, model_args, init_on_device=False):
16-
assert (
17-
"processor_path" in model_args and "lm_path" in model_args
18-
), "BLIP-2 requires processor_path, lm_path, and device arguments to be specified"
1916
super().__init__(model_args, init_on_device)
2017
with self.init_ctx:
2118
self.processor = Blip2Processor.from_pretrained(model_args["processor_path"])
@@ -25,6 +22,10 @@ def __init__(self, model_args, init_on_device=False):
2522
self.tokenizer = self.processor.tokenizer
2623
self._check_init()
2724

25+
@property
26+
def required_args(self):
27+
return ["processor_path", "lm_path"]
28+
2829
def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
2930
batch_images = None
3031
assert all(
@@ -58,6 +59,7 @@ def prepare_text(
5859
max_length=2000,
5960
add_special_tokens=True,
6061
):
62+
self._validate_text(batch)
6163
encodings = self.tokenizer(
6264
batch,
6365
padding=padding,
@@ -95,39 +97,20 @@ def get_outputs(
9597

9698
return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
9799

98-
def get_vqa_prompt(self, question, answer=None) -> str:
99-
return (
100-
f"Question:{question} Short answer:{answer if answer is not None else ''}"
101-
)
102-
103-
def get_caption_prompt(self, caption=None) -> str:
100+
def get_vqav2_prompt(self, question, answer=None) -> str:
101+
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
102+
103+
def get_ok_vqa_prompt(self, question, answer=None) -> str:
104+
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
105+
106+
def get_vizwiz_prompt(self, question, answer=None) -> str:
107+
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
108+
109+
def get_textvqa_prompt(self, question, answer=None) -> str:
110+
return f"Question:{question} Short answer:{answer if answer is not None else ''}"
111+
112+
def get_coco_prompt(self, caption=None) -> str:
113+
return f"A photo of {caption if caption is not None else ''}"
114+
115+
def get_flickr_prompt(self, caption=None) -> str:
104116
return f"A photo of {caption if caption is not None else ''}"
105-
106-
def __call__(
107-
self,
108-
lang_x: torch.Tensor,
109-
vision_x: torch.Tensor,
110-
attention_mask: torch.Tensor,
111-
):
112-
with self.autocast():
113-
outputs = self.model(
114-
pixel_values=vision_x,
115-
input_ids=lang_x,
116-
attention_mask=attention_mask,
117-
)
118-
119-
# remove vision tokens
120-
outputs.logits = outputs.logits[:, -lang_x.size(1) :, :]
121-
return outputs
122-
123-
def get_rank_classifications(
124-
self,
125-
batch_text: List[str],
126-
batch_images: List[List[Image.Image]],
127-
all_class_names: List[str],
128-
use_cache: bool,
129-
normalize_length: bool,
130-
):
131-
raise NotImplementedError(
132-
"BLIP-2 classification-based evaluation not implemented"
133-
)

open_flamingo/eval/eval_model.py renamed to open_flamingo/eval/eval_models/eval_model.py

Lines changed: 82 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,36 +6,75 @@
66
import torch
77
from contextlib import suppress
88

9+
SUPPORTED_MODELS = ["open_flamingo", "blip", "idefics"]
10+
ZERO_SHOT_ONLY_MODELS = ["blip"]
11+
12+
13+
def get_eval_model(name, *args, **kwargs):
14+
"""Return an EvalModel object."""
15+
if name == "open_flamingo":
16+
from .open_flamingo import EvalModel
17+
18+
return EvalModel(*args, **kwargs)
19+
elif name == "blip":
20+
from .blip import EvalModel
21+
22+
return EvalModel(*args, **kwargs)
23+
elif name == "idefics":
24+
from .idefics import EvalModel
25+
26+
return EvalModel(*args, **kwargs)
27+
else:
28+
raise ValueError(f"Unsupported EvalModel type {name}")
29+
30+
931
class BaseEvalModel(abc.ABC):
1032
"""Base class encapsulating functionality needed to evaluate a model."""
1133

12-
def __init__(self, model_args: List[str]):
34+
def __init__(self, model_args: List[str], init_on_device=False):
1335
"""Initialize model.
1436
1537
Args:
1638
args: arguments to model. These should be parsed, or if the model
1739
has no applicable arguments, an error should be thrown if `args`
1840
is non-empty.
1941
"""
42+
# check model args
43+
assert all(
44+
arg in model_args for arg in self.required_args
45+
), f"Missing required args for {self.__class__.__name__}: {self.required_args}"
46+
self.lm_name = model_args["lm_path"].split("/")[-1]
2047

21-
def __init__(self, model_args, init_on_device=False):
22-
assert "lm_path" in model_args, "All models require the lm_path argument"
48+
# set device and precision
2349
self.device = (
2450
model_args["device"]
25-
if ("device" in model_args and (type(model_args["device"]) != int or model_args["device"] >= 0))
51+
if (
52+
"device" in model_args
53+
and (type(model_args["device"]) != int or model_args["device"] >= 0)
54+
)
2655
else "cpu"
2756
)
57+
print("Using device:", self.device)
2858
self.precision = model_args.get("precision", "fp32")
29-
self.lm_name = model_args["lm_path"].split("/")[-1]
3059
self.autocast = get_autocast(self.precision)
3160
self.cast_dtype = get_cast_dtype(self.precision)
61+
62+
# initialization context
3263
if init_on_device:
33-
# for deepspeed, must init on device, or likely CPU OOM
64+
# for deepspeed, must init on device, or likely CPU OOM
3465
import deepspeed
35-
self.init_ctx = deepspeed.OnDevice(dtype=self.cast_dtype, device=self.device)
66+
67+
self.init_ctx = deepspeed.OnDevice(
68+
dtype=self.cast_dtype, device=self.device
69+
)
3670
else:
3771
self.init_ctx = suppress()
3872

73+
@property
74+
def required_args(self):
75+
"""Return list of required arguments to initialize model."""
76+
return ["lm_path"]
77+
3978
def _check_init(self):
4079
"""Finish model initialization."""
4180
assert hasattr(self, "model"), "Model has not been initialized"
@@ -49,6 +88,7 @@ def init_distributed(self, world_size=None, use_deepspeed=False):
4988
if use_deepspeed:
5089
assert "amp" not in self.precision, "Deepspeed does not support amp"
5190
import deepspeed
91+
5292
self.ds_engine = deepspeed.init_inference(
5393
self.model,
5494
mp_size=world_size,
@@ -61,12 +101,6 @@ def init_distributed(self, world_size=None, use_deepspeed=False):
61101
else:
62102
self.model = DDP(self.model, device_ids=[self.device])
63103

64-
def set_device(self, device):
65-
"""Set device for model."""
66-
torch.cuda.set_device(device)
67-
self.device = torch.device("cuda", device)
68-
self.model = self.model.to(device, dtype=self.cast_dtype)
69-
70104
def __call__(
71105
self,
72106
lang_x: torch.Tensor,
@@ -76,12 +110,13 @@ def __call__(
76110
use_cache: bool = False,
77111
):
78112
"""
79-
Calls the forward function of the model.
80-
Special logic to handle the case if past_key_values is not None:
113+
Calls the forward function of the model, and returns an object that includes logits.
114+
Note: implementations should handle the case if past_key_values is not None:
81115
then lang_x is assumed to contain the tokens to be generated
82116
*excluding* the tokens already in past_key_values.
83117
We then repeatedly call forward, updating the past_key_values.
84118
"""
119+
raise NotImplementedError
85120

86121
def prepare_text(
87122
self,
@@ -92,7 +127,7 @@ def prepare_text(
92127
add_special_tokens=True,
93128
):
94129
"""
95-
Prepare text for model.
130+
Prepare text for model. Note that padding is always on the left.
96131
97132
Args:
98133
batch: list of text strings
@@ -101,36 +136,38 @@ def prepare_text(
101136
max_length: maximum length of the text
102137
103138
Returns:
104-
input_ids: tensor of shape (B, T)
105-
attention_mask: tensor of shape (B, T)
139+
input_ids: tensor of shape (B, T_txt)
140+
attention_mask: tensor of shape (B, T_txt)
106141
"""
142+
raise NotImplementedError
107143

108144
def prepare_images(self, batch: List[List[Image.Image]]):
109145
"""
110146
Prepare images for model.
111147
Args:
112148
batch: list of lists of PIL images
113149
Returns:
114-
tensor of shape (B, T, *, C, H, W)
150+
tensor of shape (B, T_img, F, C, H, W)
115151
"""
152+
raise NotImplementedError
116153

117154
def get_outputs(
118155
self,
119156
batch_text: List[str],
120157
batch_images: List[List[Image.Image]],
121158
**decode_kwargs,
122159
) -> List[str]:
123-
"""Get outputs for a batch of images and text.
160+
"""Call generate on a batch of images and text.
124161
125162
Args:
126-
batch_text: list of text strings, with the text "<image>" in place
127-
of any images to be included.
163+
batch_text: list of text strings
128164
batch_images: images to provide to model. Should be a list of lists,
129165
where each list contains the images for a single example.
130166
131167
Returns:
132168
List of decoded output strings.
133169
"""
170+
raise NotImplementedError
134171

135172
def get_rank_classifications(
136173
self,
@@ -150,7 +187,29 @@ def get_rank_classifications(
150187
all_class_names: list of all class names.
151188
use_cache: whether to cache the context to speed up evaluations.
152189
normalize_length: whether to normalize logprobs by the length of the
153-
class name
190+
class name; use with caution, as this can change predictions quite a bit.
154191
Returns:
155192
(B, |all_class_names|) tensor containing the logprobs for each class name.
156193
"""
194+
raise NotImplementedError
195+
196+
@property
197+
def supported_tasks(self):
198+
"""
199+
Return list of tasks that this model can be evaluated on.
200+
Parsed by checking whether the model has a method called `get_{task}_prompt`.
201+
"""
202+
return [
203+
task.split("_")[1]
204+
for task in dir(self)
205+
if task.startswith("get_") and task.endswith("_prompt")
206+
]
207+
208+
def _validate_text(self, batch_text):
209+
"""
210+
Checks for trailing whitespaces in the text and prints a warning.
211+
"""
212+
if any([x.endswith(" ") for x in batch_text]):
213+
print(
214+
"Warning: trailing whitespace detected in text. This can cause unexpected behavior."
215+
)

open_flamingo/eval/models/idefics.py renamed to open_flamingo/eval/eval_models/idefics.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@ class EvalModel(BaseEvalModel):
1818
"""IDEFICS model evaluation."""
1919

2020
def __init__(self, model_args, init_on_device=False):
21-
assert (
22-
"lm_path" in model_args and "processor_path" in model_args
23-
), "IDEFICS requires lm_path and lm_tokenizer_path"
2421
super().__init__(model_args, init_on_device)
2522
with self.init_ctx:
2623
self.model = IdeficsForVisionText2Text.from_pretrained(model_args["lm_path"])
2724
self.processor = AutoProcessor.from_pretrained(model_args["processor_path"])
2825
self.tokenizer = self.processor.tokenizer
2926
self._check_init()
3027

28+
@property
29+
def required_args(self):
30+
return ["lm_path", "processor_path"]
31+
3132
def prepare_images(self, batch: List[List[Image.Image]]) -> torch.Tensor:
3233
batch_images = self.processor(batch)["pixel_values"]
3334
if batch_images is not None:
@@ -44,6 +45,7 @@ def prepare_text(
4445
max_length=2000,
4546
add_special_tokens=True,
4647
):
48+
self._validate_text(batch)
4749
# check to see if there any <image> without <fake_token_around_image> wrapping it
4850
for i, text in enumerate(batch):
4951
if "<image>" in text and "<fake_token_around_image>" not in text:
@@ -88,19 +90,6 @@ def _compute_image_attention_mask(self, batch_tokens: torch.Tensor) -> torch.Ten
8890
)
8991
return image_attention_mask
9092

91-
def get_rank_classifications(
92-
self,
93-
batch_text: List[str],
94-
batch_images: List[List[Image.Image]],
95-
all_class_names: List[str],
96-
use_cache: bool,
97-
normalize_length: bool,
98-
):
99-
"""
100-
Returns a (B, |all_class_names|) tensor containing the logprobs for each class name.
101-
"""
102-
raise NotImplementedError
103-
10493
def get_outputs(
10594
self,
10695
batch_text: List[str],
@@ -176,18 +165,26 @@ def __call__(
176165
past_key_values=past_key_values,
177166
)
178167

179-
def get_vqa_prompt(self, question, answer=None) -> str:
168+
def get_vqav2_prompt(self, question, answer=None) -> str:
180169
# TODO: handle prefix prompts
181170
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
182-
183-
def get_caption_prompt(self, caption=None) -> str:
171+
172+
def get_ok_vqa_prompt(self, question, answer=None) -> str:
184173
# TODO: handle prefix prompts
185-
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
174+
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
175+
176+
def get_vizwiz_prompt(self, question, answer=None) -> str:
177+
# TODO: handle prefix prompts
178+
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
186179

187-
def get_imagenet_prompt(self, label=None) -> str:
180+
def get_textvqa_prompt(self, question, answer=None) -> str:
188181
# TODO: handle prefix prompts
189-
return f"<image>Output:{label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
182+
return f"<image>Question:{question} Answer: {answer if answer is not None else ''}{'<|endofchunk|>' if answer is not None else ''}"
190183

191-
def get_hateful_memes_prompt(self, text, label=None) -> str:
184+
def get_coco_prompt(self, caption=None) -> str:
185+
# TODO: handle prefix prompts
186+
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"
187+
188+
def get_flickr_prompt(self, caption=None) -> str:
192189
# TODO: handle prefix prompts
193-
return f"<image>is an image with: '{text}' written on it. Is it hateful? Answer: {label if label is not None else ''}{'<|endofchunk|>' if label is not None else ''}"
190+
return f"<image>Caption: {caption if caption is not None else ''}{'<|endofchunk|>' if caption is not None else ''}"

0 commit comments

Comments
 (0)