@@ -185,17 +185,16 @@ def tok_decode(self, tokens):
185
185
except :
186
186
return self .tokenizer .decode ([tokens ])
187
187
188
-
189
188
def flatten (self , input ):
190
- if not input or any (i is None for i in input ):
191
- return []
189
+ if not input or any (i is None for i in input ):
190
+ return []
192
191
new_list = []
193
192
for i in input :
194
- if i :
193
+ if i :
195
194
for j in i :
196
195
new_list .append (j )
197
196
return new_list
198
-
197
+
199
198
def loglikelihood (self , requests : List [Instance ]) -> List [Tuple [float , bool ]]:
200
199
# TODO
201
200
res = []
@@ -238,17 +237,42 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
238
237
239
238
msg = Message ()
240
239
msg .add_message (prompts_input )
240
+
241
+ # Process text input and get input_ids
241
242
contxt_id = self ._text_processor (msg .messages , mode = "eval" )["input_ids" ]
242
- # Add the answer of the second role
243
+
244
+ # Set the continuation as the second role's response
243
245
msg ._messages [1 ]["value" ] = continuation
244
246
input_ids = self ._text_processor (msg .messages , mode = "eval" )["input_ids" ]
245
247
248
+ # Prepare labels and ensure the correct shape
246
249
labels = input_ids .clone ()
247
- # Context part no need to calculate for loss
248
- labels [0 , : contxt_id .shape [1 ]] = - 100
250
+ if labels .dim () == 1 :
251
+ labels = labels .unsqueeze (0 ) # Convert to (1, seq_len) if needed
252
+
253
+ if len (contxt_id .shape ) == 1 :
254
+ contxt_id = contxt_id .unsqueeze (0 ) # Convert to (1, context_len)
255
+
256
+ # Mask the context part to ignore it in loss computation
257
+ labels [:, : contxt_id .shape [1 ]] = - 100
258
+
259
+ # Move tensors to the correct device
260
+ device = self .device
261
+ input_ids = input_ids .to (device )
262
+ labels = labels .to (device )
263
+
264
+ if len (input_ids .shape ) == 1 :
265
+ input_ids = input_ids .unsqueeze (0 ) # Ensure it is (batch_size, seq_len)
266
+
267
+ # Handle image input if available
268
+ if image is None :
269
+ image_sizes = []
270
+ with torch .inference_mode ():
271
+ outputs = self .model (input_ids = input_ids , labels = labels , use_cache = True )
272
+ else :
273
+ with torch .inference_mode ():
274
+ outputs = self .model (input_ids = input_ids , labels = labels , images = image , use_cache = True , image_sizes = image_sizes )
249
275
250
- with torch .inference_mode ():
251
- outputs = self .model (input_ids = input_ids , labels = labels , images = image , use_cache = True , image_sizes = image_sizes )
252
276
loss = outputs ["loss" ]
253
277
# loss = torch.exp(loss)
254
278
logits = outputs ["logits" ]
0 commit comments