Skip to content

Commit a781057

Browse files
authored
[Datasets & Models] Fuyu, HalluBench (w/Kaichen, commit 96d95b3) (EvolvingLMMs-Lab#33)
* add fuyu * Merge commit '6d570ac1d98a03585c8119ccb362e13ab2172fed' * Squashed commit of the following: commit 09c64b7491cd19d4e6c4a6e1a38254eaa74d0032 Author: kcz358 <[email protected]> Date: Tue Jan 30 19:39:57 2024 +0800 Add hallu bench commit 6d570ac Author: Pu Fanyi <[email protected]> Date: Tue Jan 30 14:52:51 2024 +0800 scienceqa for full set (EvolvingLMMs-Lab#32) * Remove unused code and configuration file * Remove docvqa.yaml and update vizwizvqa.yaml * lint * Add dataset_kwargs to vizwizvqa.yaml * Add dataset_kwargs to vizwizvqa.yaml * textvqa (EvolvingLMMs-Lab#27) * Update textvqa.yaml and utils.py * Fix YAML formatting in textvqa.yaml and remove unused files * remove useless matric * add textvqa val & test * Update progress bar description in evaluator.py * Update submission file names in VizWizVQA tasks * Update output path to include log samples suffix * Update submission file paths in OKVQA and VizWizVQA tasks * Refactor llava-in-the-wild.yaml and utils.py * Update metric for llava evaluation * Refactor logging message in Task class * Merge commit 'f92c3d6d10a8b0b7a0b42baa60cb364b99525b4e' * Fix formatting issues and add progress bar closing statements * Update task from "infovqa_val" to "infovqa_test" in infovqa_test.yaml * Update tqdm progress bar in OtterHD model * Squashed commit of the following: commit eae210c3700a59b7d5cc9de46fcb855f443096aa Author: kcz358 <[email protected]> Date: Sun Jan 28 09:46:19 2024 +0800 Black lint commit 18e4a19e82357352ab25df77b5ae4f1b011d61ae Merge: ab898e4 fb209e4 Author: kcz358 <[email protected]> Date: Sun Jan 28 09:45:31 2024 +0800 Merge branch 'main' into kc/list_tasks_num commit e899be48f55f95172fdf96bd2a98d3b91ff2aaed Author: kcz358 <[email protected]> Date: Sun Jan 28 09:44:23 2024 +0800 Enable list all tasks num commit a999fc6889c6986c28ec5d95460a4ab5233e5d4f Author: kcz358 <[email protected]> Date: Sun Jan 28 09:41:32 2024 +0800 Exclude train yaml file in the task list commit f92c3d6 Author: Zhang Peiyuan <[email protected]> Date: Sun Jan 28 02:04:57 2024 +0800 Add InfoVQA, DocVQA, and QwenVL (EvolvingLMMs-Lab#28) * add mmme * black * add model specific prompt and gen kwargs * black * add yaml config to supprot multi-model eval * print table at the end * refactor multi model code * add chartqa * black * add ai2d * black * update chartqa * blacl * update ai2d dataset * black * add qwenvl * add infovqa and docvqa * Fix error handling in loading YAML config files * Squashed commit of the following: commit fdb0c6785b0c5d6979d10e7ddf75ce9055038db8 Author: kcz358 <[email protected]> Date: Sun Jan 28 12:41:40 2024 +0800 Fix key bugs commit eae210c3700a59b7d5cc9de46fcb855f443096aa Author: kcz358 <[email protected]> Date: Sun Jan 28 09:46:19 2024 +0800 Black lint commit 18e4a19e82357352ab25df77b5ae4f1b011d61ae Merge: ab898e4 fb209e4 Author: kcz358 <[email protected]> Date: Sun Jan 28 09:45:31 2024 +0800 Merge branch 'main' into kc/list_tasks_num commit e899be48f55f95172fdf96bd2a98d3b91ff2aaed Author: kcz358 <[email protected]> Date: Sun Jan 28 09:44:23 2024 +0800 Enable list all tasks num commit a999fc6889c6986c28ec5d95460a4ab5233e5d4f Author: kcz358 <[email protected]> Date: Sun Jan 28 09:41:32 2024 +0800 Exclude train yaml file in the task list commit f92c3d6 Author: Zhang Peiyuan <[email protected]> Date: Sun Jan 28 02:04:57 2024 +0800 Add InfoVQA, DocVQA, and QwenVL (EvolvingLMMs-Lab#28) * add mmme * black * add model specific prompt and gen kwargs * black * add yaml config to supprot multi-model eval * print table at the end * refactor multi model code * add chartqa * black * add ai2d * black * update chartqa * blacl * update ai2d dataset * black * add qwenvl * add infovqa and docvqa * List task #num sorted * Update prompt messages for image-related tasks * Delete unused task configuration files * Remove coco_train.yaml configuration file * Update task name in mmmu.yaml * Fix error message for missing tasks * Add wandb import and integration * Update generation kwargs for LMMS tasks * Update lmms_eval MME task configuration and utils * Update generation_kwargs in lmms_eval tasks * Update doc_to_text function in coco and okvqa tasks * Add COCO 2017 version * Update task name in coco_test2017.yaml * Squashed commit of the following: commit fbb7aa5 Author: Zhang Peiyuan <[email protected]> Date: Mon Jan 29 22:41:33 2024 +0800 Add/mmmu test (EvolvingLMMs-Lab#30) * mmmu_test * black commit b8ba33c Author: Li Bo <[email protected]> Date: Sun Jan 28 22:19:13 2024 +0800 [Dataset Check] dataset check and add wandb logging (EvolvingLMMs-Lab#29) * Remove unused code and configuration file * Remove docvqa.yaml and update vizwizvqa.yaml * lint * Add dataset_kwargs to vizwizvqa.yaml * Add dataset_kwargs to vizwizvqa.yaml * textvqa (EvolvingLMMs-Lab#27) * Update textvqa.yaml and utils.py * Fix YAML formatting in textvqa.yaml and remove unused files * remove useless matric * add textvqa val & test * Update progress bar description in evaluator.py * Update submission file names in VizWizVQA tasks * Update output path to include log samples suffix * Update submission file paths in OKVQA and VizWizVQA tasks * Refactor llava-in-the-wild.yaml and utils.py * Update metric for llava evaluation * Refactor logging message in Task class * Merge commit 'f92c3d6d10a8b0b7a0b42baa60cb364b99525b4e' * Fix formatting issues and add progress bar closing statements * Update task from "infovqa_val" to "infovqa_test" in infovqa_test.yaml * Update tqdm progress bar in OtterHD model * Squashed commit of the following: commit eae210c3700a59b7d5cc9de46fcb855f443096aa Author: kcz358 <[email protected]> Date: Sun Jan 28 09:46:19 2024 +0800 Black lint commit 18e4a19e82357352ab25df77b5ae4f1b011d61ae Merge: ab898e4 fb209e4 Author: kcz358 <[email protected]> Date: Sun Jan 28 09:45:31 2024 +0800 Merge branch 'main' into kc/list_tasks_num commit e899be48f55f95172fdf96bd2a98d3b91ff2aaed Author: kcz358 <[email protected]> Date: Sun Jan 28 09:44:23 2024 +0800 Enable list all tasks num commit a999fc6889c6986c28ec5d95460a4ab5233e5d4f Author: kcz358 <[email protected]> Date: Sun Jan 28 09:41:32 2024 +0800 Exclude train yaml file in the task list commit f92c3d6 Author: Zhang Peiyuan <[email protected]> Date: Sun Jan 28 02:04:57 2024 +0800 Add InfoVQA, DocVQA, and QwenVL (EvolvingLMMs-Lab#28) * add mmme * black * add model specific prompt and gen kwargs * black * add yaml config to supprot multi-model eval * print table at the end * refactor multi model code * add chartqa * black * add ai2d * black * update chartqa * blacl * update ai2d dataset * black * add qwenvl * add infovqa and docvqa * Fix error handling in loading YAML config files * Squashed commit of the following: commit fdb0c6785b0c5d6979d10e7ddf75ce9055038db8 Author: kcz358 <[email protected]> Date: Sun Jan 28 12:41:40 2024 +0800 Fix key bugs commit eae210c3700a59b7d5cc9de46fcb855f443096aa Author: kcz358 <[email protected]> Date: Sun Jan 28 09:46:19 2024 +0800 Black lint commit 18e4a19e82357352ab25df77b5ae4f1b011d61ae Merge: ab898e4 fb209e4 Author: kcz358 <[email protected]> Date: Sun Jan 28 09:45:31 2024 +0800 Merge branch 'main' into kc/list_tasks_num commit e899be48f55f95172fdf96bd2a98d3b91ff2aaed Author: kcz358 <[email protected]> Date: Sun Jan 28 09:44:23 2024 +0800 Enable list all tasks num commit a999fc6889c6986c28ec5d95460a4ab5233e5d4f Author: kcz358 <[email protected]> Date: Sun Jan 28 09:41:32 2024 +0800 Exclude train yaml file in the task list commit f92c3d6 Author: Zhang Peiyuan <[email protected]> Date: Sun Jan 28 02:04:57 2024 +0800 Add InfoVQA, DocVQA, and QwenVL (EvolvingLMMs-Lab#28) * add mmme * black * add model specific prompt and gen kwargs * black * add yaml config to supprot multi-model eval * print table at the end * refactor multi model code * add chartqa * black * add ai2d * black * update chartqa * blacl * update ai2d dataset * black * add qwenvl * add infovqa and docvqa * List task #num sorted * Update prompt messages for image-related tasks * Delete unused task configuration files * Remove coco_train.yaml configuration file * Update task name in mmmu.yaml * Fix error message for missing tasks * Add wandb import and integration --------- Co-authored-by: Fanyi Pu <[email protected]> Co-authored-by: kcz358 <[email protected]> * Remove scienceqa_img task configuration * eval scienceqa with no images --------- Co-authored-by: Bo Li <[email protected]> Co-authored-by: kcz358 <[email protected]> * Update hb_doc_to_text function to remove unnecessary line break * Add Fuyu model and update OtterHD model * Refactor model response handling and fix image processing bug * Refactor flatten method to support only getting the first element * Add support for specifying timezone in datetime string Update flatten method in OtterHD class Update get_datetime_str function in utils.py * Fix condition for checking wandb_args_dict in __main__.py * Commented out assertions for batch size in Fuyu model * Add warning message for existing output file
1 parent 6d570ac commit a781057

File tree

8 files changed

+723
-50
lines changed

8 files changed

+723
-50
lines changed

lmms_eval/__main__.py

Lines changed: 38 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,11 @@ def parse_eval_args() -> argparse.Namespace:
119119
default="",
120120
help="Comma separated string arguments passed to wandb.init, e.g. `project=lmms-eval,job_type=eval",
121121
)
122+
parser.add_argument(
123+
"--timezone",
124+
default="Asia/Singapore",
125+
help="Timezone for datetime string, e.g. Asia/Singapore, America/New_York, America/Los_Angeles",
126+
)
122127
args = parser.parse_args()
123128
return args
124129

@@ -206,27 +211,11 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
206211
)
207212
# eval_logger.warn(f"Tasks {missing} were not found. Try `lmms-eval --tasks list` for list of available tasks.")
208213

209-
if args.output_path:
210-
hash_input = f"{args.model_args}_{args.tasks}".encode("utf-8")
211-
hash_output = hashlib.sha256(hash_input).hexdigest()[:6]
212-
datetime_str = utils.get_datetime_str()
213-
path = Path(args.output_path)
214-
path = path.expanduser().resolve().joinpath(f"{args.model}_{datetime_str}_{hash_output}_{args.log_samples_suffix}")
215-
# check if file or 'dir/results.json' exists
216-
if path.is_file() or path.joinpath("results.json").is_file():
217-
eval_logger.warning(f"File already exists at {path}. Results will be overwritten.")
218-
output_path_file = path.joinpath("results.json")
219-
assert not path.is_file(), "File already exists"
220-
# if path json then get parent dir
221-
elif path.suffix in (".json", ".jsonl"):
222-
output_path_file = path
223-
else:
224-
output_path_file = path.joinpath("results.json")
225-
elif args.log_samples and not args.output_path:
226-
assert args.output_path, "Specify --output_path"
227-
228214
eval_logger.info(f"Selected Tasks: {task_names}")
229215

216+
# set datetime before evaluation
217+
datetime_str = utils.get_datetime_str(timezone=args.timezone)
218+
230219
results = evaluator.simple_evaluate(
231220
model=args.model,
232221
model_args=args.model_args,
@@ -241,8 +230,22 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
241230
gen_kwargs=args.gen_kwargs,
242231
)
243232

244-
if results is not None:
233+
if args.output_path:
234+
hash_input = f"{args.model_args}".encode("utf-8")
235+
hash_output = hashlib.sha256(hash_input).hexdigest()[:6]
236+
path = Path(args.output_path)
237+
path = path.expanduser().resolve().joinpath(f"{args.model}").joinpath(f"model_args_{hash_output}").joinpath(f"{datetime_str}")
245238
path.mkdir(parents=True, exist_ok=True)
239+
assert path.is_dir(), f"Output path {path} is not a directory"
240+
241+
output_path_file = path.joinpath("results.json")
242+
if output_path_file.exists():
243+
eval_logger.warning(f"Output file {output_path_file} already exists and will be overwritten.")
244+
245+
elif args.log_samples and not args.output_path:
246+
assert args.output_path, "Specify --output_path"
247+
248+
if results is not None:
246249
if args.log_samples:
247250
samples = results.pop("samples")
248251
dumped = json.dumps(results, indent=4, default=_handle_non_serializable)
@@ -254,7 +257,7 @@ def cli_evaluate_single(args: Union[argparse.Namespace, None] = None) -> None:
254257

255258
if args.log_samples:
256259
for task_name, config in results["configs"].items():
257-
output_name = f"{args.model}_{task_name}_{args.log_samples_suffix}"
260+
output_name = f"{task_name}_{args.log_samples_suffix}"
258261
filename = path.joinpath(f"{output_name}.json")
259262
# Structure the data with 'args' and 'logs' keys
260263
data_to_dump = {"args": vars(args), "config": config, "logs": sorted(samples[task_name], key=lambda x: x["doc_id"])} # Convert Namespace to dict
@@ -293,24 +296,25 @@ def print_results(args, results):
293296
# initialize Accelerator
294297
accelerator = Accelerator()
295298
all_args_dict = vars(args)
299+
wandb_run = None
296300

297301
if accelerator.is_main_process:
298302
# initialize a W&B run only on rank 0
299303
wandb_args_dict = utils.simple_parse_args_string(args.wandb_args)
300-
if "name" not in wandb_args_dict:
301-
if "config" not in all_args_dict:
302-
# use the model name and task names as run name
303-
task_names = args.tasks.replace(",", "_")
304-
wandb_args_dict["name"] = f"{args.model}_{task_names}_{args.log_samples_suffix}"
305-
if args.num_fewshot:
306-
wandb_args_dict["name"] += f"_{args.num_fewshot}shot"
307-
else:
308-
# use the name of the config file as run name
309-
wandb_args_dict["name"] = all_args_dict["config"].split("/")[-1].split(".")[0]
310-
wandb_run = wandb.init(**wandb_args_dict)
304+
if wandb_args_dict:
305+
if "name" not in wandb_args_dict:
306+
if "config" not in all_args_dict:
307+
# use the model name and task names as run name
308+
task_names = args.tasks.replace(",", "_")
309+
wandb_args_dict["name"] = f"{args.model}_{task_names}_{args.log_samples_suffix}"
310+
if args.num_fewshot:
311+
wandb_args_dict["name"] += f"_{args.num_fewshot}shot"
312+
else:
313+
# use the name of the config file as run name
314+
wandb_args_dict["name"] = all_args_dict["config"].split("/")[-1].split(".")[0]
315+
wandb_run = wandb.init(**wandb_args_dict)
311316
is_main_process = True
312317
else:
313-
wandb_run = None
314318
is_main_process = False
315319

316320
# run each config
@@ -319,5 +323,5 @@ def print_results(args, results):
319323
results = cli_evaluate(args, wandb_run)
320324
results_list.append(results)
321325

322-
if is_main_process:
326+
if is_main_process and wandb_run is not None:
323327
wandb_run.finish()

lmms_eval/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .llava import Llava
22
from .otterhd import OtterHD
33
from .qwen_vl import Qwen_VL
4+
from .fuyu import Fuyu

lmms_eval/models/fuyu.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from transformers import FuyuForCausalLM, AutoTokenizer, FuyuImageProcessor, FuyuProcessor
2+
from lmms_eval.api.model import lmms
3+
from lmms_eval.api.registry import register_model
4+
import torch
5+
from PIL import Image
6+
from typing import List, Optional, Union, Tuple
7+
from lmms_eval import utils
8+
from lmms_eval.api.instance import Instance
9+
from tqdm import tqdm
10+
11+
12+
@register_model("fuyu")
13+
class Fuyu(lmms):
14+
"""
15+
Fuyu Model
16+
"""
17+
18+
def __init__(
19+
self,
20+
pretrained: str = "adept/fuyu-8b",
21+
device: Optional[str] = "cuda",
22+
max_new_tokens: int = 256,
23+
batch_size: Optional[Union[int, str]] = 1,
24+
**kwargs,
25+
) -> None:
26+
super().__init__()
27+
# Do not use kwargs for now
28+
assert kwargs == {}, f"Unexpected kwargs: {kwargs}"
29+
30+
self.device = device if torch.cuda.is_available() else "cpu"
31+
self.model = FuyuForCausalLM.from_pretrained(pretrained, torch_dtype=torch.bfloat16, device_map=self.device)
32+
self.model.eval()
33+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained)
34+
self.image_processor = FuyuImageProcessor()
35+
self.processor = FuyuProcessor(image_processor=self.image_processor, tokenizer=self.tokenizer)
36+
self.max_new_tokens = max_new_tokens
37+
self.batch_size_per_gpu = int(batch_size)
38+
39+
@property
40+
def max_length(self):
41+
# Assuming max_length is the sum of max context tokens and max new tokens
42+
return self.tokenizer.model_max_length
43+
44+
# @property
45+
# def max_gen_toks(self) -> int:
46+
# return self.max_new_tokens
47+
48+
@property
49+
def batch_size(self):
50+
return self.batch_size_per_gpu
51+
52+
def flatten(self, input, only_get_first=False):
53+
new_list = []
54+
for i in input:
55+
for j in i:
56+
new_list.append(j)
57+
if only_get_first:
58+
break
59+
return new_list
60+
61+
def generate_until(self, requests: List[Instance]) -> List[str]:
62+
res = []
63+
64+
def _collate(x):
65+
# the negative sign on len(toks) sorts descending - this has a few advantages:
66+
# - time estimates will always be over not underestimates, which is more useful for planning
67+
# - to know the size of a batch when going through the list, you know the first one is always the batch
68+
# padded context length. this is useful to simplify the batching logic and more importantly to make
69+
# automatic adaptive batches much much easier to implement
70+
# - any OOMs will happen right away rather than near the end
71+
toks = self.tok_encode(x[0])
72+
return -len(toks), x[0]
73+
74+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
75+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
76+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
77+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
78+
79+
for chunk in chunks:
80+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
81+
task = task[0]
82+
split = split[0]
83+
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
84+
visuals = self.flatten(visuals, only_get_first=True)
85+
gen_kwargs = all_gen_kwargs[0]
86+
87+
# if isinstance(visuals[0], list):
88+
# visuals = [visuals[idx][0] for idx in range(len(visuals))] # get the first image in multi-image scenarios.
89+
90+
# assert len(contexts) == self.batch_size_per_gpu, f"Expected contexts batch size {self.batch_size_per_gpu}, got {len(contexts)}"
91+
# assert len(visuals) == self.batch_size_per_gpu, f"Expected visuals batch size {self.batch_size_per_gpu}, got {len(visuals)}"
92+
formatted_contexts = [f"{context}\n" for context in contexts]
93+
model_inputs = self.processor(text=formatted_contexts, images=visuals, device=self.device)
94+
for k, v in model_inputs.items():
95+
model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v]
96+
97+
for index in range(len(model_inputs["image_patches"])):
98+
model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype)
99+
100+
# preconfigure gen_kwargs with defaults
101+
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
102+
if "max_new_tokens" not in gen_kwargs:
103+
gen_kwargs["max_new_tokens"] = 1024
104+
if "temperature" not in gen_kwargs:
105+
gen_kwargs["temperature"] = 0
106+
if "top_p" not in gen_kwargs:
107+
gen_kwargs["top_p"] = None
108+
if "num_beams" not in gen_kwargs:
109+
gen_kwargs["num_beams"] = 1
110+
generation_output = self.model.generate(**model_inputs, max_new_tokens=gen_kwargs["max_new_tokens"], pad_token_id=self.tokenizer.eos_token_id)
111+
generation_texts = self.processor.batch_decode(generation_output, skip_special_tokens=True)
112+
response = [gen_text.split("\x04")[1].strip(" ").strip("\n") for gen_text in generation_texts]
113+
res.extend(response)
114+
pbar.update(1)
115+
116+
pbar.close()
117+
return res
118+
119+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
120+
# TODO
121+
assert False, "We have not implemented this function for llava yet"
122+
123+
def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
124+
# TODO
125+
assert False, "We have not implemented this function for llava yet"
126+
127+
def tok_encode(self, string: str, left_truncate_len=None, add_special_tokens=None) -> List[int]:
128+
""" """
129+
add_special_tokens = False if add_special_tokens is None else add_special_tokens
130+
encoding = self.tokenizer.encode(string, add_special_tokens=add_special_tokens)
131+
# left-truncate the encoded context to be at most `left_truncate_len` tokens long
132+
if left_truncate_len:
133+
encoding = encoding[-left_truncate_len:]
134+
return encoding
135+
136+
def tok_decode(self, tokens):
137+
return self.tokenizer.decode(tokens)

lmms_eval/models/otterhd.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,13 @@ def max_length(self):
5252
def batch_size(self):
5353
return self.batch_size_per_gpu
5454

55-
def flatten(self, input):
55+
def flatten(self, input, only_get_first=False):
5656
new_list = []
5757
for i in input:
5858
for j in i:
5959
new_list.append(j)
60+
if only_get_first:
61+
break
6062
return new_list
6163

6264
def generate_until(self, requests: List[Instance]) -> List[str]:
@@ -72,31 +74,44 @@ def _collate(x):
7274
toks = self.tok_encode(x[0])
7375
return -len(toks), x[0]
7476

75-
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
7677
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
7778
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
79+
num_iters = len(requests) // self.batch_size if len(requests) % self.batch_size == 0 else len(requests) // self.batch_size + 1
80+
pbar = tqdm(total=num_iters, disable=(self.rank != 0), desc="Model Responding")
7881

7982
for chunk in chunks:
80-
contexts, all_gen_kwargs, visuals = zip(*chunk)
81-
visuals = self.flatten(visuals)
83+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
84+
task = task[0]
85+
split = split[0]
86+
visuals = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
87+
visuals = self.flatten(visuals, only_get_first=True)
8288
gen_kwargs = all_gen_kwargs[0]
8389

84-
if isinstance(visuals, list):
85-
visuals = [visuals[0]]
90+
# if isinstance(visuals[0], list):
91+
# visuals = [visuals[idx][0] for idx in range(len(visuals))] # get the first image in multi-image scenarios.
8692

87-
formatted_contexts = f"User: {contexts[0]} Assistant:"
93+
formatted_contexts = [f"User: {context} Assistant:" for context in contexts]
8894
model_inputs = self.processor(text=[formatted_contexts], images=visuals, device=self.device)
8995
for k, v in model_inputs.items():
9096
model_inputs[k] = v.to(self.device, non_blocking=True) if isinstance(v, torch.Tensor) else [vv.to(self.device, non_blocking=True) for vv in v]
9197

9298
for index in range(len(model_inputs["image_patches"])):
9399
model_inputs["image_patches"][index] = model_inputs["image_patches"][index].to(dtype=next(self.model.parameters()).dtype)
94100

95-
max_new_tokens = gen_kwargs.get("max_new_tokens", self.max_new_tokens)
96-
generation_output = self.model.generate(**model_inputs, max_new_tokens=max_new_tokens, pad_token_id=self.tokenizer.eos_token_id)
97-
generation_text = self.processor.batch_decode(generation_output, skip_special_tokens=True)
98-
response = generation_text[0].split("\x04")[1].strip(" ").strip("\n")
99-
res.append(response)
101+
# preconfigure gen_kwargs with defaults
102+
gen_kwargs["image_sizes"] = [visuals[idx].size for idx in range(len(visuals))]
103+
if "max_new_tokens" not in gen_kwargs:
104+
gen_kwargs["max_new_tokens"] = 1024
105+
if "temperature" not in gen_kwargs:
106+
gen_kwargs["temperature"] = 0
107+
if "top_p" not in gen_kwargs:
108+
gen_kwargs["top_p"] = None
109+
if "num_beams" not in gen_kwargs:
110+
gen_kwargs["num_beams"] = 1
111+
generation_output = self.model.generate(**model_inputs, max_new_tokens=gen_kwargs["max_new_tokens"], pad_token_id=self.tokenizer.eos_token_id)
112+
generation_texts = self.processor.batch_decode(generation_output, skip_special_tokens=True)
113+
response = [gen_text.split("\x04")[1].strip(" ").strip("\n") for gen_text in generation_texts]
114+
res.extend(response)
100115
pbar.update(1)
101116

102117
pbar.close()

0 commit comments

Comments
 (0)