Skip to content

Commit 446fe07

Browse files
committed
实现det
1 parent f6cafd7 commit 446fe07

File tree

16 files changed

+693
-8
lines changed

16 files changed

+693
-8
lines changed

OcrLibrary/build.gradle

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
plugins {
22
id 'com.android.library'
33
id 'org.jetbrains.kotlin.android'
4+
id 'kotlin-parcelize'
45
}
56

67
android {
@@ -47,8 +48,13 @@ dependencies {
4748
implementation "androidx.core:core-ktx:$core_version"
4849
implementation 'androidx.appcompat:appcompat:1.5.1'
4950
implementation 'com.google.android.material:material:1.7.0'
51+
//Logger
52+
implementation 'com.orhanobut:logger:2.2.0'
53+
//Clipper
54+
//implementation 'de.lighti:Clipper:6.4.2'
5055
//onnxruntime
51-
implementation 'com.microsoft.onnxruntime:onnxruntime:1.13.1'
56+
//implementation 'com.microsoft.onnxruntime:onnxruntime:1.13.1'
57+
implementation 'com.microsoft.onnxruntime:onnxruntime-android:1.13.1'
5258
//opencv
5359
implementation project(':opencv')
5460
}

OcrLibrary/libs/Clipper-6.4.2.jar

69.2 KB
Binary file not shown.
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
package com.benjaminwan.ocrlibrary
2+
3+
import ai.onnxruntime.OrtEnvironment
4+
import android.content.res.AssetManager
5+
6+
class Cls(private val ortEnv: OrtEnvironment, assetManager: AssetManager, modelName: String) {
7+
8+
private val clsSession by lazy {
9+
val model = assetManager.open(modelName, AssetManager.ACCESS_UNKNOWN).readBytes()
10+
ortEnv.createSession(model)
11+
}
12+
13+
14+
}
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package com.benjaminwan.ocrlibrary
2+
3+
import ai.onnxruntime.OnnxTensor
4+
import ai.onnxruntime.OrtEnvironment
5+
import ai.onnxruntime.TensorInfo
6+
import android.content.res.AssetManager
7+
import com.benjaminwan.ocrlibrary.models.DetPoint
8+
import com.benjaminwan.ocrlibrary.models.DetResult
9+
import com.benjaminwan.ocrlibrary.models.ScaleParam
10+
import org.opencv.core.*
11+
import org.opencv.imgproc.Imgproc
12+
import java.util.*
13+
14+
class Det(private val ortEnv: OrtEnvironment, assetManager: AssetManager, modelName: String) {
15+
16+
private val detSession by lazy {
17+
val model = assetManager.open(modelName, AssetManager.ACCESS_UNKNOWN).readBytes()
18+
ortEnv.createSession(model)
19+
}
20+
21+
private val meanValues = floatArrayOf(0.485F * 255, 0.456F * 255, 0.406F * 255)
22+
23+
private val normValues = floatArrayOf(1.0F / 0.229F / 255.0F, 1.0F / 0.224F / 255.0F, 1.0F / 0.225F / 255.0F)
24+
25+
fun getDetResults(src: Mat, s: ScaleParam, boxScoreThresh: Float, boxThresh: Float, unClipRatio: Float): List<DetResult> {
26+
val srcResize = Mat()
27+
Imgproc.resize(src, srcResize, Size(s.dstWidth.toDouble(), s.dstHeight.toDouble()))
28+
29+
val inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues)
30+
val inputShape = longArrayOf(1, srcResize.channels().toLong(), srcResize.rows().toLong(), srcResize.cols().toLong())
31+
val inputName = detSession.inputNames.iterator().next()
32+
33+
OnnxTensor.createTensor(ortEnv, inputTensorValues, inputShape).use { inputTensor ->
34+
detSession.run(Collections.singletonMap(inputName, inputTensor)).use { output ->
35+
val onnxValue = output.first().value
36+
val tensorInfo = onnxValue.info as TensorInfo
37+
/*val type = onnxValue.type
38+
Logger.i("info=${tensorInfo},type=$type")*/
39+
val values = onnxValue.value as Array<Array<Array<FloatArray>>>
40+
val outputData = values.flatMap { a ->
41+
a.flatMap { b ->
42+
b.flatMap { c ->
43+
c.flatMap {
44+
listOf(it)
45+
}
46+
}
47+
}
48+
}
49+
val outHeight: Int = tensorInfo.shape[2].toInt()
50+
val outWidth: Int = tensorInfo.shape[3].toInt()
51+
//-----Data preparation-----
52+
val predData = outputData.toFloatArray()
53+
val cbufData = outputData.map { (it * 255).toInt().toUByte() }.toUByteArray()
54+
55+
val predMat = Mat(outHeight, outWidth, CvType.CV_32F)
56+
predMat.put(0, 0, predData)
57+
58+
val cBufMat = Mat(outHeight, outWidth, CvType.CV_8UC1)
59+
cBufMat.put(0, 0, cbufData)
60+
61+
//-----boxThresh-----
62+
val thresholdMat = Mat()
63+
Imgproc.threshold(cBufMat, thresholdMat, boxThresh * 255.0, 255.0, Imgproc.THRESH_BINARY)
64+
65+
//-----dilate-----
66+
val dilateMat = Mat()
67+
val dilateElement = Imgproc.getStructuringElement(Imgproc.MORPH_RECT, Size(2.0, 2.0))
68+
Imgproc.dilate(thresholdMat, dilateMat, dilateElement)
69+
70+
return findRsBoxes(predMat, dilateMat, s, boxScoreThresh, unClipRatio)
71+
}
72+
}
73+
74+
}
75+
76+
private fun findRsBoxes(predMat: Mat, dilateMat: Mat, s: ScaleParam, boxScoreThresh: Float, unClipRatio: Float): List<DetResult> {
77+
val longSideThresh = 3//minBox 长边门限
78+
val maxCandidates = 1000
79+
80+
val contours: MutableList<MatOfPoint> = mutableListOf()
81+
val hierarchy = Mat()
82+
83+
Imgproc.findContours(dilateMat, contours, hierarchy, Imgproc.RETR_LIST, Imgproc.CHAIN_APPROX_SIMPLE)
84+
85+
val numContours = if (contours.size >= maxCandidates) maxCandidates else contours.size
86+
//Logger.i("numContours=$numContours")
87+
val rsBoxes: MutableList<DetResult> = mutableListOf()
88+
89+
for (i in (0 until numContours)) {
90+
//Logger.i("contours[$i]=${contours[i]}}")
91+
//Logger.i("total=${contours[i].total()},elemSize=${contours[i].elemSize()}")
92+
if (contours[i].elemSize() <= 2) {
93+
continue
94+
}
95+
val minAreaRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[i].toArray()))
96+
//Logger.i("minAreaRect[$i]=${minAreaRect}")
97+
val minBoxes: Array<Point> = Array(4) {
98+
Point()
99+
}
100+
//Logger.i("minBoxes1=${minBoxes.contentToString()}")
101+
val longSide: Float = getMinBoxes(minAreaRect, minBoxes)
102+
//Logger.i("longSide[$i]=$longSide")
103+
//Logger.i("minBoxes[$i]=${minBoxes.contentToString()}")
104+
if (longSide < longSideThresh) {
105+
continue
106+
}
107+
108+
val boxScore: Float = boxScoreFast(minBoxes, predMat)
109+
//Logger.i("boxScore[$i]=${boxScore}")
110+
if (boxScore < boxScoreThresh)
111+
continue
112+
//-----unClip-----
113+
val clipRect: RotatedRect = unClip(minBoxes, unClipRatio)
114+
//Logger.i("clipRect[$i]=${clipRect}")
115+
if (clipRect.size.height < 1.001 && clipRect.size.width < 1.001) {
116+
continue
117+
}
118+
//-----unClip-----
119+
val clipMinBoxes: Array<Point> = Array(4) {
120+
Point()
121+
}
122+
val clipLongSide = getMinBoxes(clipRect, clipMinBoxes)
123+
//Logger.i("clipLongSide[$i]=$clipLongSide")
124+
//Logger.i("clipMinBoxes[$i]=${clipMinBoxes.contentToString()}")
125+
if (clipLongSide < longSideThresh + 2)
126+
continue
127+
128+
val intClipMinBoxes = clipMinBoxes.map { point ->
129+
val x = point.x / s.ratioWidth
130+
val y = point.y / s.ratioHeight
131+
val ptX = Math.min(Math.max(x.toInt(), 0), s.srcWidth - 1)
132+
val ptY = Math.min(Math.max(y.toInt(), 0), s.srcHeight - 1)
133+
DetPoint(ptX, ptY)
134+
}
135+
rsBoxes.add(DetResult(intClipMinBoxes, boxScore))
136+
//Logger.i("rsBoxes[$i]=${rsBoxes[i]}")
137+
}
138+
return rsBoxes.asReversed()
139+
}
140+
141+
}
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package com.benjaminwan.ocrlibrary
2+
3+
import ai.onnxruntime.OrtEnvironment
4+
import android.content.Context
5+
import android.content.res.AssetManager
6+
import android.graphics.Bitmap
7+
import com.benjaminwan.ocrlibrary.models.OcrResult
8+
import com.benjaminwan.ocrlibrary.models.ScaleParam
9+
import com.orhanobut.logger.Logger
10+
import org.opencv.android.OpenCVLoader
11+
import org.opencv.android.Utils
12+
import org.opencv.android.Utils.matToBitmap
13+
import org.opencv.core.CvType
14+
import org.opencv.core.Mat
15+
import org.opencv.core.Rect
16+
import org.opencv.imgproc.Imgproc.*
17+
import java.io.Closeable
18+
import java.lang.Integer.max
19+
20+
class OcrEngine(context: Context) : Closeable {
21+
22+
private val assetManager: AssetManager = context.assets
23+
24+
private val ortEnv by lazy { OrtEnvironment.getEnvironment() }
25+
26+
private val det by lazy {
27+
Det(ortEnv, assetManager, DET_NAME)
28+
}
29+
30+
private val cls by lazy {
31+
Cls(ortEnv, assetManager, CLS_NAME)
32+
}
33+
34+
private val rec by lazy {
35+
Rec(ortEnv, assetManager, REC_NAME)
36+
}
37+
38+
init {
39+
if (OpenCVLoader.initDebug()) {
40+
Logger.i("OpenCV library found inside package.")
41+
} else {
42+
Logger.e("Internal OpenCV library not found.")
43+
throw UnsatisfiedLinkError("Internal OpenCV library not found.")
44+
}
45+
}
46+
47+
override fun close() {
48+
ortEnv.close()
49+
}
50+
51+
fun detect(
52+
bmp: Bitmap,
53+
maxSideLen: Int,
54+
padding: Int,
55+
boxScoreThresh: Float,
56+
boxThresh: Float,
57+
unClipRatio: Float,
58+
doAngle: Boolean,
59+
mostAngle: Boolean
60+
): OcrResult {
61+
Logger.i("padding($padding),maxSideLen($maxSideLen),boxScoreThresh($boxScoreThresh),boxThresh($boxThresh),unClipRatio($unClipRatio),doAngle($doAngle),mostAngle($mostAngle)")
62+
val imgRGBA = Mat(bmp.width, bmp.height, CvType.CV_8UC4)
63+
Utils.bitmapToMat(bmp, imgRGBA)
64+
val imgBGR = Mat()
65+
cvtColor(imgRGBA, imgBGR, COLOR_RGBA2BGR)
66+
val originMaxSide = max(imgBGR.cols(), imgBGR.rows())
67+
var resize: Int = if (maxSideLen <= 0 || maxSideLen > originMaxSide) {
68+
originMaxSide
69+
} else {
70+
maxSideLen
71+
}
72+
resize += 2 * padding
73+
val paddingRect = Rect(padding, padding, imgBGR.cols(), imgBGR.rows())
74+
val paddingSrc = makePadding(imgBGR, padding)
75+
val s = getScaleParam(paddingSrc, resize)
76+
//按比例缩小图像,减少文字分割时间
77+
Logger.i("$s")
78+
val ocrResult = detect(paddingSrc, paddingRect, s, boxScoreThresh, boxThresh, unClipRatio, doAngle, mostAngle)
79+
80+
return ocrResult
81+
}
82+
83+
private fun detect(
84+
src: Mat,
85+
paddingRect: Rect,
86+
s: ScaleParam,
87+
boxScoreThresh: Float,
88+
boxThresh: Float,
89+
unClipRatio: Float,
90+
doAngle: Boolean,
91+
mostAngle: Boolean
92+
): OcrResult {
93+
val textBoxPaddingImg = src.clone()
94+
val thickness = getThickness(src)
95+
Logger.i("=====Start detect=====")
96+
97+
Logger.i("---------- step: getDetResults ----------")
98+
val detResults = det.getDetResults(src, s, boxScoreThresh, boxThresh, unClipRatio)
99+
Logger.i("$detResults")
100+
101+
Logger.i("---------- step: drawTextBoxes ----------")
102+
drawTextBoxes(textBoxPaddingImg, detResults, thickness)
103+
104+
Logger.i("---------- step: getPartImages ----------")
105+
106+
Logger.i("---------- step: getClsResults ----------")
107+
108+
Logger.i("---------- step: Rotate partImages ----------")
109+
110+
111+
val outRGBA = Mat()
112+
cvtColor(textBoxPaddingImg, outRGBA, COLOR_BGR2RGBA)
113+
val outputImg = Bitmap.createBitmap(
114+
outRGBA.cols(), outRGBA.rows(), Bitmap.Config.ARGB_8888
115+
)
116+
matToBitmap(outRGBA, outputImg)
117+
118+
return OcrResult(detResults, emptyList(), emptyList(), outputImg)
119+
}
120+
121+
122+
companion object {
123+
private const val DET_NAME = "ch_PP-OCRv3_det_infer.onnx"
124+
private const val CLS_NAME = "ch_ppocr_mobile_v2.0_cls_infer.onnx"
125+
private const val REC_NAME = "ch_PP-OCRv3_rec_infer.onnx"
126+
private const val KEYS_NAME = "ppocr_keys_v1.txt"
127+
}
128+
129+
130+
}

0 commit comments

Comments
 (0)