2
2
'''
3
3
预处理数据, 封装成方便使用的数据集
4
4
提供随机batch功能(采用生产者消费者模式, 进行数据语预取, 随机出队列)
5
- 提供统一高度的图像, 作为crnn的输入
6
- 构建字库, 对label进行编码
7
- 记录log
5
+ 提供统一高度的图像, 作为crnn的输入; 图像标准化(暂时不确定, 没有进行标准化)
6
+ 构建字库, 对label进行编码(未实现)
7
+ 记录log(未实现)
8
8
'''
9
9
# import pandas as pd
10
10
import numpy as np
11
11
# import codecs
12
12
import os
13
13
import queue
14
14
import threading
15
- # import json
15
+ import random
16
+ import glob
16
17
18
+ from PIL import Image
17
19
18
- from utils import myThread , log
20
+ from utils import myThread , log , chdir
19
21
from parameters import RECORD_PATH , IMAGE_TRAIN_PATH , TXT_TRAIN_PATH , BATCH_SIZE
20
22
from record import recQueue , recQueueLock , divide_conquer , get_cropThreadCount
21
23
24
26
# fileQueue = queue.Queue()
25
27
# fileQueueLock = threading.Lock()
26
28
27
- class chdir ():
28
- def __init__ (self , newdir ):
29
- self ._olddir = os .getcwd ()
30
- self ._newdir = newdir
31
- def __enter__ (self ):
32
- os .chdir (self ._newdir )
33
- # print("enter work dir", self._newdir)
34
- def __exit__ (self , a , b , c ):
35
- os .chdir (self ._olddir )
36
- # print("exit work dir ", self._newdir)
29
+ # class chdir():
30
+ # def __init__(self, newdir):
31
+ # self._olddir = os.getcwd()
32
+ # self._newdir = newdir
33
+ # def __enter__(self):
34
+ # os.chdir(self._newdir)
35
+ # # print("enter work dir", self._newdir)
36
+ # def __exit__(self, a, b, c):
37
+ # os.chdir(self._olddir)
38
+ # # print("exit work dir ", self._newdir)
37
39
38
40
39
- class DataSet (object ):
41
+ class Consumer (object ):
40
42
@log ('call: ' )
41
43
def __init__ (self , recQueue , recQueueLock , epochs = 1 ):
42
44
# self._recFilePath = recFilePath
@@ -89,26 +91,66 @@ def read_record(self):
89
91
90
92
91
93
class DataSets (object ):
92
- def __init__ (self ):
94
+ def __init__ (self , filenames ):
95
+ self ._height = 32 #将图像高度统一为32个像素
96
+ self ._width = 128 #将图像宽度统一为100个像素
97
+ # self._train_test_ratio = 0.8
98
+ # self._datapath = datapath
99
+ self ._image_files = filenames
100
+ # self._valid_images = []
101
+ # self.train_valid_split()
93
102
self .__start_produce ()
94
103
95
104
def __start_produce (self ):
96
105
#启动图像裁剪线程
97
- divide_conquer ()
106
+ divide_conquer (self . _image_files )
98
107
99
108
def next_batch (self ):
100
109
#从工作队列recQueue取出裁剪好的图像和对应label, 大小为BATCH_SIZE, 定义在parameters.py
101
- images , labels = self .train .read_record ()
102
- while not images and not labels :
103
- if 0 == get_cropThreadCount ():
110
+ self . _images , self . _labels = self .train .read_record ()
111
+ while not self . _images and not self . _labels :
112
+ if 0 == get_cropThreadCount (): #查询是否已经停止裁剪图像
104
113
return {}, {}
105
- images , labels = self .train .read_record ()
106
- return images , labels
107
-
114
+ self ._images , self ._labels = self .train .read_record ()
115
+ # return self._images, self._labels
116
+ # self.writeimage(self._images, self._labels)
117
+ return self .resize_with_crop_pad (self ._images , self ._labels )
118
+
119
+ def resize_with_crop_pad (self , images , labels ):
120
+ result_images = []
121
+ result_labels = []
122
+ # images = self._images
123
+ #调整图像为统一高度, 满足crnn需要
124
+ i = 0
125
+ bad = []
126
+ for image in images :
127
+ try :
128
+ H = image .shape [0 ]
129
+ W = image .shape [1 ]
130
+ ratio = 32 / H
131
+ im = Image .fromarray (image .astype ('uint8' )).convert ('RGB' )
132
+ im = im .resize ((int (W * ratio ), 32 ), Image .BILINEAR )
133
+ result_images .append (np .array (im ))
134
+ result_labels .append (labels [i ])
135
+ except :
136
+ print ("failed resize" , image .shape )
137
+ im .save ('./test/resized/%s-%.4d.jpg' % (labels [i ], i ))
138
+ bad .append (i )
139
+ finally :
140
+ i += 1
141
+ return result_images , result_labels
142
+
143
+ def writeimage (self , images , labels ):
144
+ path = './test/origin/%s-%.4d.jpg'
145
+ i = 0
146
+ for image in images :
147
+ im = Image .fromarray (image .astype ('uint8' )).convert ('RGB' )
148
+ im .save (path % (labels [i ], i ))
149
+ i += 1
108
150
@log ()
109
- def read_data_sets ():
110
- data_sets = DataSets ()
111
- data_sets .train = DataSet (recQueue , recQueueLock , epochs = 1 )
151
+ def read_data_sets (filenames ):
152
+ data_sets = DataSets (filenames )
153
+ data_sets .train = Consumer (recQueue , recQueueLock , epochs = 1 )
112
154
return data_sets
113
155
114
156
# def next_batch(data_sets):
@@ -119,20 +161,54 @@ def read_data_sets():
119
161
# images, labels = data_sets.train.read_record()
120
162
# return images, labels
121
163
122
- if __name__ == "__main__" :
123
- # start_produce()
124
- data_sets = read_data_sets ()
164
+
165
+ def train_valid_split (datapath , ratio = 0.8 , shuffle = True ):
166
+ with chdir (datapath ) as ch :
167
+ # os.chdir(os.path.join(os.getcwd(), IMAGE_TRAIN_PATH)) #修改当前工作路径, 方便获取文件名
168
+ image_names_train = glob .glob ('*.jpg' ) #获取工作路径下所有jpg格式文件名到list中
169
+ # image_names_train = glob.glob(os.path.join(IMAGE_TRAIN_PATH, '*.jpg'))
170
+ #将数据集分割为训练集和验证集
171
+ random .shuffle (image_names_train )
172
+ mid = int (ratio * len (image_names_train ))
173
+ train_image_files = image_names_train [0 : mid ]
174
+ valid_image_files = image_names_train [mid : ]
175
+ return train_image_files , valid_image_files
176
+
177
+ def demo ():
178
+ #首先划分训练集和验证集
179
+ train_image_files , valid_image_files = train_valid_split (IMAGE_TRAIN_PATH , ratio = 0.7 )
180
+ print (len (train_image_files ))
181
+ print ('start trainning' )
182
+ data_sets = read_data_sets (train_image_files ) #开始读取图像数据
125
183
step = 0
184
+ #读取训练集并训练
126
185
while True :
127
186
images , labels = data_sets .next_batch ()
128
- if images and labels :
129
- print (step , len (images ), len (labels )) #可用于训练, images需要将height统一, labels需要进行编码
187
+ if images and labels : #如果为空, 表示数据已经循环一次
188
+ #train() #训练模型
189
+ print ("train batch: " , len (images ), len (labels ))
190
+ step += 1
191
+ else :
192
+ print ("over" )
193
+ break
194
+ #读取验证集并验证
195
+ print ('start validating' )
196
+ data_sets = read_data_sets (valid_image_files ) #开始读取图像数据
197
+ print (len (valid_image_files ))
198
+ step = 0
199
+ while True :
200
+ images_valid , labels_valid = data_sets .next_batch ()
201
+ if images_valid and labels_valid : #如果为空, 表示数据已经循环一次
202
+ #train() #训练模型
203
+ print ("valid batch: " , len (images_valid ), len (labels_valid ))
130
204
step += 1
131
205
else :
132
206
print ("over" )
133
207
break
134
208
135
209
210
+ if __name__ == "__main__" :
211
+ demo ()
136
212
137
213
138
214
0 commit comments