Skip to content

Commit

Permalink
added HLE init + option to pass format to open ai models
Browse files Browse the repository at this point in the history
  • Loading branch information
clefourrier committed Jan 27, 2025
1 parent cb075a5 commit b7e2e80
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 9 deletions.
34 changes: 25 additions & 9 deletions src/lighteval/metrics/llm_as_judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Literal

from pydantic import BaseModel
from tqdm import tqdm

from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(
judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"],
url: str | None = None,
api_key: str | None = None,
response_format: BaseModel = None,
):
self.model = model
self.template = templates
Expand All @@ -91,6 +93,8 @@ def __init__(
self.api_key = api_key
self.backend = judge_backend

self.response_format = response_format

def __lazy_load_client(self):
match self.backend:
# Wether we use openai or TGI models, we go through the openai API
Expand Down Expand Up @@ -244,15 +248,27 @@ def __call_api_parallel(self, prompts):
def __call_api(self, prompt):
for _ in range(self.API_MAX_RETRY):
try:
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
)
text = response.choices[0].message.content
return text
if self.response_format:
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
response_format=self.response_format,
max_tokens=4096,
temperature=0.0,
n=1,
)
answer = response.choices[0].message.parsed
return answer
else:
response = self.client.chat.completions.create(
model=self.model,
messages=prompt,
response_format={"type": "text"},
max_tokens=512,
n=1,
)
text = response.choices[0].message.content
return text
except Exception as e:
logger.warning(f"{type(e), e}")
time.sleep(self.API_RETRY_SLEEP)
Expand Down
3 changes: 3 additions & 0 deletions src/lighteval/metrics/metrics_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordTokenizer
from nltk.translate.bleu_score import sentence_bleu
from pydantic import BaseModel
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from lighteval.metrics.imports.bert_scorer import BERTScorer
Expand Down Expand Up @@ -861,6 +862,7 @@ def __init__(
process_judge_response: Callable,
judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"],
short_judge_name: str | None = None,
response_format: BaseModel = None,
) -> None:
match judge_backend:
case "openai":
Expand Down Expand Up @@ -893,6 +895,7 @@ def __init__(
api_key=api_key,
url=url,
judge_backend=judge_backend,
response_format=response_format,
)

def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]:
Expand Down
213 changes: 213 additions & 0 deletions src/lighteval/tasks/extended/hle/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# MIT License

# Copyright (c) 2024 The HuggingFace Team

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.


import math
from typing import List, Literal

import numpy as np
from aenum import extend_enum
from pydantic import BaseModel

from lighteval.metrics.metrics import Metrics
from lighteval.metrics.metrics_sample import JudgeLLM
from lighteval.metrics.utils.metric_utils import (
MetricCategory,
MetricUseCase,
SampleLevelMetricGrouping,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc


class ExtractedAnswer(BaseModel):
extracted_final_answer: str
reasoning: str
correct: Literal["yes", "no"]
confidence: int
strict: Literal[True] # 100% reliability


"""Adaptation from https://github.com/centerforaisafety/hle/blob/main/hle_eval/run_judge_results.py
"""


def get_judge_prompt(question: str, answer: str, gold: str, **kwargs):
return [
{
"role": "user",
"content": rf"""Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below.
[question]: {question}
[response]: {answer}
Your judgement must be in the format and criteria specified below:
extracted_final_answer: The final exact answer extracted from the [response]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response.
[correct_answer]: {gold}
reasoning: Explain why the extracted_final_answer is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the extracted_final_answer. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match.
correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect.
confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available.""",
},
]


def process_judge_response_hle(response: ExtractedAnswer):
return {
# "correct_answer": correct_answer,
"model_answer": response.extracted_final_answer,
"reasoning": response.reasoning,
"correct": response.correct,
"confidence": response.confidence,
}


class JudgeLLMHLE(JudgeLLM):
def __init__(self):
super().__init__(
judge_model_name="gpt-4o-2024-08-06",
template=get_judge_prompt,
process_judge_response=process_judge_response_hle,
judge_backend="openai",
short_judge_name="hle_judge",
response_format=ExtractedAnswer,
)

def compute(self, predictions, formatted_doc: Doc, **kwargs):
# If we are evaluating a multiturn task, we need to have specific field in the formatted doc
gold = formatted_doc.get_golds()[0]

score, _, _ = self.judge.evaluate_answer(question=formatted_doc.query, answer=predictions[0], gold=gold)

score["correct_answer"] = gold
return score

def compute_corpus(self, scores: List[dict]):
n = len(scores)

correct = []
confidence = []
for score in scores:
correct.append("yes" in score["correct"])
confidence.append(score["confidence"])

correct = np.array(correct)
confidence = np.array(confidence)

# sometimes model collapses on same questions
if len(correct) != n:
print(f"Available predictions: {len(correct)} | Total questions: {n}")

accuracy = round(100 * sum(correct) / n, 2)
# Wald estimator, 95% confidence interval
confidence_half_width = round(1.96 * math.sqrt(accuracy * (100 - accuracy) / n), 2)
calibration_error = round(calib_err(confidence, correct, p="2", beta=100), 2)

return {
"accuracy": accuracy,
"confidence_half_width": confidence_half_width,
"calibration_error": calibration_error,
}


# source: https://github.com/hendrycks/outlier-exposure/blob/master/utils/calibration_tools.py
def calib_err(confidence, correct, p="2", beta=100):
# beta is target bin size
idxs = np.argsort(confidence)
confidence = confidence[idxs]
correct = correct[idxs]
bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)]
bins[-1] = [bins[-1][0], len(confidence)]

cerr = 0
total_examples = len(confidence)
for i in range(len(bins) - 1):
bin_confidence = confidence[bins[i][0] : bins[i][1]]
bin_correct = correct[bins[i][0] : bins[i][1]]
num_examples_in_bin = len(bin_confidence)

if num_examples_in_bin > 0:
difference = np.abs(np.nanmean(bin_confidence) - np.nanmean(bin_correct))

if p == "2":
cerr += num_examples_in_bin / total_examples * np.square(difference)
elif p == "1":
cerr += num_examples_in_bin / total_examples * difference
elif p == "infty" or p == "infinity" or p == "max":
cerr = np.maximum(cerr, difference)
else:
assert False, "p must be '1', '2', or 'infty'"

if p == "2":
cerr = np.sqrt(cerr)

return cerr


def hle_text_only(line, task_name: str = None):
if line["image"] is not None:
return

return Doc(
task_name=task_name,
query=f"Question: {line['question']}\nAnswer:",
choices=[line["answer"]],
gold_index=0,
)


hle_metrics = SampleLevelMetricGrouping(
metric_name=["accuracy", "confidence_half_width", "calibration_error"],
higher_is_better={n: True for n in ["accuracy", "confidence_half_width", "calibration_error"]},
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
sample_level_fn=JudgeLLMHLE().compute,
corpus_level_fn=JudgeLLMHLE().compute_corpus,
)

hle = LightevalTaskConfig(
name="hle",
suite=["lighteval"],
prompt_function=hle_text_only,
hf_repo="cais/hle",
hf_subset="main",
hf_avail_splits=["test"],
evaluation_splits=["test"],
few_shots_split=None,
few_shots_select=None,
generation_size=1, # TODO
metric=[Metrics.exact_match],
stop_sequence=["\n"],
trust_dataset=True,
version=0,
)


TASKS_TABLE = [hle]

extend_enum(Metrics, "hle_metrics", hle_metrics)

0 comments on commit b7e2e80

Please sign in to comment.