From 08678c111fc0d17207661303790798cc00c8550a Mon Sep 17 00:00:00 2001 From: kcz358 Date: Mon, 27 Jan 2025 07:55:54 +0000 Subject: [PATCH] Add eval for cot --- lmms_eval/models/qwen2_vl.py | 13 +- lmms_eval/tasks/mathvista/mathvista.yaml | 2 +- lmms_eval/tasks/mathvista/mathvista_evals.py | 6 +- .../tasks/mathvista/mathvista_testmini.yaml | 2 +- lmms_eval/tasks/mmmu/_default_template_yaml | 2 +- lmms_eval/tasks/mmmu/mmmu_val.yaml | 6 +- lmms_eval/tasks/mmmu/utils.py | 171 +++++++++++++++--- 7 files changed, 160 insertions(+), 42 deletions(-) diff --git a/lmms_eval/models/qwen2_vl.py b/lmms_eval/models/qwen2_vl.py index a7ac04f16..dfd1d8247 100755 --- a/lmms_eval/models/qwen2_vl.py +++ b/lmms_eval/models/qwen2_vl.py @@ -22,6 +22,13 @@ except ImportError: eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`") +SYSTEM_PROMPT = ( + "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " + "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " + "process and answer are enclosed within and tags, respectively, i.e., " + " reasoning process here answer here " +) + @register_model("qwen2_vl") class Qwen2_VL(lmms): @@ -198,7 +205,7 @@ def _collate(x): if "" in context: context = context.replace("", "") - message = [{"role": "system", "content": "You are a helpful assistant."}] + message = [{"role": "system", "content": SYSTEM_PROMPT}] if len(visuals) > 0: visual = visuals[i] if i < len(visuals) else None @@ -264,9 +271,7 @@ def _collate(x): eos_token_id=self.tokenizer.eos_token_id, pad_token_id=pad_token_id, do_sample=True if gen_kwargs["temperature"] > 0 else False, - temperature=gen_kwargs["temperature"], - top_p=gen_kwargs["top_p"], - num_beams=gen_kwargs["num_beams"], + temperature=0, max_new_tokens=gen_kwargs["max_new_tokens"], use_cache=self.use_cache, ) diff --git a/lmms_eval/tasks/mathvista/mathvista.yaml b/lmms_eval/tasks/mathvista/mathvista.yaml index 24a1b09eb..8df8ea95a 100755 --- a/lmms_eval/tasks/mathvista/mathvista.yaml +++ b/lmms_eval/tasks/mathvista/mathvista.yaml @@ -4,5 +4,5 @@ task: - mathvista_test metadata: version: 0.0 - gpt_eval_model_name: "gpt-3.5-turbo" + gpt_eval_model_name: "gpt-4o-2024-08-06" quick_extract: false \ No newline at end of file diff --git a/lmms_eval/tasks/mathvista/mathvista_evals.py b/lmms_eval/tasks/mathvista/mathvista_evals.py index 2dbbc2b5b..126e84da6 100755 --- a/lmms_eval/tasks/mathvista/mathvista_evals.py +++ b/lmms_eval/tasks/mathvista/mathvista_evals.py @@ -170,7 +170,7 @@ def __init__(self, api_key, gpt_model="gpt-3.5-turbo", quick_extract=False): def _post_request(self, payload): headers = { - "Authorization": f"Bearer {self.api_key}", + "api-key": self.api_key, "Content-Type": "application/json", } response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30) @@ -183,8 +183,8 @@ def get_chat_response(self, prompt, temperature=0, max_tokens=256, n=1, patience ] payload = {"model": self.gpt_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "n": n} - if self.API_TYPE == "azure": - payload.pop("model") + # if self.API_TYPE == "azure": + # payload.pop("model") while patience > 0: patience -= 1 diff --git a/lmms_eval/tasks/mathvista/mathvista_testmini.yaml b/lmms_eval/tasks/mathvista/mathvista_testmini.yaml index 289ece3c4..0f8a311f8 100755 --- a/lmms_eval/tasks/mathvista/mathvista_testmini.yaml +++ b/lmms_eval/tasks/mathvista/mathvista_testmini.yaml @@ -5,5 +5,5 @@ task: - mathvista_testmini_format metadata: version: 0.0 - gpt_eval_model_name: "gpt-3.5-turbo" + gpt_eval_model_name: "gpt-4o-2024-08-06" quick_extract: false \ No newline at end of file diff --git a/lmms_eval/tasks/mmmu/_default_template_yaml b/lmms_eval/tasks/mmmu/_default_template_yaml index a53675347..1eb65bc02 100644 --- a/lmms_eval/tasks/mmmu/_default_template_yaml +++ b/lmms_eval/tasks/mmmu/_default_template_yaml @@ -1,5 +1,5 @@ generation_kwargs: - max_new_tokens: 16 + max_new_tokens: 4096 metadata: version: 0.0 diff --git a/lmms_eval/tasks/mmmu/mmmu_val.yaml b/lmms_eval/tasks/mmmu/mmmu_val.yaml index a301f7cb8..4e6b27960 100755 --- a/lmms_eval/tasks/mmmu/mmmu_val.yaml +++ b/lmms_eval/tasks/mmmu/mmmu_val.yaml @@ -6,11 +6,11 @@ doc_to_visual: !function utils.mmmu_doc_to_visual doc_to_text: !function utils.mmmu_doc_to_text doc_to_target: "answer" # The return value of process_results will be used by metrics -process_results: !function utils.mmmu_process_results +process_results: !function utils.mmmu_reasoning_process_results metric_list: - - metric: mmmu_acc - aggregation: !function utils.mmmu_aggregate_results + - metric: mmmu_judge_acc + aggregation: !function utils.mmmu_aggregate_judge_results higher_is_better: true include: _default_template_yaml \ No newline at end of file diff --git a/lmms_eval/tasks/mmmu/utils.py b/lmms_eval/tasks/mmmu/utils.py index 789bd5a69..1951e198b 100755 --- a/lmms_eval/tasks/mmmu/utils.py +++ b/lmms_eval/tasks/mmmu/utils.py @@ -3,18 +3,18 @@ import os import random import re +import time from collections import defaultdict from pathlib import Path import numpy as np +import requests import yaml from loguru import logger as eval_logger +from openai import AzureOpenAI, OpenAI from lmms_eval.tasks._task_utils.file_utils import generate_submission_file -MULTI_CHOICE_PROMPT = "Answer with the option's letter from the given choices directly." -OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase." - with open(Path(__file__).parent / "_default_template_yaml", "r") as f: raw_data = f.readlines() safe_data = [] @@ -26,6 +26,96 @@ config = yaml.safe_load("".join(safe_data)) +with open(Path(__file__).parent / "mmmu_val.yaml", "r") as f: + raw_data = f.readlines() + safe_data = [] + for i, line in enumerate(raw_data): + # remove function definition since yaml load cannot handle it + if "!function" not in line: + safe_data.append(line) + + reasoning_config = yaml.safe_load("".join(safe_data)) + # MC_PROMPT = reasoning_config["lmms_eval_specific_kwargs"]["default"]["multiple_choice_prompt"] + # OPEN_ENDED_PROMPT = reasoning_config["lmms_eval_specific_kwargs"]["default"]["open_ended_prompt"] + + +NUM_SECONDS_TO_SLEEP = 5 +API_TYPE = os.getenv("API_TYPE", "openai") +MODEL_VERSION = os.getenv("MODEL_VERSION", "gpt-4o-2024-08-06") + +JUDGE_RULES = """You are a strict evaluator assessing answer correctness. You must output 1 for fully correct answers and 0 for any other case. +# Input +Question: +``` +{question} +``` +Ground Truth Answer: +``` +{answer} +``` +Model Prediction: +``` +{pred} +``` + +# Evaluation Rules +- The model prediction contains the reasoning process, you should spot the final answer from the it. +- For multiple-choice questions: Score 1 if the predicted answer matches the correct answer. +- For open-ended questions: + * Score 1 if the prediction matches the answer semantically and contains all key elements + * Score 0 for partially correct answers or answers with extra incorrect information, even if the reasoning process is correct. +- Ignore minor differences in formatting, capitalization, or spacing since the model may explain in a different way. +- Treat numerical answers as correct if they match within reasonable precision +- For questions requiring units, both value and unit must be correct + +# Strict Output format +[0/1]""" + +if API_TYPE == "openai": + API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions") + API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY") + client = OpenAI(api_key=API_KEY) +elif API_TYPE == "azure": + API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken") + API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY") + client = AzureOpenAI(azure_endpoint=API_URL, api_version="2023-07-01-preview", api_key=API_KEY) + + +def get_chat_response(content: str, max_tokens: int, retries: int = 5): + global MODEL_VERSION + global client + + messages = [ + { + "role": "system", + "content": "You are a helpful and precise assistant for checking the correctness of the answer.", + }, + {"role": "user", "content": content}, + ] + + payload = { + "model": MODEL_VERSION, + "messages": messages, + "temperature": 0.2, + "max_tokens": max_tokens, + } + + for attempt in range(retries): + try: + response = client.chat.completions.create(**payload) + content = response.choices[0].message.content.strip() + return content + except requests.exceptions.RequestException as e: + eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}") + time.sleep(NUM_SECONDS_TO_SLEEP) + if attempt == retries - 1: + eval_logger.error(f"Failed to get response after {retries} attempts") + return "" + except Exception as e: + eval_logger.error(f"Error on attempt {attempt+1}: {e}") + return "" + + def replace_images_tokens(input_string): for i in range(1, 8): question_text = f"" @@ -41,15 +131,15 @@ def parse_options(options): return choices_str -def construct_prompt(doc): +def construct_prompt(doc, mc_prompt="", open_ended_prompt=""): question = doc["question"] if doc["question_type"] == "multiple-choice": # Weirdly, data["options"] is a string in MMMU Huggingface dataset parsed_options = parse_options(ast.literal_eval(doc["options"])) # parsed_options already prepends a newline so no need to add space here - question = f"{question}\n{parsed_options}\n\n{MULTI_CHOICE_PROMPT}" + question = f"{question}\n{parsed_options}" else: - question = f"{question}\n\n{OPEN_ENDED_PROMPT}" + question = f"{question}" return question @@ -70,20 +160,28 @@ def mmmu_doc_to_visual(doc): def mmmu_process_results(doc, results): + parsed_preds = [] + for pred in results: + if doc["question_type"] == "multiple-choice": + index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"])) + parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) + else: + parsed_pred = parse_open_response(pred) + + parsed_preds.append(parsed_pred) + + mmmu_exact_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_preds} + return {"mmmu_acc": mmmu_exact_acc, "mmmu_acc_pass_at_k": mmmu_exact_acc} + + +def mmmu_reasoning_process_results(doc, results): pred = results[0] - if doc["question_type"] == "multiple-choice": - index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"])) - parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans) - else: - parsed_pred = parse_open_response(pred) - id = doc["id"] - mmmu_acc = {"id": id, "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_pred} - return { - "mmmu_acc": mmmu_acc, - "submission": { - id: pred, - }, - } + # formatted_question = construct_prompt(doc, MC_PROMPT, OPEN_ENDED_PROMPT) + formatted_question = construct_prompt(doc) + llm_judge_prompt = JUDGE_RULES.format(question=formatted_question, answer=doc["answer"], pred=pred) + llm_judge_score = get_chat_response(llm_judge_prompt, max_tokens=20, retries=3) + mmmu_judge_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "pred": pred, "score": llm_judge_score} + return {"mmmu_judge_acc": mmmu_judge_acc} def extract_subset_name(input_string): @@ -143,6 +241,17 @@ def mmmu_aggregate_results(results): return printable_results["Overall"]["acc"] +def mmmu_aggregate_judge_results(results): + total_score = 0 + for result in results: + try: + total_score += int(result["score"]) + except: + eval_logger.warning(f"Failed to convert score to int for {result['id']}: {result['score']}") + total_score += 0 + return total_score / len(results) + + ################## # Helper functions written by official MMMU repo. ################## @@ -253,16 +362,20 @@ def evaluate_mmmu(samples): judge_dict = dict() for sample in samples: gold_i = sample["answer"] - pred_i = sample["parsed_pred"] - if sample["question_type"] == "multiple-choice": - correct = eval_multi_choice(gold_i, pred_i) - else: # open question - correct = eval_open(gold_i, pred_i) - - if correct: - judge_dict[sample["id"]] = "Correct" - pred_correct += 1 - else: + pred_list = sample["parsed_pred"] + correct = False + for pred_i in pred_list: + if sample["question_type"] == "multiple-choice": + correct = eval_multi_choice(gold_i, pred_i) + else: # open question + correct = eval_open(gold_i, pred_i) + + if correct: + judge_dict[sample["id"]] = "Correct" + pred_correct += 1 + break + + if not correct: judge_dict[sample["id"]] = "Wrong" if len(samples) == 0: