Skip to content

Commit 18c87f6

Browse files
authored
Add files via upload
1 parent bd4017e commit 18c87f6

File tree

2 files changed

+155
-0
lines changed

2 files changed

+155
-0
lines changed

model.cpp

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
//
2+
// Created by icys on 25-2-10.
3+
//
4+
5+
#include "model.h"
6+
7+
namespace NBCapture {
8+
LatexOcr::LatexOcr() {
9+
// TODO: Replace ugly lower op with InstanceNorm to speed up
10+
11+
decoder.load_param("./assets/Simple-LaTeX-OCR_Decoder_fp16.param");
12+
decoder.load_model("./assets/Simple-LaTeX-OCR_Decoder_fp16.bin");
13+
14+
encoder.load_param("./assets/Simple-LaTeX-OCR_Encoder_fp16.param");
15+
encoder.load_model("./assets/Simple-LaTeX-OCR_Encoder_fp16.bin");
16+
17+
// simple_latex_ocr_vocab.txt
18+
std::ifstream ifs("./assets/simple_latex_ocr_vocab.txt");
19+
std::string line;
20+
while (std::getline(ifs, line)) {
21+
token_list.push_back(line);
22+
}
23+
}
24+
25+
26+
std::string post_process(std::string str) {
27+
// 替换 Ġ Ċ 为 ' '
28+
// 正则替换
29+
str = std::regex_replace(str, std::regex("Ġ"), " ");
30+
str = std::regex_replace(str, std::regex("Ċ"), " ");
31+
32+
return str;
33+
}
34+
35+
std::string LatexOcr::forward(const cv::Mat& image) {
36+
cv::Mat gray;
37+
cv::cvtColor(image, gray, cv::COLOR_BGR2GRAY);
38+
// Canny
39+
// cv::Canny(gray, gray, 50, 150, 3);
40+
// 转白底黑字
41+
// gray = 255 - gray;
42+
43+
float scale = 1.f;
44+
int target_width, target_height;
45+
if (image.cols * 128.0f / 640.0f > image.rows) {
46+
scale = 640 / static_cast<float>(image.cols);
47+
target_width = 640;
48+
target_height = static_cast<int>(image.rows * scale);
49+
} else {
50+
scale = 128 / static_cast<float>(image.rows);
51+
target_height = 128;
52+
target_width = static_cast<int>(image.cols * scale);
53+
}
54+
55+
int wpad = (640 - target_width) / 2;
56+
int hpad = (128 - target_height) / 2;
57+
58+
// resize
59+
cv::resize(gray, gray, cv::Size(target_width, target_height));
60+
const int pad_color = 114;
61+
62+
cv::Mat pad_img = cv::Mat(128, 640, CV_8UC1, cv::Scalar(pad_color));
63+
gray.copyTo(pad_img(cv::Rect(wpad, hpad, target_width, target_height)));
64+
65+
auto in = ncnn::Mat::from_pixels(pad_img.data, ncnn::Mat::PIXEL_GRAY, 640, 128);
66+
const float mean_vals[1] = {0.7931 * 255};
67+
const float norm_vals[1] = {1.0 / 0.1738 / 255.0};
68+
in.substract_mean_normalize(mean_vals, norm_vals);
69+
70+
auto ex = encoder.create_extractor();
71+
ex.input("in0", in.clone());
72+
ncnn::Mat feat;
73+
ex.extract("out0", feat);
74+
75+
const int max_step = 1024;
76+
int step = 0;
77+
78+
// output, feat, mask, pos
79+
80+
std::vector<int32_t> output;
81+
output.push_back(1); // <sos>
82+
std::vector<int32_t> pos;
83+
84+
while (step < max_step) {
85+
auto ex2 = decoder.create_extractor();
86+
pos.push_back(step);
87+
88+
ncnn::Mat mask(step+1,step+1,1);
89+
mask.fill(0.0f);
90+
for (int i = 0; i < step+1; i++) {
91+
for (int j = i + 1; j < step+1; j++) {
92+
mask.row(i)[j] = -1e30f;
93+
}
94+
}
95+
96+
ex2.input("in0", ncnn::Mat(output.size(),output.data()).clone());
97+
ex2.input("in1", feat.clone());
98+
ex2.input("in2", mask.clone());
99+
ex2.input("in3", ncnn::Mat(pos.size(),pos.data()).clone());
100+
101+
ncnn::Mat out0;
102+
ex2.extract("out0", out0);
103+
104+
const int len_token = 1200;
105+
106+
output.resize(step+2);
107+
for (int i = 0; i < step+1; i++) {
108+
int maxarg = 0;
109+
float maxval = -1e30f;
110+
for (int j = 0; j < len_token; j++) {
111+
if (out0.row(i)[j] > maxval) {
112+
maxval = out0.row(i)[j];
113+
maxarg = j;
114+
}
115+
}
116+
output[i+1] = maxarg;
117+
118+
if (maxarg == 2) {
119+
output.pop_back();
120+
goto Over;
121+
}
122+
}
123+
step++;
124+
}
125+
Over:
126+
std::string ret;
127+
for (int i = 1; i < output.size(); i++) {
128+
ret += token_list[output[i]];
129+
}
130+
return post_process(ret);
131+
132+
}
133+
} // NBCapture

model.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//
2+
// Created by icys on 25-2-10.
3+
//
4+
5+
#ifndef MODEL_Latex_H
6+
#define MODEL_Latex_H
7+
8+
namespace NBCapture {
9+
10+
class LatexOcr {
11+
public:
12+
LatexOcr();
13+
std::string forward(const cv::Mat& image);
14+
15+
ncnn::Net decoder;
16+
ncnn::Net encoder;
17+
std::vector<std::string> token_list;
18+
};
19+
20+
} // NBCapture
21+
22+
#endif //MODEL_Latex_H

0 commit comments

Comments
 (0)