Skip to content

Commit 01803d3

Browse files
committed
Added failed flag to ModelResponse.
1 parent a00af73 commit 01803d3

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

src/lighteval/metrics/llm_as_judge.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,18 +206,20 @@ def __call_api(prompt):
206206
}
207207
response = litellm.completion(**kwargs)
208208
text = response.choices[0].message.content
209-
if text is None:
209+
if not text or response.failed:
210210
kwargs["caching"] = False
211211
response = litellm.completion(**kwargs)
212212
text = response.choices[0].message.content
213-
if text is None:
213+
if not text or response.failed:
214214
# Just return an error response if the second attempt fails too
215-
return ModelResponse(text="Failed to get response from the API.", model=self.model)
215+
return ModelResponse(
216+
text="Failed to get response from the API.", model=self.model, failed=True
217+
)
216218
return text
217219
except Exception as e:
218220
logger.warning(f"{type(e), e}")
219221
time.sleep(self.API_RETRY_SLEEP)
220-
return ModelResponse(text="Failed to get response from the API.", model=self.model)
222+
return ModelResponse(text="Failed to get response from the API.", model=self.model, failed=True)
221223

222224
results = []
223225
with ThreadPoolExecutor(100) as executor:

src/lighteval/models/model_output.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ModelResponse:
3333
generated_tokens: list[int] = field(default_factory=list) # model generations
3434
truncated_tokens_count: Optional[int] = 0 # How many tokens truncated
3535
padded_tokens_count: Optional[int] = 0 # How many tokens of padding
36+
failed: bool = False
3637

3738
def get_result_for_eval(self):
3839
raise NotImplementedError()

0 commit comments

Comments
 (0)