Skip to content

Commit cfc5c23

Browse files
authored
[Improvement] Support Best-of-N evaluation for InternVL2.5 (open-compass#854)
1 parent c656fd6 commit cfc5c23

File tree

2 files changed

+59
-3
lines changed

2 files changed

+59
-3
lines changed

vlmeval/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,11 @@
648648
"InternVL2_5-78B": partial(
649649
InternVLChat, model_path="OpenGVLab/InternVL2_5-78B", version="V2.0"
650650
),
651+
# InternVL2.5 series with Best-of-N evaluation
652+
"InternVL2_5-8B-BoN-8": partial(
653+
InternVLChat, model_path="OpenGVLab/InternVL2_5-8B", version="V2.0",
654+
best_of_n=8, reward_model_path="OpenGVLab/VisualPRM-8B",
655+
),
651656
# InternVL2.5-MPO series
652657
"InternVL2_5-1B-MPO": partial(
653658
InternVLChat,

vlmeval/vlm/internvl/internvl_chat.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,12 @@ def __init__(self,
3535
load_in_8bit=False,
3636
use_mpo_prompt=False,
3737
version='V1.0',
38+
# Best-of-N parameters
39+
best_of_n=1,
40+
reward_model_path=None,
3841
**kwargs):
3942

43+
assert best_of_n >= 1
4044
assert model_path is not None
4145
assert version_cmp(transformers.__version__, '4.37.2', 'ge')
4246

@@ -78,8 +82,37 @@ def __init__(self,
7882
low_cpu_mem_usage=True).eval().cuda()
7983
self.device = 'cuda'
8084

85+
if best_of_n > 1:
86+
assert version == 'V2.0', 'only support BoN evaluation with version==V2.0'
87+
assert reward_model_path is not None
88+
89+
if auto_split_flag():
90+
rm_device_map, visible_devices = split_model(model_path=reward_model_path)
91+
rm_kwargs = {'device_map': rm_device_map}
92+
else:
93+
rm_kwargs = {}
94+
95+
self.reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_path, trust_remote_code=True, use_fast=False)
96+
self.reward_model = AutoModel.from_pretrained(
97+
reward_model_path,
98+
torch_dtype=torch.bfloat16,
99+
load_in_8bit=load_in_8bit,
100+
trust_remote_code=True,
101+
low_cpu_mem_usage=True, **rm_kwargs).eval()
102+
103+
if not auto_split_flag():
104+
self.reward_model = self.reward_model.to(self.device)
105+
106+
if not self.use_cot:
107+
os.environ['USE_COT'] = '1'
108+
self.use_cot = True
109+
print('[Warning] Since Best-of-N is enabled, USE_COT is forced to be set to 1.')
110+
111+
print(f'Enable Best-of-N evaluation with PRM: {reward_model_path}')
112+
81113
self.image_size = self.model.config.vision_config.image_size
82114
self.version = version
115+
self.best_of_n = best_of_n
83116
kwargs_default = dict(do_sample=False, max_new_tokens=4096, top_p=None)
84117
kwargs_default.update(kwargs)
85118
self.kwargs = kwargs_default
@@ -206,6 +239,7 @@ def generate_v1_5(self, message, dataset=None):
206239
verbose=True)
207240
return response
208241

242+
@torch.no_grad()
209243
def generate_v2(self, message, dataset=None):
210244

211245
use_mpo_prompt = self.use_mpo_prompt and (self.use_cot or dataset in ['MMStar', 'HallusionBench', 'OCRBench'])
@@ -237,15 +271,32 @@ def generate_v2(self, message, dataset=None):
237271
pixel_values = None
238272
num_patches_list = []
239273

240-
with torch.no_grad():
274+
response_list = []
275+
for idx in range(self.best_of_n):
276+
kwargs_default = self.kwargs.copy()
277+
kwargs_default['do_sample'] = idx > 0
278+
kwargs_default['temperature'] = 0.7
279+
kwargs_default['top_p'] = 0.95
280+
241281
response = self.model.chat(
242282
self.tokenizer,
243283
pixel_values=pixel_values,
244284
num_patches_list=num_patches_list,
245285
question=prompt,
246-
generation_config=self.kwargs,
247-
verbose=True
286+
generation_config=kwargs_default,
287+
verbose=idx == 0,
288+
)
289+
response_list.append(response)
290+
291+
if self.best_of_n > 1:
292+
response_list = self.reward_model.select_best_response(
293+
tokenizer=self.reward_tokenizer,
294+
question=prompt,
295+
response_list=response_list,
296+
pixel_values=pixel_values,
297+
num_patches_list=num_patches_list,
248298
)
299+
response = response_list[0]
249300

250301
if use_mpo_prompt:
251302
response = mpo_post_processing(response, dataset)

0 commit comments

Comments
 (0)