|
| 1 | +import logging |
| 2 | + |
| 3 | +from lmms_eval.tasks._task_utils.file_utils import generate_submission_file |
| 4 | + |
| 5 | +logger = logging.getLogger("lmms-eval") |
| 6 | + |
| 7 | +# Add the following functions to your existing utils.py file |
| 8 | +OCRBench_score = { |
| 9 | + "Regular Text Recognition": 0, |
| 10 | + "Irregular Text Recognition": 0, |
| 11 | + "Artistic Text Recognition": 0, |
| 12 | + "Handwriting Recognition": 0, |
| 13 | + "Digit String Recognition": 0, |
| 14 | + "Non-Semantic Text Recognition": 0, |
| 15 | + "Scene Text-centric VQA": 0, |
| 16 | + "Doc-oriented VQA": 0, |
| 17 | + "Key Information Extraction": 0, |
| 18 | + "Handwritten Mathematical Expression Recognition": 0, |
| 19 | +} |
| 20 | + |
| 21 | + |
| 22 | +def ocrbench_doc_to_visual(doc): |
| 23 | + # Assuming the 'doc' dictionary has a key 'image' with image data |
| 24 | + return [doc["image"].convert("RGB")] |
| 25 | + |
| 26 | + |
| 27 | +def ocrbench_doc_to_text(doc): |
| 28 | + # Assuming the 'doc' dictionary has a key 'question' with the question text |
| 29 | + question = doc["question"].strip() |
| 30 | + return f"{question}" |
| 31 | + |
| 32 | + |
| 33 | +def ocrbench_process_results(doc, results): |
| 34 | + pred = results[0].lower().strip() |
| 35 | + gt_ans = doc["answer"] |
| 36 | + dataset_name = doc["dataset"] |
| 37 | + |
| 38 | + score = 0 |
| 39 | + if dataset_name == "HME100k": |
| 40 | + if type(gt_ans) == list: |
| 41 | + for j in range(len(gt_ans)): |
| 42 | + answer = gt_ans[j].strip().replace("\n", " ").replace(" ", "") |
| 43 | + predict = pred.strip().replace("\n", " ").replace(" ", "") |
| 44 | + if answer in predict: |
| 45 | + score = 1 |
| 46 | + else: |
| 47 | + answer = gt_ans.strip().replace("\n", " ").replace(" ", "") |
| 48 | + predict = pred.strip().replace("\n", " ").replace(" ", "") |
| 49 | + if answer in predict: |
| 50 | + score = 1 |
| 51 | + else: |
| 52 | + if type(gt_ans) == list: |
| 53 | + for j in range(len(gt_ans)): |
| 54 | + answer = gt_ans[j].lower().strip().replace("\n", " ") |
| 55 | + predict = pred.lower().strip().replace("\n", " ") |
| 56 | + if answer in predict: |
| 57 | + score = 1 |
| 58 | + else: |
| 59 | + answer = gt_ans.lower().strip().replace("\n", " ") |
| 60 | + predict = pred.lower().strip().replace("\n", " ") |
| 61 | + if answer in predict: |
| 62 | + score = 1 |
| 63 | + return { |
| 64 | + "ocrbench_accuracy": {"question_type": doc["question_type"], "score": score, "prediction": pred, "ground_truth": gt_ans}, |
| 65 | + } |
| 66 | + |
| 67 | + |
| 68 | +def ocrbench_aggregate_accuracy(results, args): |
| 69 | + for result in results: |
| 70 | + OCRBench_score[result["question_type"]] += result["score"] |
| 71 | + recognition_score = ( |
| 72 | + OCRBench_score["Regular Text Recognition"] |
| 73 | + + OCRBench_score["Irregular Text Recognition"] |
| 74 | + + OCRBench_score["Artistic Text Recognition"] |
| 75 | + + OCRBench_score["Handwriting Recognition"] |
| 76 | + + OCRBench_score["Digit String Recognition"] |
| 77 | + + OCRBench_score["Non-Semantic Text Recognition"] |
| 78 | + ) |
| 79 | + Final_score = recognition_score + OCRBench_score["Scene Text-centric VQA"] + OCRBench_score["Doc-oriented VQA"] + OCRBench_score["Key Information Extraction"] + OCRBench_score["Handwritten Mathematical Expression Recognition"] |
| 80 | + file_name = generate_submission_file("ocrbench_results.txt", args, subpath="results") |
| 81 | + with open(file_name, "w") as f: |
| 82 | + print("######################### OCRBench #############################", file=f) |
| 83 | + print(f"Text Recognition(Total 300): {recognition_score}", file=f) |
| 84 | + print("---------------- Details of Recognition Score ------------------", file=f) |
| 85 | + print(f"Regular Text Recognition(Total 50): {OCRBench_score['Regular Text Recognition']}", file=f) |
| 86 | + print(f"Irregular Text Recognition(Total 50): {OCRBench_score['Irregular Text Recognition']}", file=f) |
| 87 | + print(f"Artistic Text Recognition(Total 50): {OCRBench_score['Artistic Text Recognition']}", file=f) |
| 88 | + print(f"Handwriting Recognition(Total 50): {OCRBench_score['Handwriting Recognition']}", file=f) |
| 89 | + print(f"Digit String Recognition(Total 50): {OCRBench_score['Digit String Recognition']}", file=f) |
| 90 | + print(f"Non-Semantic Text Recognition(Total 50): {OCRBench_score['Non-Semantic Text Recognition']}", file=f) |
| 91 | + print("----------------------------------------------------------------", file=f) |
| 92 | + print(f"Scene Text-centric VQA(Total 200): {OCRBench_score['Scene Text-centric VQA']}", file=f) |
| 93 | + print("----------------------------------------------------------------", file=f) |
| 94 | + print(f"Doc-oriented VQA(Total 200): {OCRBench_score['Doc-oriented VQA']}", file=f) |
| 95 | + print("----------------------------------------------------------------", file=f) |
| 96 | + print(f"Key Information Extraction(Total 200): {OCRBench_score['Key Information Extraction']}", file=f) |
| 97 | + print("----------------------------------------------------------------") |
| 98 | + print(f"Handwritten Mathematical Expression Recognition(Total 100): {OCRBench_score['Handwritten Mathematical Expression Recognition']}", file=f) |
| 99 | + print("--------------------- Final Score ------------------------------", file=f) |
| 100 | + print(f"Final Score(Total 1000): {Final_score}", file=f) |
| 101 | + logger.info(f"OCR Bench results saved to {file_name}") |
| 102 | + # return {"Final Score":Final_score,"Text Recognition":recognition_score,'Scene Text-centric VQA':OCRBench_score['Scene Text-centric VQA'],'Doc-oriented VQA':OCRBench_score['Doc-oriented VQA'],'Key Information Extraction':OCRBench_score['Key Information Extraction'],'Handwritten Mathematical Expression Recognition':OCRBench_score['Handwritten Mathematical Expression Recognition']} |
| 103 | + return Final_score |
0 commit comments