Skip to content

Commit e2211dc

Browse files
committed
Update ia_engine.py
1 parent 5b5bc07 commit e2211dc

File tree

1 file changed

+11
-30
lines changed

1 file changed

+11
-30
lines changed

ia/ia_engine.py

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -193,35 +193,16 @@ def train(self, productName: str):
193193
'''
194194

195195
train_dir, validation_dir = self.__imageReader(productName)
196-
#model = self.__createModel()
196+
model = self.__createModel()
197197

198-
input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, 3))
199-
vggModel = tensorflow.keras.applications.vgg16.VGG16(include_top=False, input_tensor=input)
200-
print(type(vggModel))
201-
print(vggModel.summary())
202-
203-
204-
model = tensorflow.keras.Sequential()
205-
206-
for layer in vggModel.layers[0:-1]:
207-
model.add(layer)
208-
209-
for layer in model.layers:
210-
layer.trainable = False
211-
212-
213-
model.add(layers.Flatten(name="capa"))
214-
model.add(layers.Dense(10, activation="softmax"))
215-
print(model.summary())
216-
model.compile(optimizer=tensorflow.keras.optimizers.Adam(lr=0.001), loss='categorical_crossentropy', metrics=['accuracy'])
217-
218-
preprocessInput = tensorflow.keras.applications.vgg16.preprocess_input
219-
220-
filepath = (MODEL_PATH +"/"+ productName + ".h5")
221-
222-
train_datagen = ImageDataGenerator(preprocessing_function=preprocessInput)
223-
224-
validation_datagen = ImageDataGenerator(preprocessing_function=preprocessInput)
198+
train_datagen = ImageDataGenerator(rescale=1./255,
199+
rotation_range=20,
200+
horizontal_flip=True,
201+
width_shift_range=0.2,
202+
height_shift_range=0.2,
203+
shear_range=0.2,
204+
zoom_range=0.2)
205+
validation_datagen = ImageDataGenerator(rescale=1./255)
225206

226207
training_set = train_datagen.flow_from_directory(
227208
train_dir,
@@ -266,8 +247,8 @@ def predict(self, product: str, img: str):
266247

267248
model = tensorflow.keras.models.load_model(MODEL_PATH +"/"+ product +".h5")
268249
npimg = np.array(np.expand_dims(npimg, axis=0))
269-
270-
return model.predict( npimg )
250+
predict = model.predict( npimg )
251+
return np.argmax(predict)
271252

272253

273254
def __accuracyGraph (self, history):

0 commit comments

Comments
 (0)