6
6
import torch
7
7
from contextlib import suppress
8
8
9
+ SUPPORTED_MODELS = ["open_flamingo" , "blip" , "idefics" ]
10
+ ZERO_SHOT_ONLY_MODELS = ["blip" ]
11
+
12
+
13
+ def get_eval_model (name , * args , ** kwargs ):
14
+ """Return an EvalModel object."""
15
+ if name == "open_flamingo" :
16
+ from .open_flamingo import EvalModel
17
+
18
+ return EvalModel (* args , ** kwargs )
19
+ elif name == "blip" :
20
+ from .blip import EvalModel
21
+
22
+ return EvalModel (* args , ** kwargs )
23
+ elif name == "idefics" :
24
+ from .idefics import EvalModel
25
+
26
+ return EvalModel (* args , ** kwargs )
27
+ else :
28
+ raise ValueError (f"Unsupported EvalModel type { name } " )
29
+
30
+
9
31
class BaseEvalModel (abc .ABC ):
10
32
"""Base class encapsulating functionality needed to evaluate a model."""
11
33
12
- def __init__ (self , model_args : List [str ]):
34
+ def __init__ (self , model_args : List [str ], init_on_device = False ):
13
35
"""Initialize model.
14
36
15
37
Args:
16
38
args: arguments to model. These should be parsed, or if the model
17
39
has no applicable arguments, an error should be thrown if `args`
18
40
is non-empty.
19
41
"""
42
+ # check model args
43
+ assert all (
44
+ arg in model_args for arg in self .required_args
45
+ ), f"Missing required args for { self .__class__ .__name__ } : { self .required_args } "
46
+ self .lm_name = model_args ["lm_path" ].split ("/" )[- 1 ]
20
47
21
- def __init__ (self , model_args , init_on_device = False ):
22
- assert "lm_path" in model_args , "All models require the lm_path argument"
48
+ # set device and precision
23
49
self .device = (
24
50
model_args ["device" ]
25
- if ("device" in model_args and (type (model_args ["device" ]) != int or model_args ["device" ] >= 0 ))
51
+ if (
52
+ "device" in model_args
53
+ and (type (model_args ["device" ]) != int or model_args ["device" ] >= 0 )
54
+ )
26
55
else "cpu"
27
56
)
57
+ print ("Using device:" , self .device )
28
58
self .precision = model_args .get ("precision" , "fp32" )
29
- self .lm_name = model_args ["lm_path" ].split ("/" )[- 1 ]
30
59
self .autocast = get_autocast (self .precision )
31
60
self .cast_dtype = get_cast_dtype (self .precision )
61
+
62
+ # initialization context
32
63
if init_on_device :
33
- # for deepspeed, must init on device, or likely CPU OOM
64
+ # for deepspeed, must init on device, or likely CPU OOM
34
65
import deepspeed
35
- self .init_ctx = deepspeed .OnDevice (dtype = self .cast_dtype , device = self .device )
66
+
67
+ self .init_ctx = deepspeed .OnDevice (
68
+ dtype = self .cast_dtype , device = self .device
69
+ )
36
70
else :
37
71
self .init_ctx = suppress ()
38
72
73
+ @property
74
+ def required_args (self ):
75
+ """Return list of required arguments to initialize model."""
76
+ return ["lm_path" ]
77
+
39
78
def _check_init (self ):
40
79
"""Finish model initialization."""
41
80
assert hasattr (self , "model" ), "Model has not been initialized"
@@ -49,6 +88,7 @@ def init_distributed(self, world_size=None, use_deepspeed=False):
49
88
if use_deepspeed :
50
89
assert "amp" not in self .precision , "Deepspeed does not support amp"
51
90
import deepspeed
91
+
52
92
self .ds_engine = deepspeed .init_inference (
53
93
self .model ,
54
94
mp_size = world_size ,
@@ -61,12 +101,6 @@ def init_distributed(self, world_size=None, use_deepspeed=False):
61
101
else :
62
102
self .model = DDP (self .model , device_ids = [self .device ])
63
103
64
- def set_device (self , device ):
65
- """Set device for model."""
66
- torch .cuda .set_device (device )
67
- self .device = torch .device ("cuda" , device )
68
- self .model = self .model .to (device , dtype = self .cast_dtype )
69
-
70
104
def __call__ (
71
105
self ,
72
106
lang_x : torch .Tensor ,
@@ -76,12 +110,13 @@ def __call__(
76
110
use_cache : bool = False ,
77
111
):
78
112
"""
79
- Calls the forward function of the model.
80
- Special logic to handle the case if past_key_values is not None:
113
+ Calls the forward function of the model, and returns an object that includes logits .
114
+ Note: implementations should handle the case if past_key_values is not None:
81
115
then lang_x is assumed to contain the tokens to be generated
82
116
*excluding* the tokens already in past_key_values.
83
117
We then repeatedly call forward, updating the past_key_values.
84
118
"""
119
+ raise NotImplementedError
85
120
86
121
def prepare_text (
87
122
self ,
@@ -92,7 +127,7 @@ def prepare_text(
92
127
add_special_tokens = True ,
93
128
):
94
129
"""
95
- Prepare text for model.
130
+ Prepare text for model. Note that padding is always on the left.
96
131
97
132
Args:
98
133
batch: list of text strings
@@ -101,36 +136,38 @@ def prepare_text(
101
136
max_length: maximum length of the text
102
137
103
138
Returns:
104
- input_ids: tensor of shape (B, T )
105
- attention_mask: tensor of shape (B, T )
139
+ input_ids: tensor of shape (B, T_txt )
140
+ attention_mask: tensor of shape (B, T_txt )
106
141
"""
142
+ raise NotImplementedError
107
143
108
144
def prepare_images (self , batch : List [List [Image .Image ]]):
109
145
"""
110
146
Prepare images for model.
111
147
Args:
112
148
batch: list of lists of PIL images
113
149
Returns:
114
- tensor of shape (B, T, * , C, H, W)
150
+ tensor of shape (B, T_img, F , C, H, W)
115
151
"""
152
+ raise NotImplementedError
116
153
117
154
def get_outputs (
118
155
self ,
119
156
batch_text : List [str ],
120
157
batch_images : List [List [Image .Image ]],
121
158
** decode_kwargs ,
122
159
) -> List [str ]:
123
- """Get outputs for a batch of images and text.
160
+ """Call generate on a batch of images and text.
124
161
125
162
Args:
126
- batch_text: list of text strings, with the text "<image>" in place
127
- of any images to be included.
163
+ batch_text: list of text strings
128
164
batch_images: images to provide to model. Should be a list of lists,
129
165
where each list contains the images for a single example.
130
166
131
167
Returns:
132
168
List of decoded output strings.
133
169
"""
170
+ raise NotImplementedError
134
171
135
172
def get_rank_classifications (
136
173
self ,
@@ -150,7 +187,29 @@ def get_rank_classifications(
150
187
all_class_names: list of all class names.
151
188
use_cache: whether to cache the context to speed up evaluations.
152
189
normalize_length: whether to normalize logprobs by the length of the
153
- class name
190
+ class name; use with caution, as this can change predictions quite a bit.
154
191
Returns:
155
192
(B, |all_class_names|) tensor containing the logprobs for each class name.
156
193
"""
194
+ raise NotImplementedError
195
+
196
+ @property
197
+ def supported_tasks (self ):
198
+ """
199
+ Return list of tasks that this model can be evaluated on.
200
+ Parsed by checking whether the model has a method called `get_{task}_prompt`.
201
+ """
202
+ return [
203
+ task .split ("_" )[1 ]
204
+ for task in dir (self )
205
+ if task .startswith ("get_" ) and task .endswith ("_prompt" )
206
+ ]
207
+
208
+ def _validate_text (self , batch_text ):
209
+ """
210
+ Checks for trailing whitespaces in the text and prints a warning.
211
+ """
212
+ if any ([x .endswith (" " ) for x in batch_text ]):
213
+ print (
214
+ "Warning: trailing whitespace detected in text. This can cause unexpected behavior."
215
+ )
0 commit comments