Skip to content

Commit

Permalink
edit olympiadbench
Browse files Browse the repository at this point in the history
  • Loading branch information
JvThunder committed Mar 28, 2024
1 parent 756ec8d commit 035f3df
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,58 +16,41 @@ def olympiadbench_doc_to_visual(doc):
def olympiadbench_doc_to_text(doc):
question = doc["question"]
subject = doc["subfield"]
language = "en" if "English" in doc["classification"] else "zh"
mul_ans = doc["is_multiple_answer"]
if mul_ans is None:
mul_ans = False
ans_type = doc["answer_type"]
if ans_type == "Need_human_evaluate":
ans_type = "proof based"

pre_prompt = ""
if language == "en":
pre_prompt += f"The following is a question from an International {subject} competition.\n"
else:
pre_prompt += f"以下是中国{subject}竞赛中的解答题。\n"
pre_prompt = f"以下是中国{subject}竞赛中的解答题。\n"

post_prompt = ""
if language == "en":
if not mul_ans:
post_prompt += f"The answer of the question should be {ans_type}.\n"
else:
post_prompt += f"The question has multiple answers, each of them should be {ans_type}.\n"
post_prompt += "Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "
if not mul_ans:
post_prompt += '"So the final answer is \\boxed{answer}."\n'
else:
post_prompt += 'So the final answer is \\boxed{multiple answers connected with commas}.\n'
if not mul_ans:
post_prompt += f"答案类型为{ans_type}\n"
else:
post_prompt += f"题目有多个答案,答案类型均为{ans_type}\n"
post_prompt += "请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以"
if not mul_ans:
post_prompt += '"所以最终答案是\\boxed{答案}。"\n'
else:
if not mul_ans:
post_prompt += f"答案类型为{ans_type}\n"
else:
post_prompt += f"题目有多个答案,答案类型均为{ans_type}\n"
post_prompt += "请根据题目的要求和所提供的信息计算得出答案。解答过程和结果中使用的变量和公式请使用LaTeX格式表示。请在最后以"
if not mul_ans:
post_prompt += '"所以最终答案是\\boxed{答案}。"\n'
else:
post_prompt += '"所以最终答案是\\boxed{用英⽂逗号连接的多个答案}。"\n'
post_prompt += '"所以最终答案是\\boxed{用英⽂逗号连接的多个答案}。"\n'

final_question = pre_prompt + question + '\n' + post_prompt
return final_question

def olympiadbench_process_results(doc, results):
precision = doc["error"]
answer_type = doc["answer_type"]
is_proving = "TP" in doc["source"]
if precision is None:
precision = 0
prediction = results[0].strip()

if answer_type == "Need_human_evaluate":
if is_proving:
return {
"submission": prediction
}
else:
prediction = prediction.split("final answer is")[-1]
prediction = prediction.split("所以最终答案是")[-1]
prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。")
accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision)
Expand All @@ -78,7 +61,7 @@ def olympiadbench_process_results(doc, results):

def olympiadbench_aggregate_results(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
submission_file_name = f"olympiadbench-test-submission-{now_date_time}.json"
submission_file_name = f"olympiadbench-test-cn-submission-{now_date_time}.json"
path = generate_submission_file(submission_file_name, args)
with open(path, "w") as f:
json.dump(results, f, ensure_ascii=False)
Expand Down
69 changes: 69 additions & 0 deletions lmms_eval/tasks/olympiadbench/en_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import json
import datetime
from lmms_eval.tasks.olympiadbench.olympiadbench_evals import OlympiadBenchEvaluator
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file

import logging
eval_logger = logging.getLogger("lmms-eval")
dir_name = os.path.dirname(os.path.abspath(__file__))

olympiadbench_evaluator = OlympiadBenchEvaluator()

def olympiadbench_doc_to_visual(doc):
return [image.convert("RGB") for image in doc["images"]]

def olympiadbench_doc_to_text(doc):
question = doc["question"]
subject = doc["subfield"]
mul_ans = doc["is_multiple_answer"]
if mul_ans is None:
mul_ans = False
ans_type = doc["answer_type"]
if ans_type == "Need_human_evaluate":
ans_type = "proof based"

pre_prompt = f"The following is a question from an International {subject} competition.\n"

post_prompt = ""
if not mul_ans:
post_prompt += f"The answer of the question should be {ans_type}.\n"
else:
post_prompt += f"The question has multiple answers, each of them should be {ans_type}.\n"
post_prompt += "Please calculate the answer according to the given requirements and the information provided. Please use LaTeX format to represent the variables and formulas used in the solution process and results. Please end your solution with "
if not mul_ans:
post_prompt += '"So the final answer is \\boxed{answer}."\n'
else:
post_prompt += 'So the final answer is \\boxed{multiple answers connected with commas}.\n'

final_question = pre_prompt + question + '\n' + post_prompt
return final_question

def olympiadbench_process_results(doc, results):
precision = doc["error"]
is_proving = "TP" in doc["source"]
if precision is None:
precision = 0
prediction = results[0].strip()

if is_proving:
return {
"submission": prediction
}
else:
prediction = prediction.split("final answer is")[-1]
prediction = prediction.replace('"', "").replace("\n", "").replace(" ", "").strip(".").strip("。")
accuracy = olympiadbench_evaluator.judge(prediction, doc["final_answer"][0], precision)
accuracy = int(accuracy)
return {
"exact_match": accuracy
}

def olympiadbench_aggregate_results(results, args):
now_date_time = datetime.datetime.now().strftime("%Y-%m%d-%H%M-%S")
submission_file_name = f"olympiadbench-test-en-submission-{now_date_time}.json"
path = generate_submission_file(submission_file_name, args)
with open(path, "w") as f:
json.dump(results, f, ensure_ascii=False)
print(f"Submission file saved to {path}")

7 changes: 2 additions & 5 deletions lmms_eval/tasks/olympiadbench/olympiadbench.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
group: olympiadbench
task:
# - olympiadbench_test_math_en_comp
# - olympiadbench_test_math_zh_comp
- olympiadbench_test_math_zh_cee
# - olympiadbench_test_physics_en_comp
# - olympiadbench_test_physics_zh_cee
- olympiadbench_test_en
- olympiadbench_test_cn
metadata:
- version: 0.0
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
dataset_path: lmms-lab/OlympiadBench
dataset_kwargs:
token: True
task : "olympiadbench_test_math_zh_cee"
test_split: test_math_zh_cee
task : "olympiadbench_test_cn"
test_split: test_cn
output_type: generate_until
doc_to_visual: !function utils.olympiadbench_doc_to_visual
doc_to_text: !function utils.olympiadbench_doc_to_text
doc_to_visual: !function cn_utils.olympiadbench_doc_to_visual
doc_to_text: !function cn_utils.olympiadbench_doc_to_text
doc_to_target: "answer"
generation_kwargs:
until:
Expand All @@ -15,10 +15,10 @@ generation_kwargs:
top_p: 0
num_beams: 1
do_sample: false
process_results: !function utils.olympiadbench_process_results
process_results: !function cn_utils.olympiadbench_process_results
metric_list:
- metric: submission
aggregation: !function utils.olympiadbench_aggregate_results
aggregation: !function cn_utils.olympiadbench_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: mean
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
dataset_path: lmms-lab/OlympiadBench
dataset_kwargs:
token: True
task : "olympiadbench_test_math_en_comp"
test_split: test_math_en_comp
task : "olympiadbench_test_en"
test_split: test_en
output_type: generate_until
doc_to_visual: !function utils.olympiadbench_doc_to_visual
doc_to_text: !function utils.olympiadbench_doc_to_text
doc_to_visual: !function en_utils.olympiadbench_doc_to_visual
doc_to_text: !function en_utils.olympiadbench_doc_to_text
doc_to_target: "answer"
generation_kwargs:
until:
Expand All @@ -15,10 +15,10 @@ generation_kwargs:
top_p: 0
num_beams: 1
do_sample: false
process_results: !function utils.olympiadbench_process_results
process_results: !function en_utils.olympiadbench_process_results
metric_list:
- metric: submission
aggregation: !function utils.olympiadbench_aggregate_results
aggregation: !function en_utils.olympiadbench_aggregate_results
higher_is_better: true
- metric: exact_match
aggregation: mean
Expand Down
25 changes: 0 additions & 25 deletions lmms_eval/tasks/olympiadbench/olympiadbench_test_math_zh_comp.yaml

This file was deleted.

This file was deleted.

This file was deleted.

0 comments on commit 035f3df

Please sign in to comment.