@@ -193,35 +193,16 @@ def train(self, productName: str):
193
193
'''
194
194
195
195
train_dir , validation_dir = self .__imageReader (productName )
196
- # model = self.__createModel()
196
+ model = self .__createModel ()
197
197
198
- input = Input (shape = (IMG_HEIGHT , IMG_WIDTH , 3 ))
199
- vggModel = tensorflow .keras .applications .vgg16 .VGG16 (include_top = False , input_tensor = input )
200
- print (type (vggModel ))
201
- print (vggModel .summary ())
202
-
203
-
204
- model = tensorflow .keras .Sequential ()
205
-
206
- for layer in vggModel .layers [0 :- 1 ]:
207
- model .add (layer )
208
-
209
- for layer in model .layers :
210
- layer .trainable = False
211
-
212
-
213
- model .add (layers .Flatten (name = "capa" ))
214
- model .add (layers .Dense (10 , activation = "softmax" ))
215
- print (model .summary ())
216
- model .compile (optimizer = tensorflow .keras .optimizers .Adam (lr = 0.001 ), loss = 'categorical_crossentropy' , metrics = ['accuracy' ])
217
-
218
- preprocessInput = tensorflow .keras .applications .vgg16 .preprocess_input
219
-
220
- filepath = (MODEL_PATH + "/" + productName + ".h5" )
221
-
222
- train_datagen = ImageDataGenerator (preprocessing_function = preprocessInput )
223
-
224
- validation_datagen = ImageDataGenerator (preprocessing_function = preprocessInput )
198
+ train_datagen = ImageDataGenerator (rescale = 1. / 255 ,
199
+ rotation_range = 20 ,
200
+ horizontal_flip = True ,
201
+ width_shift_range = 0.2 ,
202
+ height_shift_range = 0.2 ,
203
+ shear_range = 0.2 ,
204
+ zoom_range = 0.2 )
205
+ validation_datagen = ImageDataGenerator (rescale = 1. / 255 )
225
206
226
207
training_set = train_datagen .flow_from_directory (
227
208
train_dir ,
@@ -266,8 +247,8 @@ def predict(self, product: str, img: str):
266
247
267
248
model = tensorflow .keras .models .load_model (MODEL_PATH + "/" + product + ".h5" )
268
249
npimg = np .array (np .expand_dims (npimg , axis = 0 ))
269
-
270
- return model . predict ( npimg )
250
+ predict = model . predict ( npimg )
251
+ return np . argmax ( predict )
271
252
272
253
273
254
def __accuracyGraph (self , history ):
0 commit comments