Skip to content
This repository was archived by the owner on Jan 8, 2025. It is now read-only.

Commit 1b85d5c

Browse files
authored
Merge branch 'ml4ai:main' into main
2 parents 3a86205 + 76d62f1 commit 1b85d5c

File tree

7 files changed

+325
-241
lines changed

7 files changed

+325
-241
lines changed

skema/img2mml/api.py

Lines changed: 173 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
1-
import json
21
import os
32
import requests
43
from pathlib import Path
54
import urllib.request
65
from skema.rest.proxies import SKEMA_MATHJAX_ADDRESS
76
from skema.img2mml.translate import convert_to_torch_tensor, render_mml
7+
from skema.img2mml.models.image2mml_xfmer import Image2MathML_Xfmer
8+
import torch
9+
from typing import Tuple, List, Any, Dict
10+
from logging import info
11+
from skema.img2mml.translate import define_model
12+
import json
813

914

10-
def retrieve_model(model_path=None):
15+
def retrieve_model(model_path=None) -> str:
1116
"""
1217
Retrieve the img2mml model from the specified path or download it if not found.
1318
@@ -34,27 +39,177 @@ def retrieve_model(model_path=None):
3439
return str(model_path)
3540

3641

37-
def get_mathml_from_bytes(data: bytes):
38-
# read config file
42+
def check_gpu_availability() -> torch.device:
43+
"""
44+
Check if GPU is available and return the appropriate device.
45+
46+
Returns:
47+
torch.device: The device (GPU or CPU) to be used for computation.
48+
"""
49+
if not torch.cuda.is_available():
50+
print("CUDA is not available, falling back to using the CPU.")
51+
device = torch.device("cpu")
52+
else:
53+
device = torch.device("cuda")
54+
55+
return device
56+
57+
58+
def load_model(
59+
model_path: str,
60+
config: dict,
61+
vocab: List[str],
62+
device: torch.device = torch.device("cpu"),
63+
) -> Image2MathML_Xfmer:
64+
"""
65+
Load the model's state dictionary from a file.
66+
67+
Args:
68+
model_path: The path to the model state dictionary file.
69+
config: The configuration setting.
70+
vocab: The vocabulary dictionary of the img2mml model.
71+
device: The device (GPU or CPU) to be used for computation.
72+
73+
Returns:
74+
The model with loaded state dictionary.
75+
76+
Raises:
77+
FileNotFoundError: If the model state dictionary file does not exist.
78+
RuntimeError: If there is an error during loading the state dictionary.
79+
80+
Note:
81+
If `clean_state_dict` is True, the function removes the "module." prefix from the state_dict keys
82+
if present.
83+
84+
If CUDA is not available, the function falls back to using the CPU for loading the state dictionary.
85+
"""
86+
87+
model: Image2MathML_Xfmer = define_model(config, vocab, device).to(device)
3988
cwd = Path(__file__).parents[0]
40-
config_path = cwd / "configs" / "xfmer_mml_config.json"
41-
with open(config_path, "r") as cfg:
42-
config = json.load(cfg)
43-
# convert png image to tensor
44-
imagetensor = convert_to_torch_tensor(data, config)
89+
if model_path is None:
90+
model_path = (
91+
cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_best.pt"
92+
)
93+
try:
94+
# if state_dict keys has "module.<key_name>"
95+
# we need to remove the "module." from key_names
96+
if config["clean_state_dict"]:
97+
new_model = dict()
98+
for key, value in torch.load(model_path, map_location=device).items():
99+
new_model[key[7:]] = value
100+
model.load_state_dict(new_model, strict=False)
101+
else:
102+
if not torch.cuda.is_available():
103+
info("CUDA is not available, falling back to using the CPU.")
104+
new_model = dict()
105+
for key, value in torch.load(model_path, map_location=device).items():
106+
new_model[key[7:]] = value
107+
model.load_state_dict(new_model, strict=False)
108+
else:
109+
model.load_state_dict(torch.load(model_path))
110+
except FileNotFoundError:
111+
raise FileNotFoundError(f"Model state dictionary file not found: {model_path}")
112+
except Exception as e:
113+
raise RuntimeError(
114+
f"Error loading state dictionary from file: {model_path}\n{e}"
115+
)
116+
117+
return model
118+
119+
120+
def load_vocab(vocab_path: str = None) -> Tuple[List[str], dict, dict]:
121+
"""
122+
Load vocabulary from a list and create dictionaries for both forward and backward mapping.
45123
46-
# change the shape of tensor from (C_in, H, W)
47-
# to (1, C_in, H, w) [batch =1]
48-
imagetensor = imagetensor.unsqueeze(0)
49-
VOCAB_NAME = "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"
124+
Args:
125+
vocab (Optional[str, Path]): The vocabulary path.
126+
127+
Returns:
128+
Tuple[List[str], dict, dict]: A tuple containing two dictionaries:
129+
- vocab (List[str]): A complete dictionary.
130+
- vocab_itos (dict): A dictionary mapping index to token.
131+
- vocab_stoi (dict): A dictionary mapping token to index.
132+
"""
133+
cwd = Path(__file__).parents[0]
134+
if vocab_path is None:
135+
vocab_path = (
136+
cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"
137+
)
50138

51139
# read vocab.txt
52-
with open(cwd / "trained_models" / VOCAB_NAME) as f:
140+
with open(vocab_path) as f:
53141
vocab = f.readlines()
54142

55-
model_path = retrieve_model()
143+
vocab_itos = dict()
144+
vocab_stoi = dict()
145+
146+
for v in vocab:
147+
k, v = v.split()
148+
vocab_itos[v.strip()] = k.strip()
149+
vocab_stoi[k.strip()] = v.strip()
150+
151+
return vocab, vocab_itos, vocab_stoi
152+
153+
154+
class Image2MathML:
155+
def __init__(self, config_path: str, vocab_path: str, model_path: str) -> None:
156+
self.config = self.load_config(config_path)
157+
self.vocab, self.vocab_itos, self.vocab_stoi = self.load_vocab(vocab_path)
158+
self.device = self.check_gpu_availability()
159+
self.model = self.load_model(model_path)
160+
161+
def load_config(self, config_path: str) -> Dict[str, Any]:
162+
with open(config_path, "r") as cfg:
163+
config = json.load(cfg)
164+
return config
165+
166+
def load_vocab(self, vocab_path: str) -> Tuple[Any, Dict[str, Any], Dict[str, Any]]:
167+
# Load the image2mathml vocabulary
168+
vocab, vocab_itos, vocab_stoi = load_vocab(vocab_path=vocab_path)
169+
return vocab, vocab_itos, vocab_stoi
170+
171+
def check_gpu_availability(self) -> torch.device:
172+
# Check GPU availability
173+
if torch.cuda.is_available():
174+
device = torch.device("cuda")
175+
else:
176+
device = torch.device("cpu")
177+
return device
178+
179+
def load_model(self, model_path: str) -> Image2MathML_Xfmer:
180+
# Load the image2mathml model
181+
MODEL_PATH = retrieve_model(model_path=model_path)
182+
img2mml_model: Image2MathML_Xfmer = load_model(
183+
model_path=MODEL_PATH, config=self.config, vocab=self.vocab, device=self.device
184+
)
185+
return img2mml_model
186+
187+
def get_mathml_from_bytes(
188+
data: bytes,
189+
image2mathml_db: Image2MathML,
190+
) -> str:
191+
"""
192+
Convert an image in bytes format to MathML representation using the provided model.
193+
194+
Args:
195+
data (bytes): The image data in bytes format.
196+
model (Image2MathML_Xfmer): The pre-trained image-to-MathML model.
197+
config (Dict): Configuration dictionary for rendering MathML.
198+
vocab_itos (Dict): Dictionary mapping index to token for vocabulary.
199+
vocab_stoi (Dict): Dictionary mapping token to index for vocabulary.
200+
device (torch.device): CPU or GPU.
56201
57-
return render_mml(config, model_path, vocab, imagetensor)
202+
Returns:
203+
str: The MathML representation of the input image.
204+
"""
205+
# convert png image to tensor
206+
imagetensor = convert_to_torch_tensor(data, image2mathml_db.config)
207+
208+
# change the shape of tensor from (C_in, H, W)
209+
# to (1, C_in, H, w) [batch =1]
210+
imagetensor = imagetensor.unsqueeze(0)
211+
212+
return render_mml(image2mathml_db.model, image2mathml_db.vocab_itos, image2mathml_db.vocab_stoi, imagetensor, image2mathml_db.device)
58213

59214

60215
def get_mathml_from_file(filepath) -> str:
@@ -94,3 +249,5 @@ def get_mathml_from_latex(eqn: str) -> str:
94249
return f"An error occurred: {e}"
95250
finally:
96251
return "Conversion Failed."
252+
253+

0 commit comments

Comments
 (0)