Skip to content

Commit 7f36065

Browse files
committed
Made comet score more robust.
1 parent f6b50b4 commit 7f36065

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

community_tasks/swiss_legal_evals.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -589,14 +589,15 @@ def compute(
589589
predictions = [response[0].result for response in responses]
590590
sources = [kwargs["formatted_doc"].specific["source"] for kwargs["formatted_doc"] in formatted_docs]
591591

592-
data = [
593-
{
594-
"src": src,
595-
"mt": pred if isinstance(pred, str) else pred[0],
596-
"ref": gold,
597-
}
598-
for src, pred, gold in zip(sources, predictions, golds)
599-
]
592+
def unpack(x):
593+
if isinstance(x, str):
594+
return x
595+
elif isinstance(x, (list, tuple)):
596+
return unpack(x[0])
597+
else:
598+
raise ValueError(f"Unknown type {type(x)} of prediction {x}")
599+
600+
data = [{"src": src, "mt": unpack(pred), "ref": gold} for src, pred, gold in zip(sources, predictions, golds)]
600601
model_output = self.model.predict(
601602
data,
602603
batch_size=self.batch_size,

0 commit comments

Comments
 (0)