1
- import json
2
1
import os
3
2
import requests
4
3
from pathlib import Path
5
4
import urllib .request
6
5
from skema .rest .proxies import SKEMA_MATHJAX_ADDRESS
7
6
from skema .img2mml .translate import convert_to_torch_tensor , render_mml
7
+ from skema .img2mml .models .image2mml_xfmer import Image2MathML_Xfmer
8
+ import torch
9
+ from typing import Tuple , List , Any , Dict
10
+ from logging import info
11
+ from skema .img2mml .translate import define_model
12
+ import json
8
13
9
14
10
- def retrieve_model (model_path = None ):
15
+ def retrieve_model (model_path = None ) -> str :
11
16
"""
12
17
Retrieve the img2mml model from the specified path or download it if not found.
13
18
@@ -34,27 +39,177 @@ def retrieve_model(model_path=None):
34
39
return str (model_path )
35
40
36
41
37
- def get_mathml_from_bytes (data : bytes ):
38
- # read config file
42
+ def check_gpu_availability () -> torch .device :
43
+ """
44
+ Check if GPU is available and return the appropriate device.
45
+
46
+ Returns:
47
+ torch.device: The device (GPU or CPU) to be used for computation.
48
+ """
49
+ if not torch .cuda .is_available ():
50
+ print ("CUDA is not available, falling back to using the CPU." )
51
+ device = torch .device ("cpu" )
52
+ else :
53
+ device = torch .device ("cuda" )
54
+
55
+ return device
56
+
57
+
58
+ def load_model (
59
+ model_path : str ,
60
+ config : dict ,
61
+ vocab : List [str ],
62
+ device : torch .device = torch .device ("cpu" ),
63
+ ) -> Image2MathML_Xfmer :
64
+ """
65
+ Load the model's state dictionary from a file.
66
+
67
+ Args:
68
+ model_path: The path to the model state dictionary file.
69
+ config: The configuration setting.
70
+ vocab: The vocabulary dictionary of the img2mml model.
71
+ device: The device (GPU or CPU) to be used for computation.
72
+
73
+ Returns:
74
+ The model with loaded state dictionary.
75
+
76
+ Raises:
77
+ FileNotFoundError: If the model state dictionary file does not exist.
78
+ RuntimeError: If there is an error during loading the state dictionary.
79
+
80
+ Note:
81
+ If `clean_state_dict` is True, the function removes the "module." prefix from the state_dict keys
82
+ if present.
83
+
84
+ If CUDA is not available, the function falls back to using the CPU for loading the state dictionary.
85
+ """
86
+
87
+ model : Image2MathML_Xfmer = define_model (config , vocab , device ).to (device )
39
88
cwd = Path (__file__ ).parents [0 ]
40
- config_path = cwd / "configs" / "xfmer_mml_config.json"
41
- with open (config_path , "r" ) as cfg :
42
- config = json .load (cfg )
43
- # convert png image to tensor
44
- imagetensor = convert_to_torch_tensor (data , config )
89
+ if model_path is None :
90
+ model_path = (
91
+ cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_best.pt"
92
+ )
93
+ try :
94
+ # if state_dict keys has "module.<key_name>"
95
+ # we need to remove the "module." from key_names
96
+ if config ["clean_state_dict" ]:
97
+ new_model = dict ()
98
+ for key , value in torch .load (model_path , map_location = device ).items ():
99
+ new_model [key [7 :]] = value
100
+ model .load_state_dict (new_model , strict = False )
101
+ else :
102
+ if not torch .cuda .is_available ():
103
+ info ("CUDA is not available, falling back to using the CPU." )
104
+ new_model = dict ()
105
+ for key , value in torch .load (model_path , map_location = device ).items ():
106
+ new_model [key [7 :]] = value
107
+ model .load_state_dict (new_model , strict = False )
108
+ else :
109
+ model .load_state_dict (torch .load (model_path ))
110
+ except FileNotFoundError :
111
+ raise FileNotFoundError (f"Model state dictionary file not found: { model_path } " )
112
+ except Exception as e :
113
+ raise RuntimeError (
114
+ f"Error loading state dictionary from file: { model_path } \n { e } "
115
+ )
116
+
117
+ return model
118
+
119
+
120
+ def load_vocab (vocab_path : str = None ) -> Tuple [List [str ], dict , dict ]:
121
+ """
122
+ Load vocabulary from a list and create dictionaries for both forward and backward mapping.
45
123
46
- # change the shape of tensor from (C_in, H, W)
47
- # to (1, C_in, H, w) [batch =1]
48
- imagetensor = imagetensor .unsqueeze (0 )
49
- VOCAB_NAME = "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"
124
+ Args:
125
+ vocab (Optional[str, Path]): The vocabulary path.
126
+
127
+ Returns:
128
+ Tuple[List[str], dict, dict]: A tuple containing two dictionaries:
129
+ - vocab (List[str]): A complete dictionary.
130
+ - vocab_itos (dict): A dictionary mapping index to token.
131
+ - vocab_stoi (dict): A dictionary mapping token to index.
132
+ """
133
+ cwd = Path (__file__ ).parents [0 ]
134
+ if vocab_path is None :
135
+ vocab_path = (
136
+ cwd / "trained_models" / "arxiv_im2mml_with_fonts_with_boldface_vocab.txt"
137
+ )
50
138
51
139
# read vocab.txt
52
- with open (cwd / "trained_models" / VOCAB_NAME ) as f :
140
+ with open (vocab_path ) as f :
53
141
vocab = f .readlines ()
54
142
55
- model_path = retrieve_model ()
143
+ vocab_itos = dict ()
144
+ vocab_stoi = dict ()
145
+
146
+ for v in vocab :
147
+ k , v = v .split ()
148
+ vocab_itos [v .strip ()] = k .strip ()
149
+ vocab_stoi [k .strip ()] = v .strip ()
150
+
151
+ return vocab , vocab_itos , vocab_stoi
152
+
153
+
154
+ class Image2MathML :
155
+ def __init__ (self , config_path : str , vocab_path : str , model_path : str ) -> None :
156
+ self .config = self .load_config (config_path )
157
+ self .vocab , self .vocab_itos , self .vocab_stoi = self .load_vocab (vocab_path )
158
+ self .device = self .check_gpu_availability ()
159
+ self .model = self .load_model (model_path )
160
+
161
+ def load_config (self , config_path : str ) -> Dict [str , Any ]:
162
+ with open (config_path , "r" ) as cfg :
163
+ config = json .load (cfg )
164
+ return config
165
+
166
+ def load_vocab (self , vocab_path : str ) -> Tuple [Any , Dict [str , Any ], Dict [str , Any ]]:
167
+ # Load the image2mathml vocabulary
168
+ vocab , vocab_itos , vocab_stoi = load_vocab (vocab_path = vocab_path )
169
+ return vocab , vocab_itos , vocab_stoi
170
+
171
+ def check_gpu_availability (self ) -> torch .device :
172
+ # Check GPU availability
173
+ if torch .cuda .is_available ():
174
+ device = torch .device ("cuda" )
175
+ else :
176
+ device = torch .device ("cpu" )
177
+ return device
178
+
179
+ def load_model (self , model_path : str ) -> Image2MathML_Xfmer :
180
+ # Load the image2mathml model
181
+ MODEL_PATH = retrieve_model (model_path = model_path )
182
+ img2mml_model : Image2MathML_Xfmer = load_model (
183
+ model_path = MODEL_PATH , config = self .config , vocab = self .vocab , device = self .device
184
+ )
185
+ return img2mml_model
186
+
187
+ def get_mathml_from_bytes (
188
+ data : bytes ,
189
+ image2mathml_db : Image2MathML ,
190
+ ) -> str :
191
+ """
192
+ Convert an image in bytes format to MathML representation using the provided model.
193
+
194
+ Args:
195
+ data (bytes): The image data in bytes format.
196
+ model (Image2MathML_Xfmer): The pre-trained image-to-MathML model.
197
+ config (Dict): Configuration dictionary for rendering MathML.
198
+ vocab_itos (Dict): Dictionary mapping index to token for vocabulary.
199
+ vocab_stoi (Dict): Dictionary mapping token to index for vocabulary.
200
+ device (torch.device): CPU or GPU.
56
201
57
- return render_mml (config , model_path , vocab , imagetensor )
202
+ Returns:
203
+ str: The MathML representation of the input image.
204
+ """
205
+ # convert png image to tensor
206
+ imagetensor = convert_to_torch_tensor (data , image2mathml_db .config )
207
+
208
+ # change the shape of tensor from (C_in, H, W)
209
+ # to (1, C_in, H, w) [batch =1]
210
+ imagetensor = imagetensor .unsqueeze (0 )
211
+
212
+ return render_mml (image2mathml_db .model , image2mathml_db .vocab_itos , image2mathml_db .vocab_stoi , imagetensor , image2mathml_db .device )
58
213
59
214
60
215
def get_mathml_from_file (filepath ) -> str :
@@ -94,3 +249,5 @@ def get_mathml_from_latex(eqn: str) -> str:
94
249
return f"An error occurred: { e } "
95
250
finally :
96
251
return "Conversion Failed."
252
+
253
+
0 commit comments