diff --git a/lmms_eval/models/qwen2_vl.py b/lmms_eval/models/qwen2_vl.py
index a7ac04f16..dfd1d8247 100755
--- a/lmms_eval/models/qwen2_vl.py
+++ b/lmms_eval/models/qwen2_vl.py
@@ -22,6 +22,13 @@
except ImportError:
eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`")
+SYSTEM_PROMPT = (
+ "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
+ "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
+ "process and answer are enclosed within and tags, respectively, i.e., "
+ " reasoning process here answer here "
+)
+
@register_model("qwen2_vl")
class Qwen2_VL(lmms):
@@ -198,7 +205,7 @@ def _collate(x):
if "" in context:
context = context.replace("", "")
- message = [{"role": "system", "content": "You are a helpful assistant."}]
+ message = [{"role": "system", "content": SYSTEM_PROMPT}]
if len(visuals) > 0:
visual = visuals[i] if i < len(visuals) else None
@@ -264,9 +271,7 @@ def _collate(x):
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=pad_token_id,
do_sample=True if gen_kwargs["temperature"] > 0 else False,
- temperature=gen_kwargs["temperature"],
- top_p=gen_kwargs["top_p"],
- num_beams=gen_kwargs["num_beams"],
+ temperature=0,
max_new_tokens=gen_kwargs["max_new_tokens"],
use_cache=self.use_cache,
)
diff --git a/lmms_eval/tasks/mathvista/mathvista.yaml b/lmms_eval/tasks/mathvista/mathvista.yaml
index 24a1b09eb..8df8ea95a 100755
--- a/lmms_eval/tasks/mathvista/mathvista.yaml
+++ b/lmms_eval/tasks/mathvista/mathvista.yaml
@@ -4,5 +4,5 @@ task:
- mathvista_test
metadata:
version: 0.0
- gpt_eval_model_name: "gpt-3.5-turbo"
+ gpt_eval_model_name: "gpt-4o-2024-08-06"
quick_extract: false
\ No newline at end of file
diff --git a/lmms_eval/tasks/mathvista/mathvista_evals.py b/lmms_eval/tasks/mathvista/mathvista_evals.py
index 2dbbc2b5b..126e84da6 100755
--- a/lmms_eval/tasks/mathvista/mathvista_evals.py
+++ b/lmms_eval/tasks/mathvista/mathvista_evals.py
@@ -170,7 +170,7 @@ def __init__(self, api_key, gpt_model="gpt-3.5-turbo", quick_extract=False):
def _post_request(self, payload):
headers = {
- "Authorization": f"Bearer {self.api_key}",
+ "api-key": self.api_key,
"Content-Type": "application/json",
}
response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30)
@@ -183,8 +183,8 @@ def get_chat_response(self, prompt, temperature=0, max_tokens=256, n=1, patience
]
payload = {"model": self.gpt_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "n": n}
- if self.API_TYPE == "azure":
- payload.pop("model")
+ # if self.API_TYPE == "azure":
+ # payload.pop("model")
while patience > 0:
patience -= 1
diff --git a/lmms_eval/tasks/mathvista/mathvista_testmini.yaml b/lmms_eval/tasks/mathvista/mathvista_testmini.yaml
index 289ece3c4..0f8a311f8 100755
--- a/lmms_eval/tasks/mathvista/mathvista_testmini.yaml
+++ b/lmms_eval/tasks/mathvista/mathvista_testmini.yaml
@@ -5,5 +5,5 @@ task:
- mathvista_testmini_format
metadata:
version: 0.0
- gpt_eval_model_name: "gpt-3.5-turbo"
+ gpt_eval_model_name: "gpt-4o-2024-08-06"
quick_extract: false
\ No newline at end of file
diff --git a/lmms_eval/tasks/mmmu/_default_template_yaml b/lmms_eval/tasks/mmmu/_default_template_yaml
index a53675347..1eb65bc02 100644
--- a/lmms_eval/tasks/mmmu/_default_template_yaml
+++ b/lmms_eval/tasks/mmmu/_default_template_yaml
@@ -1,5 +1,5 @@
generation_kwargs:
- max_new_tokens: 16
+ max_new_tokens: 4096
metadata:
version: 0.0
diff --git a/lmms_eval/tasks/mmmu/mmmu_val.yaml b/lmms_eval/tasks/mmmu/mmmu_val.yaml
index a301f7cb8..4e6b27960 100755
--- a/lmms_eval/tasks/mmmu/mmmu_val.yaml
+++ b/lmms_eval/tasks/mmmu/mmmu_val.yaml
@@ -6,11 +6,11 @@ doc_to_visual: !function utils.mmmu_doc_to_visual
doc_to_text: !function utils.mmmu_doc_to_text
doc_to_target: "answer"
# The return value of process_results will be used by metrics
-process_results: !function utils.mmmu_process_results
+process_results: !function utils.mmmu_reasoning_process_results
metric_list:
- - metric: mmmu_acc
- aggregation: !function utils.mmmu_aggregate_results
+ - metric: mmmu_judge_acc
+ aggregation: !function utils.mmmu_aggregate_judge_results
higher_is_better: true
include: _default_template_yaml
\ No newline at end of file
diff --git a/lmms_eval/tasks/mmmu/utils.py b/lmms_eval/tasks/mmmu/utils.py
index 789bd5a69..1951e198b 100755
--- a/lmms_eval/tasks/mmmu/utils.py
+++ b/lmms_eval/tasks/mmmu/utils.py
@@ -3,18 +3,18 @@
import os
import random
import re
+import time
from collections import defaultdict
from pathlib import Path
import numpy as np
+import requests
import yaml
from loguru import logger as eval_logger
+from openai import AzureOpenAI, OpenAI
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
-MULTI_CHOICE_PROMPT = "Answer with the option's letter from the given choices directly."
-OPEN_ENDED_PROMPT = "Answer the question using a single word or phrase."
-
with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
@@ -26,6 +26,96 @@
config = yaml.safe_load("".join(safe_data))
+with open(Path(__file__).parent / "mmmu_val.yaml", "r") as f:
+ raw_data = f.readlines()
+ safe_data = []
+ for i, line in enumerate(raw_data):
+ # remove function definition since yaml load cannot handle it
+ if "!function" not in line:
+ safe_data.append(line)
+
+ reasoning_config = yaml.safe_load("".join(safe_data))
+ # MC_PROMPT = reasoning_config["lmms_eval_specific_kwargs"]["default"]["multiple_choice_prompt"]
+ # OPEN_ENDED_PROMPT = reasoning_config["lmms_eval_specific_kwargs"]["default"]["open_ended_prompt"]
+
+
+NUM_SECONDS_TO_SLEEP = 5
+API_TYPE = os.getenv("API_TYPE", "openai")
+MODEL_VERSION = os.getenv("MODEL_VERSION", "gpt-4o-2024-08-06")
+
+JUDGE_RULES = """You are a strict evaluator assessing answer correctness. You must output 1 for fully correct answers and 0 for any other case.
+# Input
+Question:
+```
+{question}
+```
+Ground Truth Answer:
+```
+{answer}
+```
+Model Prediction:
+```
+{pred}
+```
+
+# Evaluation Rules
+- The model prediction contains the reasoning process, you should spot the final answer from the it.
+- For multiple-choice questions: Score 1 if the predicted answer matches the correct answer.
+- For open-ended questions:
+ * Score 1 if the prediction matches the answer semantically and contains all key elements
+ * Score 0 for partially correct answers or answers with extra incorrect information, even if the reasoning process is correct.
+- Ignore minor differences in formatting, capitalization, or spacing since the model may explain in a different way.
+- Treat numerical answers as correct if they match within reasonable precision
+- For questions requiring units, both value and unit must be correct
+
+# Strict Output format
+[0/1]"""
+
+if API_TYPE == "openai":
+ API_URL = os.getenv("OPENAI_API_URL", "https://api.openai.com/v1/chat/completions")
+ API_KEY = os.getenv("OPENAI_API_KEY", "YOUR_API_KEY")
+ client = OpenAI(api_key=API_KEY)
+elif API_TYPE == "azure":
+ API_URL = os.getenv("AZURE_ENDPOINT", "https://api.cognitive.microsoft.com/sts/v1.0/issueToken")
+ API_KEY = os.getenv("AZURE_API_KEY", "YOUR_API_KEY")
+ client = AzureOpenAI(azure_endpoint=API_URL, api_version="2023-07-01-preview", api_key=API_KEY)
+
+
+def get_chat_response(content: str, max_tokens: int, retries: int = 5):
+ global MODEL_VERSION
+ global client
+
+ messages = [
+ {
+ "role": "system",
+ "content": "You are a helpful and precise assistant for checking the correctness of the answer.",
+ },
+ {"role": "user", "content": content},
+ ]
+
+ payload = {
+ "model": MODEL_VERSION,
+ "messages": messages,
+ "temperature": 0.2,
+ "max_tokens": max_tokens,
+ }
+
+ for attempt in range(retries):
+ try:
+ response = client.chat.completions.create(**payload)
+ content = response.choices[0].message.content.strip()
+ return content
+ except requests.exceptions.RequestException as e:
+ eval_logger.warning(f"Request failed on attempt {attempt+1}: {e}")
+ time.sleep(NUM_SECONDS_TO_SLEEP)
+ if attempt == retries - 1:
+ eval_logger.error(f"Failed to get response after {retries} attempts")
+ return ""
+ except Exception as e:
+ eval_logger.error(f"Error on attempt {attempt+1}: {e}")
+ return ""
+
+
def replace_images_tokens(input_string):
for i in range(1, 8):
question_text = f""
@@ -41,15 +131,15 @@ def parse_options(options):
return choices_str
-def construct_prompt(doc):
+def construct_prompt(doc, mc_prompt="", open_ended_prompt=""):
question = doc["question"]
if doc["question_type"] == "multiple-choice":
# Weirdly, data["options"] is a string in MMMU Huggingface dataset
parsed_options = parse_options(ast.literal_eval(doc["options"]))
# parsed_options already prepends a newline so no need to add space here
- question = f"{question}\n{parsed_options}\n\n{MULTI_CHOICE_PROMPT}"
+ question = f"{question}\n{parsed_options}"
else:
- question = f"{question}\n\n{OPEN_ENDED_PROMPT}"
+ question = f"{question}"
return question
@@ -70,20 +160,28 @@ def mmmu_doc_to_visual(doc):
def mmmu_process_results(doc, results):
+ parsed_preds = []
+ for pred in results:
+ if doc["question_type"] == "multiple-choice":
+ index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
+ parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
+ else:
+ parsed_pred = parse_open_response(pred)
+
+ parsed_preds.append(parsed_pred)
+
+ mmmu_exact_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_preds}
+ return {"mmmu_acc": mmmu_exact_acc, "mmmu_acc_pass_at_k": mmmu_exact_acc}
+
+
+def mmmu_reasoning_process_results(doc, results):
pred = results[0]
- if doc["question_type"] == "multiple-choice":
- index2ans, all_choices = get_multi_choice_info(ast.literal_eval(doc["options"]))
- parsed_pred = parse_multi_choice_response(pred, all_choices, index2ans)
- else:
- parsed_pred = parse_open_response(pred)
- id = doc["id"]
- mmmu_acc = {"id": id, "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "parsed_pred": parsed_pred}
- return {
- "mmmu_acc": mmmu_acc,
- "submission": {
- id: pred,
- },
- }
+ # formatted_question = construct_prompt(doc, MC_PROMPT, OPEN_ENDED_PROMPT)
+ formatted_question = construct_prompt(doc)
+ llm_judge_prompt = JUDGE_RULES.format(question=formatted_question, answer=doc["answer"], pred=pred)
+ llm_judge_score = get_chat_response(llm_judge_prompt, max_tokens=20, retries=3)
+ mmmu_judge_acc = {"id": doc["id"], "subdomain": extract_subset_name(doc["id"]), "question_type": doc["question_type"], "answer": doc["answer"], "pred": pred, "score": llm_judge_score}
+ return {"mmmu_judge_acc": mmmu_judge_acc}
def extract_subset_name(input_string):
@@ -143,6 +241,17 @@ def mmmu_aggregate_results(results):
return printable_results["Overall"]["acc"]
+def mmmu_aggregate_judge_results(results):
+ total_score = 0
+ for result in results:
+ try:
+ total_score += int(result["score"])
+ except:
+ eval_logger.warning(f"Failed to convert score to int for {result['id']}: {result['score']}")
+ total_score += 0
+ return total_score / len(results)
+
+
##################
# Helper functions written by official MMMU repo.
##################
@@ -253,16 +362,20 @@ def evaluate_mmmu(samples):
judge_dict = dict()
for sample in samples:
gold_i = sample["answer"]
- pred_i = sample["parsed_pred"]
- if sample["question_type"] == "multiple-choice":
- correct = eval_multi_choice(gold_i, pred_i)
- else: # open question
- correct = eval_open(gold_i, pred_i)
-
- if correct:
- judge_dict[sample["id"]] = "Correct"
- pred_correct += 1
- else:
+ pred_list = sample["parsed_pred"]
+ correct = False
+ for pred_i in pred_list:
+ if sample["question_type"] == "multiple-choice":
+ correct = eval_multi_choice(gold_i, pred_i)
+ else: # open question
+ correct = eval_open(gold_i, pred_i)
+
+ if correct:
+ judge_dict[sample["id"]] = "Correct"
+ pred_correct += 1
+ break
+
+ if not correct:
judge_dict[sample["id"]] = "Wrong"
if len(samples) == 0: