Skip to content

Commit

Permalink
Add eval for cot
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed Jan 27, 2025
1 parent f6fe367 commit 08678c1
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 42 deletions.
13 changes: 9 additions & 4 deletions lmms_eval/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@
except ImportError:
eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`")

SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)


@register_model("qwen2_vl")
class Qwen2_VL(lmms):
Expand Down Expand Up @@ -198,7 +205,7 @@ def _collate(x):
if "<image>" in context:
context = context.replace("<image>", "")

message = [{"role": "system", "content": "You are a helpful assistant."}]
message = [{"role": "system", "content": SYSTEM_PROMPT}]

if len(visuals) > 0:
visual = visuals[i] if i < len(visuals) else None
Expand Down Expand Up @@ -264,9 +271,7 @@ def _collate(x):
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=pad_token_id,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
temperature=gen_kwargs["temperature"],
top_p=gen_kwargs["top_p"],
num_beams=gen_kwargs["num_beams"],
temperature=0,
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
)
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/mathvista/mathvista.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@ task:
- mathvista_test
metadata:
version: 0.0
gpt_eval_model_name: "gpt-3.5-turbo"
gpt_eval_model_name: "gpt-4o-2024-08-06"
quick_extract: false
6 changes: 3 additions & 3 deletions lmms_eval/tasks/mathvista/mathvista_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def __init__(self, api_key, gpt_model="gpt-3.5-turbo", quick_extract=False):

def _post_request(self, payload):
headers = {
"Authorization": f"Bearer {self.api_key}",
"api-key": self.api_key,
"Content-Type": "application/json",
}
response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30)
Expand All @@ -183,8 +183,8 @@ def get_chat_response(self, prompt, temperature=0, max_tokens=256, n=1, patience
]
payload = {"model": self.gpt_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "n": n}

if self.API_TYPE == "azure":
payload.pop("model")
# if self.API_TYPE == "azure":
# payload.pop("model")

while patience > 0:
patience -= 1
Expand Down
2 changes: 1 addition & 1 deletion lmms_eval/tasks/mathvista/mathvista_testmini.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@ task:
- mathvista_testmini_format
metadata:
version: 0.0
gpt_eval_model_name: "gpt-3.5-turbo"
gpt_eval_model_name: "gpt-4o-2024-08-06"
quick_extract: false
2 changes: 1 addition & 1 deletion lmms_eval/tasks/mmmu/_default_template_yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
generation_kwargs:
max_new_tokens: 16
max_new_tokens: 4096

metadata:
version: 0.0
Expand Down
6 changes: 3 additions & 3 deletions lmms_eval/tasks/mmmu/mmmu_val.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ doc_to_visual: !function utils.mmmu_doc_to_visual
doc_to_text: !function utils.mmmu_doc_to_text
doc_to_target: "answer"
# The return value of process_results will be used by metrics
process_results: !function utils.mmmu_process_results
process_results: !function utils.mmmu_reasoning_process_results

metric_list:
- metric: mmmu_acc
aggregation: !function utils.mmmu_aggregate_results
- metric: mmmu_judge_acc
aggregation: !function utils.mmmu_aggregate_judge_results
higher_is_better: true

include: _default_template_yaml
171 changes: 142 additions & 29 deletions lmms_eval/tasks/mmmu/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import os
import random
import re
import time
from collections import defaultdict
from pathlib import Path

import numpy as np
import requests
import yaml
from loguru import logger as eval_logger
from openai import AzureOpenAI, OpenAI

from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

MULTI_CHOICE_PROMPT = "Answer with the option's letter from the given choices directly."
OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."

with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
Expand All @@ -26,6 +26,96 @@
config = yaml.safe_load("".join(safe_data))


with open(Path(__file__).parent / "mmmu_val.yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)

reasoning_config = yaml.safe_load("".join(safe_data))
# MC_PROMPT = reasoning_config["lmms_eval_specific_kwargs"]["default"]["multiple_choice_prompt"]
# OPEN_ENDED_PROMPT = reasoning_config["lmms_eval_specific_kwargs"]["default"]["open_ended_prompt"]


NUM_SECONDS_TO_SLEEP = 5
API_TYPE = os.getenv("API_TYPE", "openai")
MODEL_VERSION = os.getenv("MODEL_VERSION", "gpt-4o-2024-08-06")

JUDGE_RULES = """You are a strict evaluator assessing answer correctness. You must output 1 for fully correct answers and 0 for any other case.
# Input
Question:
```
{question}
```
Ground Truth Answer:
```
{answer}
```
Model Prediction:
```
{pred}
```
# Evaluation Rules
- The model prediction contains the reasoning process, you should spot the final answer from the it.
- For multiple-choice questions: Score 1 if the predicted answer matches the correct answer.
- For open-ended questions:
* Score 1 if the prediction matches the answer semantically and contains all key elements
* Score 0 for partially correct answers or answers with extra incorrect information, even if the reasoning process is correct.
- Ignore minor differences in formatting, capitalization, or spacing since the model may explain in a different way.
- Treat numerical answers as correct if they match within reasonable precision
- For questions requiring units, both value and unit must be correct
# Strict Output format
[0/1]"""

if API_TYPE == "openai":
API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
client = OpenAI(api_key=API_KEY)
elif API_TYPE == "azure":
API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
client = AzureOpenAI(azure_endpoint=API_URL, api_version="2023-07-01-preview", api_key=API_KEY)


def get_chat_response(content: str, max_tokens: int, retries: int = 5):
global MODEL_VERSION
global client

messages = [
{
"role": "system",
"content": "You are a helpful and precise assistant for checking the correctness of the answer.",
},
{"role": "user", "content": content},
]

payload = {
"model": MODEL_VERSION,
"messages": messages,
"temperature": 0.2,
"max_tokens": max_tokens,
}

for attempt in range(retries):
try:
response = client.chat.completions.create(**payload)
content = response.choices[0].message.content.strip()
return content
except requests.exceptions.RequestException as e:
eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}")
time.sleep(NUM_SECONDS_TO_SLEEP)
if attempt == retries - 1:
eval_logger.error(f"Failed to get response after {retries} attempts")
return ""
except Exception as e:
eval_logger.error(f"Error on attempt {attempt+1}: {e}")
return ""


def replace_images_tokens(input_string):
for i in range(1, 8):
question_text = f"<image {i}>"
Expand All @@ -41,15 +131,15 @@ def parse_options(options):
return choices_str


def construct_prompt(doc):
def construct_prompt(doc, mc_prompt="", open_ended_prompt=""):
question = doc["question"]
if doc["question_type"] == "multiple-choice":
# Weirdly, data["options"] is a string in MMMU Huggingface dataset
parsed_options = parse_options(ast.literal_eval(doc["options"]))
# parsed_options already prepends a newline so no need to add space here
question = f"{question}\n{parsed_options}\n\n{MULTI_CHOICE_PROMPT}"
question = f"{question}\n{parsed_options}"
else:
question = f"{question}\n\n{OPEN_ENDED_PROMPT}"
question = f"{question}"
return question


Expand All @@ -70,20 +160,28 @@ def mmmu_doc_to_visual(doc):


def mmmu_process_results(doc, results):
parsed_preds = []
for pred in results:
if doc["question_type"] == "multiple-choice":
index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
else:
parsed_pred = parse_open_response(pred)

parsed_preds.append(parsed_pred)

mmmu_exact_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_preds}
return {"mmmu_acc": mmmu_exact_acc, "mmmu_acc_pass_at_k": mmmu_exact_acc}


def mmmu_reasoning_process_results(doc, results):
pred = results[0]
if doc["question_type"] == "multiple-choice":
index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
else:
parsed_pred = parse_open_response(pred)
id = doc["id"]
mmmu_acc = {"id": id, "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_pred}
return {
"mmmu_acc": mmmu_acc,
"submission": {
id: pred,
},
}
# formatted_question = construct_prompt(doc, MC_PROMPT, OPEN_ENDED_PROMPT)
formatted_question = construct_prompt(doc)
llm_judge_prompt = JUDGE_RULES.format(question=formatted_question, answer=doc["answer"], pred=pred)
llm_judge_score = get_chat_response(llm_judge_prompt, max_tokens=20, retries=3)
mmmu_judge_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "pred": pred, "score": llm_judge_score}
return {"mmmu_judge_acc": mmmu_judge_acc}


def extract_subset_name(input_string):
Expand Down Expand Up @@ -143,6 +241,17 @@ def mmmu_aggregate_results(results):
return printable_results["Overall"]["acc"]


def mmmu_aggregate_judge_results(results):
total_score = 0
for result in results:
try:
total_score += int(result["score"])
except:
eval_logger.warning(f"Failed to convert score to int for {result['id']}: {result['score']}")
total_score += 0
return total_score / len(results)


##################
# Helper functions written by official MMMU repo.
##################
Expand Down Expand Up @@ -253,16 +362,20 @@ def evaluate_mmmu(samples):
judge_dict = dict()
for sample in samples:
gold_i = sample["answer"]
pred_i = sample["parsed_pred"]
if sample["question_type"] == "multiple-choice":
correct = eval_multi_choice(gold_i, pred_i)
else: # open question
correct = eval_open(gold_i, pred_i)

if correct:
judge_dict[sample["id"]] = "Correct"
pred_correct += 1
else:
pred_list = sample["parsed_pred"]
correct = False
for pred_i in pred_list:
if sample["question_type"] == "multiple-choice":
correct = eval_multi_choice(gold_i, pred_i)
else: # open question
correct = eval_open(gold_i, pred_i)

if correct:
judge_dict[sample["id"]] = "Correct"
pred_correct += 1
break

if not correct:
judge_dict[sample["id"]] = "Wrong"

if len(samples) == 0:
Expand Down

0 comments on commit 08678c1

Please sign in to comment.