@@ -35,8 +35,12 @@ def __init__(self,
35
35
load_in_8bit = False ,
36
36
use_mpo_prompt = False ,
37
37
version = 'V1.0' ,
38
+ # Best-of-N parameters
39
+ best_of_n = 1 ,
40
+ reward_model_path = None ,
38
41
** kwargs ):
39
42
43
+ assert best_of_n >= 1
40
44
assert model_path is not None
41
45
assert version_cmp (transformers .__version__ , '4.37.2' , 'ge' )
42
46
@@ -78,8 +82,37 @@ def __init__(self,
78
82
low_cpu_mem_usage = True ).eval ().cuda ()
79
83
self .device = 'cuda'
80
84
85
+ if best_of_n > 1 :
86
+ assert version == 'V2.0' , 'only support BoN evaluation with version==V2.0'
87
+ assert reward_model_path is not None
88
+
89
+ if auto_split_flag ():
90
+ rm_device_map , visible_devices = split_model (model_path = reward_model_path )
91
+ rm_kwargs = {'device_map' : rm_device_map }
92
+ else :
93
+ rm_kwargs = {}
94
+
95
+ self .reward_tokenizer = AutoTokenizer .from_pretrained (reward_model_path , trust_remote_code = True , use_fast = False )
96
+ self .reward_model = AutoModel .from_pretrained (
97
+ reward_model_path ,
98
+ torch_dtype = torch .bfloat16 ,
99
+ load_in_8bit = load_in_8bit ,
100
+ trust_remote_code = True ,
101
+ low_cpu_mem_usage = True , ** rm_kwargs ).eval ()
102
+
103
+ if not auto_split_flag ():
104
+ self .reward_model = self .reward_model .to (self .device )
105
+
106
+ if not self .use_cot :
107
+ os .environ ['USE_COT' ] = '1'
108
+ self .use_cot = True
109
+ print ('[Warning] Since Best-of-N is enabled, USE_COT is forced to be set to 1.' )
110
+
111
+ print (f'Enable Best-of-N evaluation with PRM: { reward_model_path } ' )
112
+
81
113
self .image_size = self .model .config .vision_config .image_size
82
114
self .version = version
115
+ self .best_of_n = best_of_n
83
116
kwargs_default = dict (do_sample = False , max_new_tokens = 4096 , top_p = None )
84
117
kwargs_default .update (kwargs )
85
118
self .kwargs = kwargs_default
@@ -206,6 +239,7 @@ def generate_v1_5(self, message, dataset=None):
206
239
verbose = True )
207
240
return response
208
241
242
+ @torch .no_grad ()
209
243
def generate_v2 (self , message , dataset = None ):
210
244
211
245
use_mpo_prompt = self .use_mpo_prompt and (self .use_cot or dataset in ['MMStar' , 'HallusionBench' , 'OCRBench' ])
@@ -237,15 +271,32 @@ def generate_v2(self, message, dataset=None):
237
271
pixel_values = None
238
272
num_patches_list = []
239
273
240
- with torch .no_grad ():
274
+ response_list = []
275
+ for idx in range (self .best_of_n ):
276
+ kwargs_default = self .kwargs .copy ()
277
+ kwargs_default ['do_sample' ] = idx > 0
278
+ kwargs_default ['temperature' ] = 0.7
279
+ kwargs_default ['top_p' ] = 0.95
280
+
241
281
response = self .model .chat (
242
282
self .tokenizer ,
243
283
pixel_values = pixel_values ,
244
284
num_patches_list = num_patches_list ,
245
285
question = prompt ,
246
- generation_config = self .kwargs ,
247
- verbose = True
286
+ generation_config = kwargs_default ,
287
+ verbose = idx == 0 ,
288
+ )
289
+ response_list .append (response )
290
+
291
+ if self .best_of_n > 1 :
292
+ response_list = self .reward_model .select_best_response (
293
+ tokenizer = self .reward_tokenizer ,
294
+ question = prompt ,
295
+ response_list = response_list ,
296
+ pixel_values = pixel_values ,
297
+ num_patches_list = num_patches_list ,
248
298
)
299
+ response = response_list [0 ]
249
300
250
301
if use_mpo_prompt :
251
302
response = mpo_post_processing (response , dataset )
0 commit comments