Skip to content

Commit

Permalink
metrics working
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jan 29, 2025
1 parent db4c4e8 commit 59ce9e1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 21 deletions.
38 changes: 24 additions & 14 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,17 @@
from tqdm import tqdm

from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
from lighteval.utils.utils import as_list


logging.getLogger("openai").setLevel(logging.ERROR)
logging.getLogger("httpx").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)


DEFAULT_FORMAT = {"type": "text"}


class JudgeLM:
"""
A class representing a judge for evaluating answers using either the OpenAI or Transformers library.
Expand Down Expand Up @@ -93,7 +97,7 @@ def __init__(
self.api_key = api_key
self.backend = judge_backend

self.response_format = response_format
self.response_format = response_format if not None else DEFAULT_FORMAT

def __lazy_load_client(self):
match self.backend:
Expand Down Expand Up @@ -248,28 +252,34 @@ def __call_api_parallel(self, prompts):
def __call_api(self, prompt):
for _ in range(self.API_MAX_RETRY):
try:
if self.response_format:
# Base model
response = self.client.beta.chat.completions.parse(
model=self.model,
messages=as_list(prompt),
response_format=self.response_format,
max_tokens=4096,
temperature=0.0,
n=1,
)
answer = response.choices[0].message.parsed
return answer
except TypeError:
try:
# Finetune
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
messages=as_list(prompt),
response_format=self.response_format,
max_tokens=4096,
temperature=0.0,
n=1,
)
answer = response.choices[0].message.parsed
return answer
else:
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
)
text = response.choices[0].message.content
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)

raise Exception("Failed to get response from the API")
28 changes: 21 additions & 7 deletions src/lighteval/tasks/extended/hle/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# SOFTWARE.


import logging
import math
from typing import List, Literal

Expand All @@ -31,14 +32,17 @@
from lighteval.metrics.metrics import Metrics
from lighteval.metrics.metrics_sample import JudgeLLM
from lighteval.metrics.utils.metric_utils import (
CorpusLevelMetricGrouping,
MetricCategory,
MetricUseCase,
SampleLevelMetricGrouping,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc


logger = logging.getLogger(__name__)


class ExtractedAnswer(BaseModel):
extracted_final_answer: str
reasoning: str
Expand Down Expand Up @@ -77,7 +81,10 @@ def get_judge_prompt(question: str, answer: str, gold: str, **kwargs):
]


def process_judge_response_hle(response: ExtractedAnswer):
def process_judge_response_hle(response: ExtractedAnswer | List[ExtractedAnswer]):
# todo: add support for batched responses
if isinstance(response, list):
response = response[0]
return {
# "correct_answer": correct_answer,
"model_answer": response.extracted_final_answer,
Expand Down Expand Up @@ -105,7 +112,11 @@ def compute(self, predictions, formatted_doc: Doc, **kwargs):
score, _, _ = self.judge.evaluate_answer(question=formatted_doc.query, answer=predictions[0], gold=gold)

score["correct_answer"] = gold
return score
return {
"accuracy": score,
"confidence_half_width": score,
"calibration_error": score,
}

def compute_corpus(self, scores: List[dict]):
n = len(scores)
Expand Down Expand Up @@ -142,6 +153,10 @@ def calib_err(confidence, correct, p="2", beta=100):
confidence = confidence[idxs]
correct = correct[idxs]
bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)]
if len(bins) == 0:
logger.warning("Error when computing the bins for calibration error")
return -1

bins[-1] = [bins[-1][0], len(confidence)]

cerr = 0
Expand Down Expand Up @@ -181,14 +196,15 @@ def hle_text_only(line, task_name: str = None):
)


hle_metrics = SampleLevelMetricGrouping(
hle_metrics = CorpusLevelMetricGrouping(
metric_name=["accuracy", "confidence_half_width", "calibration_error"],
higher_is_better={n: True for n in ["accuracy", "confidence_half_width", "calibration_error"]},
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
sample_level_fn=JudgeLLMHLE().compute,
corpus_level_fn=JudgeLLMHLE().compute_corpus,
)
extend_enum(Metrics, "hle_metrics", hle_metrics)

hle = LightevalTaskConfig(
name="hle",
Expand All @@ -201,13 +217,11 @@ def hle_text_only(line, task_name: str = None):
few_shots_split=None,
few_shots_select=None,
generation_size=8192, # TODO
metric=[Metrics.exact_match, hle_metrics],
metric=[Metrics.exact_match, Metrics.hle_metrics],
stop_sequence=["\n"],
trust_dataset=True,
version=0,
)


TASKS_TABLE = [hle]

extend_enum(Metrics, "hle_metrics", hle_metrics)

0 comments on commit 59ce9e1

Please sign in to comment.