1
+ package com.benjaminwan.ocrlibrary
2
+
3
+ import ai.onnxruntime.OnnxTensor
4
+ import ai.onnxruntime.OrtEnvironment
5
+ import ai.onnxruntime.TensorInfo
6
+ import android.content.res.AssetManager
7
+ import com.benjaminwan.ocrlibrary.models.DetPoint
8
+ import com.benjaminwan.ocrlibrary.models.DetResult
9
+ import com.benjaminwan.ocrlibrary.models.ScaleParam
10
+ import org.opencv.core.*
11
+ import org.opencv.imgproc.Imgproc
12
+ import java.util.*
13
+
14
+ class Det (private val ortEnv : OrtEnvironment , assetManager : AssetManager , modelName : String ) {
15
+
16
+ private val detSession by lazy {
17
+ val model = assetManager.open(modelName, AssetManager .ACCESS_UNKNOWN ).readBytes()
18
+ ortEnv.createSession(model)
19
+ }
20
+
21
+ private val meanValues = floatArrayOf(0.485F * 255 , 0.456F * 255 , 0.406F * 255 )
22
+
23
+ private val normValues = floatArrayOf(1.0F / 0.229F / 255.0F , 1.0F / 0.224F / 255.0F , 1.0F / 0.225F / 255.0F )
24
+
25
+ fun getDetResults (src : Mat , s : ScaleParam , boxScoreThresh : Float , boxThresh : Float , unClipRatio : Float ): List <DetResult > {
26
+ val srcResize = Mat ()
27
+ Imgproc .resize(src, srcResize, Size (s.dstWidth.toDouble(), s.dstHeight.toDouble()))
28
+
29
+ val inputTensorValues = substractMeanNormalize(srcResize, meanValues, normValues)
30
+ val inputShape = longArrayOf(1 , srcResize.channels().toLong(), srcResize.rows().toLong(), srcResize.cols().toLong())
31
+ val inputName = detSession.inputNames.iterator().next()
32
+
33
+ OnnxTensor .createTensor(ortEnv, inputTensorValues, inputShape).use { inputTensor ->
34
+ detSession.run (Collections .singletonMap(inputName, inputTensor)).use { output ->
35
+ val onnxValue = output.first().value
36
+ val tensorInfo = onnxValue.info as TensorInfo
37
+ /* val type = onnxValue.type
38
+ Logger.i("info=${tensorInfo},type=$type")*/
39
+ val values = onnxValue.value as Array <Array <Array <FloatArray >>>
40
+ val outputData = values.flatMap { a ->
41
+ a.flatMap { b ->
42
+ b.flatMap { c ->
43
+ c.flatMap {
44
+ listOf (it)
45
+ }
46
+ }
47
+ }
48
+ }
49
+ val outHeight: Int = tensorInfo.shape[2 ].toInt()
50
+ val outWidth: Int = tensorInfo.shape[3 ].toInt()
51
+ // -----Data preparation-----
52
+ val predData = outputData.toFloatArray()
53
+ val cbufData = outputData.map { (it * 255 ).toInt().toUByte() }.toUByteArray()
54
+
55
+ val predMat = Mat (outHeight, outWidth, CvType .CV_32F )
56
+ predMat.put(0 , 0 , predData)
57
+
58
+ val cBufMat = Mat (outHeight, outWidth, CvType .CV_8UC1 )
59
+ cBufMat.put(0 , 0 , cbufData)
60
+
61
+ // -----boxThresh-----
62
+ val thresholdMat = Mat ()
63
+ Imgproc .threshold(cBufMat, thresholdMat, boxThresh * 255.0 , 255.0 , Imgproc .THRESH_BINARY )
64
+
65
+ // -----dilate-----
66
+ val dilateMat = Mat ()
67
+ val dilateElement = Imgproc .getStructuringElement(Imgproc .MORPH_RECT , Size (2.0 , 2.0 ))
68
+ Imgproc .dilate(thresholdMat, dilateMat, dilateElement)
69
+
70
+ return findRsBoxes(predMat, dilateMat, s, boxScoreThresh, unClipRatio)
71
+ }
72
+ }
73
+
74
+ }
75
+
76
+ private fun findRsBoxes (predMat : Mat , dilateMat : Mat , s : ScaleParam , boxScoreThresh : Float , unClipRatio : Float ): List <DetResult > {
77
+ val longSideThresh = 3 // minBox 长边门限
78
+ val maxCandidates = 1000
79
+
80
+ val contours: MutableList <MatOfPoint > = mutableListOf ()
81
+ val hierarchy = Mat ()
82
+
83
+ Imgproc .findContours(dilateMat, contours, hierarchy, Imgproc .RETR_LIST , Imgproc .CHAIN_APPROX_SIMPLE )
84
+
85
+ val numContours = if (contours.size >= maxCandidates) maxCandidates else contours.size
86
+ // Logger.i("numContours=$numContours")
87
+ val rsBoxes: MutableList <DetResult > = mutableListOf ()
88
+
89
+ for (i in (0 until numContours)) {
90
+ // Logger.i("contours[$i]=${contours[i]}}")
91
+ // Logger.i("total=${contours[i].total()},elemSize=${contours[i].elemSize()}")
92
+ if (contours[i].elemSize() <= 2 ) {
93
+ continue
94
+ }
95
+ val minAreaRect = Imgproc .minAreaRect(MatOfPoint2f (* contours[i].toArray()))
96
+ // Logger.i("minAreaRect[$i]=${minAreaRect}")
97
+ val minBoxes: Array <Point > = Array (4 ) {
98
+ Point ()
99
+ }
100
+ // Logger.i("minBoxes1=${minBoxes.contentToString()}")
101
+ val longSide: Float = getMinBoxes(minAreaRect, minBoxes)
102
+ // Logger.i("longSide[$i]=$longSide")
103
+ // Logger.i("minBoxes[$i]=${minBoxes.contentToString()}")
104
+ if (longSide < longSideThresh) {
105
+ continue
106
+ }
107
+
108
+ val boxScore: Float = boxScoreFast(minBoxes, predMat)
109
+ // Logger.i("boxScore[$i]=${boxScore}")
110
+ if (boxScore < boxScoreThresh)
111
+ continue
112
+ // -----unClip-----
113
+ val clipRect: RotatedRect = unClip(minBoxes, unClipRatio)
114
+ // Logger.i("clipRect[$i]=${clipRect}")
115
+ if (clipRect.size.height < 1.001 && clipRect.size.width < 1.001 ) {
116
+ continue
117
+ }
118
+ // -----unClip-----
119
+ val clipMinBoxes: Array <Point > = Array (4 ) {
120
+ Point ()
121
+ }
122
+ val clipLongSide = getMinBoxes(clipRect, clipMinBoxes)
123
+ // Logger.i("clipLongSide[$i]=$clipLongSide")
124
+ // Logger.i("clipMinBoxes[$i]=${clipMinBoxes.contentToString()}")
125
+ if (clipLongSide < longSideThresh + 2 )
126
+ continue
127
+
128
+ val intClipMinBoxes = clipMinBoxes.map { point ->
129
+ val x = point.x / s.ratioWidth
130
+ val y = point.y / s.ratioHeight
131
+ val ptX = Math .min(Math .max(x.toInt(), 0 ), s.srcWidth - 1 )
132
+ val ptY = Math .min(Math .max(y.toInt(), 0 ), s.srcHeight - 1 )
133
+ DetPoint (ptX, ptY)
134
+ }
135
+ rsBoxes.add(DetResult (intClipMinBoxes, boxScore))
136
+ // Logger.i("rsBoxes[$i]=${rsBoxes[i]}")
137
+ }
138
+ return rsBoxes.asReversed()
139
+ }
140
+
141
+ }
0 commit comments