From b7e2e808bee7de64fbb00551294ea51788819748 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Mon, 27 Jan 2025 14:58:24 +0100 Subject: [PATCH] added HLE init + option to pass format to open ai models --- src/lighteval/metrics/llm_as_judge.py | 34 +++- src/lighteval/metrics/metrics_sample.py | 3 + src/lighteval/tasks/extended/hle/main.py | 213 +++++++++++++++++++++++ 3 files changed, 241 insertions(+), 9 deletions(-) create mode 100644 src/lighteval/tasks/extended/hle/main.py diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 23beda76f..8350fb088 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -26,6 +26,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Callable, Literal +from pydantic import BaseModel from tqdm import tqdm from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available @@ -76,6 +77,7 @@ def __init__( judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"], url: str | None = None, api_key: str | None = None, + response_format: BaseModel = None, ): self.model = model self.template = templates @@ -91,6 +93,8 @@ def __init__( self.api_key = api_key self.backend = judge_backend + self.response_format = response_format + def __lazy_load_client(self): match self.backend: # Wether we use openai or TGI models, we go through the openai API @@ -244,15 +248,27 @@ def __call_api_parallel(self, prompts): def __call_api(self, prompt): for _ in range(self.API_MAX_RETRY): try: - response = self.client.chat.completions.create( - model=self.model, - messages=prompt, - response_format={"type": "text"}, - max_tokens=512, - n=1, - ) - text = response.choices[0].message.content - return text + if self.response_format: + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + response_format=self.response_format, + max_tokens=4096, + temperature=0.0, + n=1, + ) + answer = response.choices[0].message.parsed + return answer + else: + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + response_format={"type": "text"}, + max_tokens=512, + n=1, + ) + text = response.choices[0].message.content + return text except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 352c2b98e..cf680422c 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -35,6 +35,7 @@ from nltk.tokenize import word_tokenize from nltk.tokenize.treebank import TreebankWordTokenizer from nltk.translate.bleu_score import sentence_bleu +from pydantic import BaseModel from transformers import AutoModelForSequenceClassification, AutoTokenizer from lighteval.metrics.imports.bert_scorer import BERTScorer @@ -861,6 +862,7 @@ def __init__( process_judge_response: Callable, judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"], short_judge_name: str | None = None, + response_format: BaseModel = None, ) -> None: match judge_backend: case "openai": @@ -893,6 +895,7 @@ def __init__( api_key=api_key, url=url, judge_backend=judge_backend, + response_format=response_format, ) def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]: diff --git a/src/lighteval/tasks/extended/hle/main.py b/src/lighteval/tasks/extended/hle/main.py new file mode 100644 index 000000000..21744744d --- /dev/null +++ b/src/lighteval/tasks/extended/hle/main.py @@ -0,0 +1,213 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import math +from typing import List, Literal + +import numpy as np +from aenum import extend_enum +from pydantic import BaseModel + +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.metrics_sample import JudgeLLM +from lighteval.metrics.utils.metric_utils import ( + MetricCategory, + MetricUseCase, + SampleLevelMetricGrouping, +) +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +class ExtractedAnswer(BaseModel): + extracted_final_answer: str + reasoning: str + correct: Literal["yes", "no"] + confidence: int + strict: Literal[True] # 100% reliability + + +"""Adaptation from https://github.com/centerforaisafety/hle/blob/main/hle_eval/run_judge_results.py +""" + + +def get_judge_prompt(question: str, answer: str, gold: str, **kwargs): + return [ + { + "role": "user", + "content": rf"""Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below. + +[question]: {question} + +[response]: {answer} + +Your judgement must be in the format and criteria specified below: + +extracted_final_answer: The final exact answer extracted from the [response]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response. + +[correct_answer]: {gold} + +reasoning: Explain why the extracted_final_answer is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the extracted_final_answer. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match. + +correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect. + + +confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available.""", + }, + ] + + +def process_judge_response_hle(response: ExtractedAnswer): + return { + # "correct_answer": correct_answer, + "model_answer": response.extracted_final_answer, + "reasoning": response.reasoning, + "correct": response.correct, + "confidence": response.confidence, + } + + +class JudgeLLMHLE(JudgeLLM): + def __init__(self): + super().__init__( + judge_model_name="gpt-4o-2024-08-06", + template=get_judge_prompt, + process_judge_response=process_judge_response_hle, + judge_backend="openai", + short_judge_name="hle_judge", + response_format=ExtractedAnswer, + ) + + def compute(self, predictions, formatted_doc: Doc, **kwargs): + # If we are evaluating a multiturn task, we need to have specific field in the formatted doc + gold = formatted_doc.get_golds()[0] + + score, _, _ = self.judge.evaluate_answer(question=formatted_doc.query, answer=predictions[0], gold=gold) + + score["correct_answer"] = gold + return score + + def compute_corpus(self, scores: List[dict]): + n = len(scores) + + correct = [] + confidence = [] + for score in scores: + correct.append("yes" in score["correct"]) + confidence.append(score["confidence"]) + + correct = np.array(correct) + confidence = np.array(confidence) + + # sometimes model collapses on same questions + if len(correct) != n: + print(f"Available predictions: {len(correct)} | Total questions: {n}") + + accuracy = round(100 * sum(correct) / n, 2) + # Wald estimator, 95% confidence interval + confidence_half_width = round(1.96 * math.sqrt(accuracy * (100 - accuracy) / n), 2) + calibration_error = round(calib_err(confidence, correct, p="2", beta=100), 2) + + return { + "accuracy": accuracy, + "confidence_half_width": confidence_half_width, + "calibration_error": calibration_error, + } + + +# source: https://github.com/hendrycks/outlier-exposure/blob/master/utils/calibration_tools.py +def calib_err(confidence, correct, p="2", beta=100): + # beta is target bin size + idxs = np.argsort(confidence) + confidence = confidence[idxs] + correct = correct[idxs] + bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)] + bins[-1] = [bins[-1][0], len(confidence)] + + cerr = 0 + total_examples = len(confidence) + for i in range(len(bins) - 1): + bin_confidence = confidence[bins[i][0] : bins[i][1]] + bin_correct = correct[bins[i][0] : bins[i][1]] + num_examples_in_bin = len(bin_confidence) + + if num_examples_in_bin > 0: + difference = np.abs(np.nanmean(bin_confidence) - np.nanmean(bin_correct)) + + if p == "2": + cerr += num_examples_in_bin / total_examples * np.square(difference) + elif p == "1": + cerr += num_examples_in_bin / total_examples * difference + elif p == "infty" or p == "infinity" or p == "max": + cerr = np.maximum(cerr, difference) + else: + assert False, "p must be '1', '2', or 'infty'" + + if p == "2": + cerr = np.sqrt(cerr) + + return cerr + + +def hle_text_only(line, task_name: str = None): + if line["image"] is not None: + return + + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + choices=[line["answer"]], + gold_index=0, + ) + + +hle_metrics = SampleLevelMetricGrouping( + metric_name=["accuracy", "confidence_half_width", "calibration_error"], + higher_is_better={n: True for n in ["accuracy", "confidence_half_width", "calibration_error"]}, + category=MetricCategory.GENERATIVE, + use_case=MetricUseCase.ACCURACY, + sample_level_fn=JudgeLLMHLE().compute, + corpus_level_fn=JudgeLLMHLE().compute_corpus, +) + +hle = LightevalTaskConfig( + name="hle", + suite=["lighteval"], + prompt_function=hle_text_only, + hf_repo="cais/hle", + hf_subset="main", + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=1, # TODO + metric=[Metrics.exact_match], + stop_sequence=["\n"], + trust_dataset=True, + version=0, +) + + +TASKS_TABLE = [hle] + +extend_enum(Metrics, "hle_metrics", hle_metrics)