1
+ //
2
+ // Created by icys on 25-2-10.
3
+ //
4
+
5
+ #include " model.h"
6
+
7
+ namespace NBCapture {
8
+ LatexOcr::LatexOcr () {
9
+ // TODO: Replace ugly lower op with InstanceNorm to speed up
10
+
11
+ decoder.load_param (" ./assets/Simple-LaTeX-OCR_Decoder_fp16.param" );
12
+ decoder.load_model (" ./assets/Simple-LaTeX-OCR_Decoder_fp16.bin" );
13
+
14
+ encoder.load_param (" ./assets/Simple-LaTeX-OCR_Encoder_fp16.param" );
15
+ encoder.load_model (" ./assets/Simple-LaTeX-OCR_Encoder_fp16.bin" );
16
+
17
+ // simple_latex_ocr_vocab.txt
18
+ std::ifstream ifs (" ./assets/simple_latex_ocr_vocab.txt" );
19
+ std::string line;
20
+ while (std::getline (ifs, line)) {
21
+ token_list.push_back (line);
22
+ }
23
+ }
24
+
25
+
26
+ std::string post_process (std::string str) {
27
+ // 替换 Ġ Ċ 为 ' '
28
+ // 正则替换
29
+ str = std::regex_replace (str, std::regex (" Ġ" ), " " );
30
+ str = std::regex_replace (str, std::regex (" Ċ" ), " " );
31
+
32
+ return str;
33
+ }
34
+
35
+ std::string LatexOcr::forward (const cv::Mat& image) {
36
+ cv::Mat gray;
37
+ cv::cvtColor (image, gray, cv::COLOR_BGR2GRAY);
38
+ // Canny
39
+ // cv::Canny(gray, gray, 50, 150, 3);
40
+ // 转白底黑字
41
+ // gray = 255 - gray;
42
+
43
+ float scale = 1 .f ;
44
+ int target_width, target_height;
45
+ if (image.cols * 128 .0f / 640 .0f > image.rows ) {
46
+ scale = 640 / static_cast <float >(image.cols );
47
+ target_width = 640 ;
48
+ target_height = static_cast <int >(image.rows * scale);
49
+ } else {
50
+ scale = 128 / static_cast <float >(image.rows );
51
+ target_height = 128 ;
52
+ target_width = static_cast <int >(image.cols * scale);
53
+ }
54
+
55
+ int wpad = (640 - target_width) / 2 ;
56
+ int hpad = (128 - target_height) / 2 ;
57
+
58
+ // resize
59
+ cv::resize (gray, gray, cv::Size (target_width, target_height));
60
+ const int pad_color = 114 ;
61
+
62
+ cv::Mat pad_img = cv::Mat (128 , 640 , CV_8UC1, cv::Scalar (pad_color));
63
+ gray.copyTo (pad_img (cv::Rect (wpad, hpad, target_width, target_height)));
64
+
65
+ auto in = ncnn::Mat::from_pixels (pad_img.data , ncnn::Mat::PIXEL_GRAY, 640 , 128 );
66
+ const float mean_vals[1 ] = {0.7931 * 255 };
67
+ const float norm_vals[1 ] = {1.0 / 0.1738 / 255.0 };
68
+ in.substract_mean_normalize (mean_vals, norm_vals);
69
+
70
+ auto ex = encoder.create_extractor ();
71
+ ex.input (" in0" , in.clone ());
72
+ ncnn::Mat feat;
73
+ ex.extract (" out0" , feat);
74
+
75
+ const int max_step = 1024 ;
76
+ int step = 0 ;
77
+
78
+ // output, feat, mask, pos
79
+
80
+ std::vector<int32_t > output;
81
+ output.push_back (1 ); // <sos>
82
+ std::vector<int32_t > pos;
83
+
84
+ while (step < max_step) {
85
+ auto ex2 = decoder.create_extractor ();
86
+ pos.push_back (step);
87
+
88
+ ncnn::Mat mask (step+1 ,step+1 ,1 );
89
+ mask.fill (0 .0f );
90
+ for (int i = 0 ; i < step+1 ; i++) {
91
+ for (int j = i + 1 ; j < step+1 ; j++) {
92
+ mask.row (i)[j] = -1e30f;
93
+ }
94
+ }
95
+
96
+ ex2.input (" in0" , ncnn::Mat (output.size (),output.data ()).clone ());
97
+ ex2.input (" in1" , feat.clone ());
98
+ ex2.input (" in2" , mask.clone ());
99
+ ex2.input (" in3" , ncnn::Mat (pos.size (),pos.data ()).clone ());
100
+
101
+ ncnn::Mat out0;
102
+ ex2.extract (" out0" , out0);
103
+
104
+ const int len_token = 1200 ;
105
+
106
+ output.resize (step+2 );
107
+ for (int i = 0 ; i < step+1 ; i++) {
108
+ int maxarg = 0 ;
109
+ float maxval = -1e30f;
110
+ for (int j = 0 ; j < len_token; j++) {
111
+ if (out0.row (i)[j] > maxval) {
112
+ maxval = out0.row (i)[j];
113
+ maxarg = j;
114
+ }
115
+ }
116
+ output[i+1 ] = maxarg;
117
+
118
+ if (maxarg == 2 ) {
119
+ output.pop_back ();
120
+ goto Over;
121
+ }
122
+ }
123
+ step++;
124
+ }
125
+ Over:
126
+ std::string ret;
127
+ for (int i = 1 ; i < output.size (); i++) {
128
+ ret += token_list[output[i]];
129
+ }
130
+ return post_process (ret);
131
+
132
+ }
133
+ } // NBCapture
0 commit comments