21
21
from pix2tex .utils import *
22
22
from pix2tex .model .checkpoints .get_latest_checkpoint import download_checkpoints
23
23
24
- last_pic = None
25
-
26
24
27
25
def minmax_size (img , max_dimensions = None , min_dimensions = None ):
28
26
if max_dimensions is not None :
@@ -40,79 +38,77 @@ def minmax_size(img, max_dimensions=None, min_dimensions=None):
40
38
return img
41
39
42
40
43
- @in_model_path ()
44
- def initialize (arguments = None ):
45
- if arguments is None :
46
- arguments = Munch ({'config' : 'settings/config.yaml' , 'checkpoint' : 'checkpoints/weights.pth' , 'no_cuda' : True , 'no_resize' : False })
47
- logging .getLogger ().setLevel (logging .FATAL )
48
- os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
49
- with open (arguments .config , 'r' ) as f :
50
- params = yaml .load (f , Loader = yaml .FullLoader )
51
- args = parse_args (Munch (params ))
52
- args .update (** vars (arguments ))
53
- args .wandb = False
54
- args .device = 'cuda' if torch .cuda .is_available () and not args .no_cuda else 'cpu'
55
- if not os .path .exists (args .checkpoint ):
56
- download_checkpoints ()
57
- model = get_model (args )
58
- model .load_state_dict (torch .load (args .checkpoint , map_location = args .device ))
59
-
60
- if 'image_resizer.pth' in os .listdir (os .path .dirname (args .checkpoint )) and not arguments .no_resize :
61
- image_resizer = ResNetV2 (layers = [2 , 3 , 3 ], num_classes = max (args .max_dimensions )// 32 , global_pool = 'avg' , in_chans = 1 , drop_rate = .05 ,
62
- preact = True , stem_type = 'same' , conv_layer = StdConv2dSame ).to (args .device )
63
- image_resizer .load_state_dict (torch .load (os .path .join (os .path .dirname (args .checkpoint ), 'image_resizer.pth' ), map_location = args .device ))
64
- image_resizer .eval ()
65
- else :
66
- image_resizer = None
67
- tokenizer = PreTrainedTokenizerFast (tokenizer_file = args .tokenizer )
68
- return args , model , image_resizer , tokenizer
69
-
70
-
71
- @in_model_path ()
72
- def call_model (args , model , image_resizer , tokenizer , img = None ):
73
- global last_pic
74
- encoder , decoder = model .encoder , model .decoder
75
- if type (img ) is bool :
76
- img = None
77
- if img is None :
78
- if last_pic is None :
79
- print ('Provide an image.' )
80
- return ''
41
+ class LatexOCR :
42
+ image_resizer = None
43
+ last_pic = None
44
+
45
+ @in_model_path ()
46
+ def __init__ (self , arguments = None ):
47
+ if arguments is None :
48
+ arguments = Munch ({'config' : 'settings/config.yaml' , 'checkpoint' : 'checkpoints/weights.pth' , 'no_cuda' : True , 'no_resize' : False })
49
+ logging .getLogger ().setLevel (logging .FATAL )
50
+ os .environ ['TF_CPP_MIN_LOG_LEVEL' ] = '3'
51
+ with open (arguments .config , 'r' ) as f :
52
+ params = yaml .load (f , Loader = yaml .FullLoader )
53
+ self .args = parse_args (Munch (params ))
54
+ self .args .update (** vars (arguments ))
55
+ self .args .wandb = False
56
+ self .args .device = 'cuda' if torch .cuda .is_available () and not self .args .no_cuda else 'cpu'
57
+ if not os .path .exists (self .args .checkpoint ):
58
+ download_checkpoints ()
59
+ self .model = get_model (self .args )
60
+ self .model .load_state_dict (torch .load (self .args .checkpoint , map_location = self .args .device ))
61
+
62
+ if 'image_resizer.pth' in os .listdir (os .path .dirname (self .args .checkpoint )) and not arguments .no_resize :
63
+ self .image_resizer = ResNetV2 (layers = [2 , 3 , 3 ], num_classes = max (self .args .max_dimensions )// 32 , global_pool = 'avg' , in_chans = 1 , drop_rate = .05 ,
64
+ preact = True , stem_type = 'same' , conv_layer = StdConv2dSame ).to (self .args .device )
65
+ self .image_resizer .load_state_dict (torch .load (os .path .join (os .path .dirname (self .args .checkpoint ), 'image_resizer.pth' ), map_location = self .args .device ))
66
+ self .image_resizer .eval ()
67
+ self .tokenizer = PreTrainedTokenizerFast (tokenizer_file = self .args .tokenizer )
68
+
69
+ @in_model_path ()
70
+ def __call__ (self , img = None , resize = True ):
71
+ if type (img ) is bool :
72
+ img = None
73
+ if img is None :
74
+ if self .last_pic is None :
75
+ print ('Provide an image.' )
76
+ return ''
77
+ else :
78
+ img = self .last_pic .copy ()
79
+ else :
80
+ self .last_pic = img .copy ()
81
+ img = minmax_size (pad (img ), self .args .max_dimensions , self .args .min_dimensions )
82
+ if (self .image_resizer is not None and not self .args .no_resize ) and resize :
83
+ with torch .no_grad ():
84
+ input_image = img .convert ('RGB' ).copy ()
85
+ r , w , h = 1 , input_image .size [0 ], input_image .size [1 ]
86
+ for _ in range (10 ):
87
+ h = int (h * r ) # height to resize
88
+ img = pad (minmax_size (input_image .resize ((w , h ), Image .BILINEAR if r > 1 else Image .LANCZOS ), self .args .max_dimensions , self .args .min_dimensions ))
89
+ t = test_transform (image = np .array (img .convert ('RGB' )))['image' ][:1 ].unsqueeze (0 )
90
+ w = (self .image_resizer (t .to (self .args .device )).argmax (- 1 ).item ()+ 1 )* 32
91
+ logging .info (r , img .size , (w , int (input_image .size [1 ]* r )))
92
+ if (w == img .size [0 ]):
93
+ break
94
+ r = w / img .size [0 ]
81
95
else :
82
- img = last_pic .copy ()
83
- else :
84
- last_pic = img .copy ()
85
- img = minmax_size (pad (img ), args .max_dimensions , args .min_dimensions )
86
- if image_resizer is not None and not args .no_resize :
96
+ img = np .array (pad (img ).convert ('RGB' ))
97
+ t = test_transform (image = img )['image' ][:1 ].unsqueeze (0 )
98
+ im = t .to (self .args .device )
99
+
87
100
with torch .no_grad ():
88
- input_image = img .convert ('RGB' ).copy ()
89
- r , w , h = 1 , input_image .size [0 ], input_image .size [1 ]
90
- for _ in range (10 ):
91
- h = int (h * r ) # height to resize
92
- img = pad (minmax_size (input_image .resize ((w , h ), Image .BILINEAR if r > 1 else Image .LANCZOS ), args .max_dimensions , args .min_dimensions ))
93
- t = test_transform (image = np .array (img .convert ('RGB' )))['image' ][:1 ].unsqueeze (0 )
94
- w = (image_resizer (t .to (args .device )).argmax (- 1 ).item ()+ 1 )* 32
95
- logging .info (r , img .size , (w , int (input_image .size [1 ]* r )))
96
- if (w == img .size [0 ]):
97
- break
98
- r = w / img .size [0 ]
99
- else :
100
- img = np .array (pad (img ).convert ('RGB' ))
101
- t = test_transform (image = img )['image' ][:1 ].unsqueeze (0 )
102
- im = t .to (args .device )
103
-
104
- with torch .no_grad ():
105
- model .eval ()
106
- device = args .device
107
- encoded = encoder (im .to (device ))
108
- dec = decoder .generate (torch .LongTensor ([args .bos_token ])[:, None ].to (device ), args .max_seq_len ,
109
- eos_token = args .eos_token , context = encoded .detach (), temperature = args .get ('temperature' , .25 ))
110
- pred = post_process (token2str (dec , tokenizer )[0 ])
111
- try :
112
- clipboard .copy (pred )
113
- except :
114
- pass
115
- return pred
101
+ self .model .eval ()
102
+ device = self .args .device
103
+ encoded = self .model .encoder (im .to (device ))
104
+ dec = self .model .decoder .generate (torch .LongTensor ([self .args .bos_token ])[:, None ].to (device ), self .args .max_seq_len ,
105
+ eos_token = self .args .eos_token , context = encoded .detach (), temperature = self .args .get ('temperature' , .25 ))
106
+ pred = post_process (token2str (dec , self .tokenizer )[0 ])
107
+ try :
108
+ clipboard .copy (pred )
109
+ except :
110
+ pass
111
+ return pred
116
112
117
113
118
114
def output_prediction (pred , args ):
@@ -144,7 +140,8 @@ def main():
144
140
parser .add_argument ('--no-resize' , action = 'store_true' , help = 'Resize the image beforehand' )
145
141
arguments = parser .parse_args ()
146
142
with in_model_path ():
147
- args , * objs = initialize (arguments )
143
+ model = LatexOCR (arguments )
144
+ file = None
148
145
while True :
149
146
instructions = input ('Predict LaTeX code for image ("?"/"h" for help). ' )
150
147
possible_file = instructions .strip ()
@@ -176,32 +173,32 @@ def main():
176
173
''' )
177
174
continue
178
175
elif ins in ['show' , 'katex' , 'no_resize' ]:
179
- setattr (args , ins , not getattr (args , ins , False ))
180
- print ('set %s to %s' % (ins , getattr (args , ins )))
176
+ setattr (model . args , ins , not getattr (model . args , ins , False ))
177
+ print ('set %s to %s' % (ins , getattr (model . args , ins )))
181
178
continue
182
179
elif os .path .isfile (os .path .realpath (possible_file )):
183
- args . file = possible_file
180
+ file = possible_file
184
181
else :
185
182
t = re .match (r't=([\.\d]+)' , ins )
186
183
if t is not None :
187
184
t = t .groups ()[0 ]
188
- args .temperature = float (t )+ 1e-8
189
- print ('new temperature: T=%.3f' % args .temperature )
185
+ model . args .temperature = float (t )+ 1e-8
186
+ print ('new temperature: T=%.3f' % model . args .temperature )
190
187
continue
191
188
try :
192
189
img = None
193
- if args . file :
194
- img = Image .open (args . file )
190
+ if file :
191
+ img = Image .open (file )
195
192
else :
196
193
try :
197
194
img = ImageGrab .grabclipboard ()
198
195
except :
199
196
pass
200
- pred = call_model ( args , * objs , img = img )
201
- output_prediction (pred , args )
197
+ pred = model ( img )
198
+ output_prediction (pred , model . args )
202
199
except KeyboardInterrupt :
203
200
pass
204
- args . file = None
201
+ file = None
205
202
206
203
207
204
if __name__ == "__main__" :
0 commit comments