Skip to content

Commit 9dfb53a

Browse files
authored
Merge pull request #28 from echo840/main
add_ocrbench
2 parents d7b207f + e00d0ca commit 9dfb53a

File tree

4 files changed

+221
-2
lines changed

4 files changed

+221
-2
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import os
22

33

4-
def generate_submission_file(file_name, args):
5-
path = os.path.join(args.output_path, "submissions")
4+
def generate_submission_file(file_name, args, subpath="submissions"):
5+
path = os.path.join(args.output_path, subpath)
66
os.makedirs(path, exist_ok=True)
77
path = os.path.join(path, file_name)
88
return os.path.abspath(path)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
dataset_path: echo840/OCRBench
2+
dataset_kwargs:
3+
token: True
4+
task: "ocrbench"
5+
test_split: test
6+
output_type: generate_until
7+
doc_to_visual: !function utils.ocrbench_doc_to_visual
8+
doc_to_text: !function utils.ocrbench_doc_to_text
9+
doc_to_target: "answer"
10+
generation_kwargs:
11+
max_new_tokens: 128
12+
temperature: 0
13+
top_p: 0
14+
num_beams: 1
15+
do_sample: false
16+
process_results: !function utils.ocrbench_process_results
17+
metric_list:
18+
- metric: ocrbench_accuracy
19+
aggregation: !function utils.ocrbench_aggregate_accuracy
20+
higher_is_better: true
21+
metadata:
22+
- version: 0.0
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import json
17+
18+
import datasets
19+
from PIL import Image as PIL_Image
20+
import json
21+
from uuid import uuid4
22+
from datasets import Dataset, Features
23+
import pandas as pd
24+
from tqdm import tqdm
25+
import io
26+
27+
# Find for instance the citation on arxiv or on the dataset repo/website
28+
_CITATION = """https://arxiv.org/abs/2305.07895"""
29+
_DESCRIPTION = "OCRBench is a comprehensive evaluation benchmark designed to assess the OCR capabilities of Large Multimodal Models."
30+
31+
32+
def image2byte(image):
33+
img_bytes = io.BytesIO()
34+
image.save(img_bytes, format="JPEG")
35+
image_bytes = img_bytes.getvalue()
36+
return image_bytes
37+
38+
39+
def get_builder_config(VERSION):
40+
builder_config = [
41+
datasets.BuilderConfig(
42+
name=f"ocrbench",
43+
version=VERSION,
44+
description=f"ocrbench",
45+
)
46+
]
47+
return builder_config
48+
49+
50+
ocrbench_json = "pathto/OCRBench/OCRBench.json"
51+
img_dir = "pathto/OCRBench_Images/"
52+
53+
dataset_features = Features(
54+
{
55+
"dataset": datasets.Value("string"),
56+
"question": datasets.Value("string"),
57+
"question_type": datasets.Value("string"),
58+
"answer": datasets.features.Sequence(datasets.Value("string")),
59+
"image": datasets.Image(),
60+
}
61+
)
62+
63+
df_items = {
64+
"dataset": [],
65+
"question": [],
66+
"question_type": [],
67+
"answer": [],
68+
"image": [],
69+
}
70+
# img_feature = datasets.Image(decode=False)
71+
with open(ocrbench_json, "r") as f:
72+
data = json.load(f)
73+
for i in tqdm(range(len(data))):
74+
dataset_name = data[i]["dataset_name"]
75+
image_path = img_dir + data[i]["image_path"]
76+
question = data[i]["question"]
77+
answers = data[i]["answers"]
78+
question_type = data[i]["type"]
79+
if type(answers) == str:
80+
answers = [answers]
81+
img = PIL_Image.open(image_path).convert("RGB")
82+
byte_data = image2byte(img)
83+
image = {"bytes": byte_data, "path": ""}
84+
df_items["image"].append(image)
85+
df_items["question"].append(str(question))
86+
df_items["answer"].append(answers)
87+
df_items["question_type"].append(str(question_type))
88+
df_items["dataset"].append(str(dataset_name))
89+
90+
df_items = pd.DataFrame(df_items)
91+
df_items.head()
92+
dataset = Dataset.from_pandas(df_items, features=dataset_features)
93+
hub_dataset_path = "echo840/OCRBench"
94+
dataset.push_to_hub(repo_id=hub_dataset_path, split="test")

lmms_eval/tasks/ocrbench/utils.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import logging
2+
3+
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
4+
5+
logger = logging.getLogger("lmms-eval")
6+
7+
# Add the following functions to your existing utils.py file
8+
OCRBench_score = {
9+
"Regular Text Recognition": 0,
10+
"Irregular Text Recognition": 0,
11+
"Artistic Text Recognition": 0,
12+
"Handwriting Recognition": 0,
13+
"Digit String Recognition": 0,
14+
"Non-Semantic Text Recognition": 0,
15+
"Scene Text-centric VQA": 0,
16+
"Doc-oriented VQA": 0,
17+
"Key Information Extraction": 0,
18+
"Handwritten Mathematical Expression Recognition": 0,
19+
}
20+
21+
22+
def ocrbench_doc_to_visual(doc):
23+
# Assuming the 'doc' dictionary has a key 'image' with image data
24+
return [doc["image"].convert("RGB")]
25+
26+
27+
def ocrbench_doc_to_text(doc):
28+
# Assuming the 'doc' dictionary has a key 'question' with the question text
29+
question = doc["question"].strip()
30+
return f"{question}"
31+
32+
33+
def ocrbench_process_results(doc, results):
34+
pred = results[0].lower().strip()
35+
gt_ans = doc["answer"]
36+
dataset_name = doc["dataset"]
37+
38+
score = 0
39+
if dataset_name == "HME100k":
40+
if type(gt_ans) == list:
41+
for j in range(len(gt_ans)):
42+
answer = gt_ans[j].strip().replace("\n", " ").replace(" ", "")
43+
predict = pred.strip().replace("\n", " ").replace(" ", "")
44+
if answer in predict:
45+
score = 1
46+
else:
47+
answer = gt_ans.strip().replace("\n", " ").replace(" ", "")
48+
predict = pred.strip().replace("\n", " ").replace(" ", "")
49+
if answer in predict:
50+
score = 1
51+
else:
52+
if type(gt_ans) == list:
53+
for j in range(len(gt_ans)):
54+
answer = gt_ans[j].lower().strip().replace("\n", " ")
55+
predict = pred.lower().strip().replace("\n", " ")
56+
if answer in predict:
57+
score = 1
58+
else:
59+
answer = gt_ans.lower().strip().replace("\n", " ")
60+
predict = pred.lower().strip().replace("\n", " ")
61+
if answer in predict:
62+
score = 1
63+
return {
64+
"ocrbench_accuracy": {"question_type": doc["question_type"], "score": score, "prediction": pred, "ground_truth": gt_ans},
65+
}
66+
67+
68+
def ocrbench_aggregate_accuracy(results, args):
69+
for result in results:
70+
OCRBench_score[result["question_type"]] += result["score"]
71+
recognition_score = (
72+
OCRBench_score["Regular Text Recognition"]
73+
+ OCRBench_score["Irregular Text Recognition"]
74+
+ OCRBench_score["Artistic Text Recognition"]
75+
+ OCRBench_score["Handwriting Recognition"]
76+
+ OCRBench_score["Digit String Recognition"]
77+
+ OCRBench_score["Non-Semantic Text Recognition"]
78+
)
79+
Final_score = recognition_score + OCRBench_score["Scene Text-centric VQA"] + OCRBench_score["Doc-oriented VQA"] + OCRBench_score["Key Information Extraction"] + OCRBench_score["Handwritten Mathematical Expression Recognition"]
80+
file_name = generate_submission_file("ocrbench_results.txt", args, subpath="results")
81+
with open(file_name, "w") as f:
82+
print("######################### OCRBench #############################", file=f)
83+
print(f"Text Recognition(Total 300): {recognition_score}", file=f)
84+
print("---------------- Details of Recognition Score ------------------", file=f)
85+
print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}", file=f)
86+
print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}", file=f)
87+
print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}", file=f)
88+
print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}", file=f)
89+
print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}", file=f)
90+
print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}", file=f)
91+
print("----------------------------------------------------------------", file=f)
92+
print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}", file=f)
93+
print("----------------------------------------------------------------", file=f)
94+
print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}", file=f)
95+
print("----------------------------------------------------------------", file=f)
96+
print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}", file=f)
97+
print("----------------------------------------------------------------")
98+
print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}", file=f)
99+
print("--------------------- Final Score ------------------------------", file=f)
100+
print(f"Final Score(Total 1000): {Final_score}", file=f)
101+
logger.info(f"OCR Bench results saved to {file_name}")
102+
# return {"Final Score":Final_score,"Text Recognition":recognition_score,'Scene Text-centric VQA':OCRBench_score['Scene Text-centric VQA'],'Doc-oriented VQA':OCRBench_score['Doc-oriented VQA'],'Key Information Extraction':OCRBench_score['Key Information Extraction'],'Handwritten Mathematical Expression Recognition':OCRBench_score['Handwritten Mathematical Expression Recognition']}
103+
return Final_score

0 commit comments

Comments
 (0)