Skip to content

Commit 13d562a

Browse files
Merge pull request #140 from lukas-blecher/api
Add API
2 parents aa4093f + 63787f5 commit 13d562a

File tree

7 files changed

+213
-109
lines changed

7 files changed

+213
-109
lines changed

notebooks/LaTeX_OCR_test.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
"\n",
6262
"from pix2tex import cli as pix2tex\n",
6363
"from PIL import Image\n",
64-
"args = pix2tex.initialize()\n",
64+
"model = pix2tex.LatexOCR()\n",
6565
"\n",
6666
"from IPython.display import HTML, Math\n",
6767
"display(HTML(\"<script src='https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.3/\"\n",
@@ -76,7 +76,7 @@
7676
"predictions = []\n",
7777
"for name, f in imgs:\n",
7878
" img = Image.open(f)\n",
79-
" math = pix2tex.call_model(*args, img)\n",
79+
" math = model(img)\n",
8080
" print(math)\n",
8181
" predictions.append('\\\\mathrm{%s} & \\\\displaystyle{%s}'%(name, math))\n",
8282
"Math(table%'\\\\\\\\'.join(predictions))"

pix2tex/api/app.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Adapted from https://github.com/kingyiusuen/image-to-latex/blob/main/api/app.py
2+
3+
from ctypes import resize
4+
from http import HTTPStatus
5+
from fastapi import FastAPI, File, UploadFile, Form
6+
from PIL import Image
7+
from io import BytesIO
8+
from pix2tex.cli import LatexOCR
9+
10+
model = None
11+
app = FastAPI(title='pix2tex API')
12+
13+
14+
def read_imagefile(file) -> Image.Image:
15+
image = Image.open(BytesIO(file))
16+
return image
17+
18+
19+
@app.on_event('startup')
20+
async def load_model():
21+
global model
22+
if model is None:
23+
model = LatexOCR()
24+
25+
26+
@app.get('/')
27+
def root():
28+
'''Health check.'''
29+
response = {
30+
'message': HTTPStatus.OK.phrase,
31+
'status-code': HTTPStatus.OK,
32+
'data': {},
33+
}
34+
return response
35+
36+
37+
@app.post('/predict/')
38+
async def predict(file: UploadFile = File(...)):
39+
global model
40+
image = Image.open(file.file)
41+
return model(image)
42+
43+
44+
@app.post('/bytes/')
45+
async def predict_from_bytes(file: bytes = File(...)): #, size: str = Form(...)
46+
global model
47+
#size = tuple(int(a) for a in size.split(','))
48+
image = Image.open(BytesIO(file))
49+
return model(image, resize=False)

pix2tex/api/run.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from multiprocessing import Process
2+
import subprocess
3+
import os
4+
5+
6+
def start_api(path='.'):
7+
subprocess.call(['uvicorn', 'app:app'], cwd=path)
8+
9+
10+
def start_frontend(path='.'):
11+
subprocess.call(['streamlit', 'run', 'streamlit.py'], cwd=path)
12+
13+
14+
if __name__ == '__main__':
15+
path = os.path.realpath(os.path.dirname(__file__))
16+
api = Process(target=start_api, kwargs={'path': path})
17+
api.start()
18+
frontend = Process(target=start_frontend, kwargs={'path': path})
19+
frontend.start()
20+
api.join()
21+
frontend.join()

pix2tex/api/streamlit.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from msilib.schema import Icon
2+
import requests
3+
from PIL import Image
4+
import streamlit
5+
6+
if __name__ == '__main__':
7+
streamlit.set_page_config(page_title='LaTeX-OCR')
8+
streamlit.title('LaTeX OCR')
9+
streamlit.markdown('Convert images of equations to corresponding LaTeX code.\n\nThis is based on the `pix2tex` module. Check it out [![github](https://img.shields.io/badge/LaTeX--OCR-visit-a?style=social&logo=github)](https://github.com/lukas-blecher/LaTeX-OCR)')
10+
11+
uploaded_file = streamlit.file_uploader(
12+
'Upload an image an equation',
13+
type=['png', 'jpg'],
14+
)
15+
16+
if uploaded_file is not None:
17+
image = Image.open(uploaded_file)
18+
streamlit.image(image)
19+
else:
20+
streamlit.text('\n')
21+
22+
if streamlit.button('Convert'):
23+
if uploaded_file is not None and image is not None:
24+
with streamlit.spinner('Computing'):
25+
response = requests.post('http://127.0.0.1:8000/predict/', files={'file': uploaded_file.getvalue()})
26+
if response.ok:
27+
latex_code = response.json()
28+
streamlit.code(latex_code, language='latex')
29+
streamlit.markdown(f'$\\displaystyle {latex_code}$')
30+
else:
31+
streamlit.error(response.text)
32+
else:
33+
streamlit.error('Please upload an image.')

pix2tex/cli.py

Lines changed: 81 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@
2121
from pix2tex.utils import *
2222
from pix2tex.model.checkpoints.get_latest_checkpoint import download_checkpoints
2323

24-
last_pic = None
25-
2624

2725
def minmax_size(img, max_dimensions=None, min_dimensions=None):
2826
if max_dimensions is not None:
@@ -40,79 +38,77 @@ def minmax_size(img, max_dimensions=None, min_dimensions=None):
4038
return img
4139

4240

43-
@in_model_path()
44-
def initialize(arguments=None):
45-
if arguments is None:
46-
arguments = Munch({'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False})
47-
logging.getLogger().setLevel(logging.FATAL)
48-
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
49-
with open(arguments.config, 'r') as f:
50-
params = yaml.load(f, Loader=yaml.FullLoader)
51-
args = parse_args(Munch(params))
52-
args.update(**vars(arguments))
53-
args.wandb = False
54-
args.device = 'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu'
55-
if not os.path.exists(args.checkpoint):
56-
download_checkpoints()
57-
model = get_model(args)
58-
model.load_state_dict(torch.load(args.checkpoint, map_location=args.device))
59-
60-
if 'image_resizer.pth' in os.listdir(os.path.dirname(args.checkpoint)) and not arguments.no_resize:
61-
image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
62-
preact=True, stem_type='same', conv_layer=StdConv2dSame).to(args.device)
63-
image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(args.checkpoint), 'image_resizer.pth'), map_location=args.device))
64-
image_resizer.eval()
65-
else:
66-
image_resizer = None
67-
tokenizer = PreTrainedTokenizerFast(tokenizer_file=args.tokenizer)
68-
return args, model, image_resizer, tokenizer
69-
70-
71-
@in_model_path()
72-
def call_model(args, model, image_resizer, tokenizer, img=None):
73-
global last_pic
74-
encoder, decoder = model.encoder, model.decoder
75-
if type(img) is bool:
76-
img = None
77-
if img is None:
78-
if last_pic is None:
79-
print('Provide an image.')
80-
return ''
41+
class LatexOCR:
42+
image_resizer = None
43+
last_pic = None
44+
45+
@in_model_path()
46+
def __init__(self, arguments=None):
47+
if arguments is None:
48+
arguments = Munch({'config': 'settings/config.yaml', 'checkpoint': 'checkpoints/weights.pth', 'no_cuda': True, 'no_resize': False})
49+
logging.getLogger().setLevel(logging.FATAL)
50+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
51+
with open(arguments.config, 'r') as f:
52+
params = yaml.load(f, Loader=yaml.FullLoader)
53+
self.args = parse_args(Munch(params))
54+
self.args.update(**vars(arguments))
55+
self.args.wandb = False
56+
self.args.device = 'cuda' if torch.cuda.is_available() and not self.args.no_cuda else 'cpu'
57+
if not os.path.exists(self.args.checkpoint):
58+
download_checkpoints()
59+
self.model = get_model(self.args)
60+
self.model.load_state_dict(torch.load(self.args.checkpoint, map_location=self.args.device))
61+
62+
if 'image_resizer.pth' in os.listdir(os.path.dirname(self.args.checkpoint)) and not arguments.no_resize:
63+
self.image_resizer = ResNetV2(layers=[2, 3, 3], num_classes=max(self.args.max_dimensions)//32, global_pool='avg', in_chans=1, drop_rate=.05,
64+
preact=True, stem_type='same', conv_layer=StdConv2dSame).to(self.args.device)
65+
self.image_resizer.load_state_dict(torch.load(os.path.join(os.path.dirname(self.args.checkpoint), 'image_resizer.pth'), map_location=self.args.device))
66+
self.image_resizer.eval()
67+
self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=self.args.tokenizer)
68+
69+
@in_model_path()
70+
def __call__(self, img=None, resize=True):
71+
if type(img) is bool:
72+
img = None
73+
if img is None:
74+
if self.last_pic is None:
75+
print('Provide an image.')
76+
return ''
77+
else:
78+
img = self.last_pic.copy()
79+
else:
80+
self.last_pic = img.copy()
81+
img = minmax_size(pad(img), self.args.max_dimensions, self.args.min_dimensions)
82+
if (self.image_resizer is not None and not self.args.no_resize) and resize:
83+
with torch.no_grad():
84+
input_image = img.convert('RGB').copy()
85+
r, w, h = 1, input_image.size[0], input_image.size[1]
86+
for _ in range(10):
87+
h = int(h * r) # height to resize
88+
img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), self.args.max_dimensions, self.args.min_dimensions))
89+
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
90+
w = (self.image_resizer(t.to(self.args.device)).argmax(-1).item()+1)*32
91+
logging.info(r, img.size, (w, int(input_image.size[1]*r)))
92+
if (w == img.size[0]):
93+
break
94+
r = w/img.size[0]
8195
else:
82-
img = last_pic.copy()
83-
else:
84-
last_pic = img.copy()
85-
img = minmax_size(pad(img), args.max_dimensions, args.min_dimensions)
86-
if image_resizer is not None and not args.no_resize:
96+
img = np.array(pad(img).convert('RGB'))
97+
t = test_transform(image=img)['image'][:1].unsqueeze(0)
98+
im = t.to(self.args.device)
99+
87100
with torch.no_grad():
88-
input_image = img.convert('RGB').copy()
89-
r, w, h = 1, input_image.size[0], input_image.size[1]
90-
for _ in range(10):
91-
h = int(h * r) # height to resize
92-
img = pad(minmax_size(input_image.resize((w, h), Image.BILINEAR if r > 1 else Image.LANCZOS), args.max_dimensions, args.min_dimensions))
93-
t = test_transform(image=np.array(img.convert('RGB')))['image'][:1].unsqueeze(0)
94-
w = (image_resizer(t.to(args.device)).argmax(-1).item()+1)*32
95-
logging.info(r, img.size, (w, int(input_image.size[1]*r)))
96-
if (w == img.size[0]):
97-
break
98-
r = w/img.size[0]
99-
else:
100-
img = np.array(pad(img).convert('RGB'))
101-
t = test_transform(image=img)['image'][:1].unsqueeze(0)
102-
im = t.to(args.device)
103-
104-
with torch.no_grad():
105-
model.eval()
106-
device = args.device
107-
encoded = encoder(im.to(device))
108-
dec = decoder.generate(torch.LongTensor([args.bos_token])[:, None].to(device), args.max_seq_len,
109-
eos_token=args.eos_token, context=encoded.detach(), temperature=args.get('temperature', .25))
110-
pred = post_process(token2str(dec, tokenizer)[0])
111-
try:
112-
clipboard.copy(pred)
113-
except:
114-
pass
115-
return pred
101+
self.model.eval()
102+
device = self.args.device
103+
encoded = self.model.encoder(im.to(device))
104+
dec = self.model.decoder.generate(torch.LongTensor([self.args.bos_token])[:, None].to(device), self.args.max_seq_len,
105+
eos_token=self.args.eos_token, context=encoded.detach(), temperature=self.args.get('temperature', .25))
106+
pred = post_process(token2str(dec, self.tokenizer)[0])
107+
try:
108+
clipboard.copy(pred)
109+
except:
110+
pass
111+
return pred
116112

117113

118114
def output_prediction(pred, args):
@@ -144,7 +140,8 @@ def main():
144140
parser.add_argument('--no-resize', action='store_true', help='Resize the image beforehand')
145141
arguments = parser.parse_args()
146142
with in_model_path():
147-
args, *objs = initialize(arguments)
143+
model = LatexOCR(arguments)
144+
file = None
148145
while True:
149146
instructions = input('Predict LaTeX code for image ("?"/"h" for help). ')
150147
possible_file = instructions.strip()
@@ -176,32 +173,32 @@ def main():
176173
''')
177174
continue
178175
elif ins in ['show', 'katex', 'no_resize']:
179-
setattr(args, ins, not getattr(args, ins, False))
180-
print('set %s to %s' % (ins, getattr(args, ins)))
176+
setattr(model.args, ins, not getattr(model.args, ins, False))
177+
print('set %s to %s' % (ins, getattr(model.args, ins)))
181178
continue
182179
elif os.path.isfile(os.path.realpath(possible_file)):
183-
args.file = possible_file
180+
file = possible_file
184181
else:
185182
t = re.match(r't=([\.\d]+)', ins)
186183
if t is not None:
187184
t = t.groups()[0]
188-
args.temperature = float(t)+1e-8
189-
print('new temperature: T=%.3f' % args.temperature)
185+
model.args.temperature = float(t)+1e-8
186+
print('new temperature: T=%.3f' % model.args.temperature)
190187
continue
191188
try:
192189
img = None
193-
if args.file:
194-
img = Image.open(args.file)
190+
if file:
191+
img = Image.open(file)
195192
else:
196193
try:
197194
img = ImageGrab.grabclipboard()
198195
except:
199196
pass
200-
pred = call_model(args, *objs, img=img)
201-
output_prediction(pred, args)
197+
pred = model(img)
198+
output_prediction(pred, model.args)
202199
except KeyboardInterrupt:
203200
pass
204-
args.file = None
201+
file = None
205202

206203

207204
if __name__ == "__main__":

0 commit comments

Comments
 (0)