Skip to content

Commit fdd39ce

Browse files
Jinhe BiJinhe Bi
authored andcommitted
update text only tinyllava
1 parent 720d28e commit fdd39ce

File tree

3 files changed

+40
-18
lines changed

3 files changed

+40
-18
lines changed

lmms_eval/models/llava.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -278,16 +278,15 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
278278
return res
279279

280280
def flatten(self, input):
281-
if not input or any(i is None for i in input):
282-
return []
281+
if not input or any(i is None for i in input):
282+
return []
283283
new_list = []
284284
for i in input:
285-
if i:
285+
if i:
286286
for j in i:
287287
new_list.append(j)
288288
return new_list
289289

290-
291290
def generate_until(self, requests: List[Instance]) -> List[str]:
292291
res = []
293292

lmms_eval/models/llava_onevision.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -361,13 +361,12 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
361361
pbar.close()
362362
return res
363363

364-
365364
def flatten(self, input):
366-
if not input or any(i is None for i in input):
367-
return []
365+
if not input or any(i is None for i in input):
366+
return []
368367
new_list = []
369368
for i in input:
370-
if i:
369+
if i:
371370
for j in i:
372371
new_list.append(j)
373372
return new_list

lmms_eval/models/tinyllava.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,17 +185,16 @@ def tok_decode(self, tokens):
185185
except:
186186
return self.tokenizer.decode([tokens])
187187

188-
189188
def flatten(self, input):
190-
if not input or any(i is None for i in input):
191-
return []
189+
if not input or any(i is None for i in input):
190+
return []
192191
new_list = []
193192
for i in input:
194-
if i:
193+
if i:
195194
for j in i:
196195
new_list.append(j)
197196
return new_list
198-
197+
199198
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
200199
# TODO
201200
res = []
@@ -238,17 +237,42 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
238237

239238
msg = Message()
240239
msg.add_message(prompts_input)
240+
241+
# Process text input and get input_ids
241242
contxt_id = self._text_processor(msg.messages, mode="eval")["input_ids"]
242-
# Add the answer of the second role
243+
244+
# Set the continuation as the second role's response
243245
msg._messages[1]["value"] = continuation
244246
input_ids = self._text_processor(msg.messages, mode="eval")["input_ids"]
245247

248+
# Prepare labels and ensure the correct shape
246249
labels = input_ids.clone()
247-
# Context part no need to calculate for loss
248-
labels[0, : contxt_id.shape[1]] = -100
250+
if labels.dim() == 1:
251+
labels = labels.unsqueeze(0) # Convert to (1, seq_len) if needed
252+
253+
if len(contxt_id.shape) == 1:
254+
contxt_id = contxt_id.unsqueeze(0) # Convert to (1, context_len)
255+
256+
# Mask the context part to ignore it in loss computation
257+
labels[:, : contxt_id.shape[1]] = -100
258+
259+
# Move tensors to the correct device
260+
device = self.device
261+
input_ids = input_ids.to(device)
262+
labels = labels.to(device)
263+
264+
if len(input_ids.shape) == 1:
265+
input_ids = input_ids.unsqueeze(0) # Ensure it is (batch_size, seq_len)
266+
267+
# Handle image input if available
268+
if image is None:
269+
image_sizes = []
270+
with torch.inference_mode():
271+
outputs = self.model(input_ids=input_ids, labels=labels, use_cache=True)
272+
else:
273+
with torch.inference_mode():
274+
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes)
249275

250-
with torch.inference_mode():
251-
outputs = self.model(input_ids=input_ids, labels=labels, images=image, use_cache=True, image_sizes=image_sizes)
252276
loss = outputs["loss"]
253277
# loss = torch.exp(loss)
254278
logits = outputs["logits"]

0 commit comments

Comments
 (0)