Skip to content

Commit c05a148

Browse files
committed
Repaired paths.
1 parent 878e449 commit c05a148

File tree

5 files changed

+25
-12
lines changed

5 files changed

+25
-12
lines changed

eval_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,6 @@ def euclidean_distance(vects):
6868

6969
model = Model([input_anchor, input_positive, input_negative], [net_anchor,net_positive, net_negative], name='gen')
7070

71-
model.load_weights('./model.h5')
71+
model.load_weights('./model_weights.h5')
7272

7373
evaluate.test(base_model)

instructions

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,25 @@ Then run
77
./train_test_split.sh
88
to split the dataset into training and testing parts
99

10-
Alternatively, you can skip previous steps and download prepared toy data from http://www.stud.fit.vutbr.cz/~xjonjo00/data_small.tar.gz
10+
Alternatively, you can skip previous steps and download prepared toy data (small size) from http://www.stud.fit.vutbr.cz/~xjonjo00/data_small.tar.gz
1111

12+
Now you can run training:
1213

13-
Now you can launch training with
14-
python3 resnet.py
14+
To train with hard mining (web api available):
1515

16-
after the model is trained and saved, you can run the sample website:
17-
FLASK_APP=web_pova.py flask run
16+
launch training with
17+
python3 resnet.py
18+
19+
after the model is trained and saved, you can run the sample website:
20+
FLASK_APP=web_pova.py flask run
21+
22+
To train without hard mining (web api not available):
23+
24+
launch training with
25+
python3 model_noHardMin.py
26+
27+
To evaluate saved trained model run:
28+
29+
python3 eval_results.py
1830

1931
Requirements: Keras, tensorflow, Flask

model_noHardMin.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from triplets_generator_noHardMin import DataGenerator
1212
import evaluate
1313

14+
datapath="./"
15+
1416
def l2Norm(x):
1517
return K.l2_normalize(x, axis=-1)
1618

@@ -87,14 +89,14 @@ def mean_neg_dist(_, y_pred):
8789

8890
""" Training """
8991
batch_size = 5
90-
training_generator = DataGenerator(dim_x=224, dim_y=224, batch_size=batch_size, dataset_path='./20_classes').generate()
92+
training_generator = DataGenerator(dim_x=224, dim_y=224, batch_size=batch_size, dataset_path=os.path.join(datapath,'data_train5')).generate()
9193
#validation_generator = DataGenerator(dim_x = 224, dim_y = 224, batch_size = batch_size, dataset_path = './places365-dataset/20_classes').generate()
9294

9395
opt = optimizers.Adam(lr=0.0005)
9496
model.compile(loss=triplet_loss, optimizer=opt, metrics=[accuracy, mean_pos_dist, mean_neg_dist])
9597

9698
model.fit_generator(generator = training_generator,
97-
steps_per_epoch = 248300//batch_size,
99+
steps_per_epoch = 24500//batch_size,
98100
epochs = 1)
99101

100102
model.save_weights("model_1epoch.h5")
@@ -105,14 +107,14 @@ def mean_neg_dist(_, y_pred):
105107
opt = optimizers.Adam(lr=0.00004)
106108
model.compile(loss=triplet_loss, optimizer=opt, metrics=[accuracy, mean_pos_dist, mean_neg_dist])
107109
model.fit_generator(generator = training_generator,
108-
steps_per_epoch = 248300//batch_size,
110+
steps_per_epoch = 24500//batch_size,
109111
epochs = 1)
110112

111113
# serialize model to JSON
112114
model_json = model.to_json()
113115
with open("model.json", "w") as json_file:
114116
json_file.write(model_json)
115117
# serialize weights to HDF5
116-
model.save_weights("model.h5")
118+
model.save_weights("model_weights.h5")
117119
print("Saved model to disk")
118120

resnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def fake_loss(__,_):
137137
# serialize weights to HDF5
138138
base_model.save(os.path.join(datapath,"model.h5"))
139139
model.save(os.path.join(datapath,"model_triplets.h5"))
140+
model.save_weights('model_weights.h5')
140141

141142
print("Saved model to disk")
142143

triplets_generator_noHardMin.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
from keras.preprocessing import image
44
from keras.applications.resnet50 import preprocess_input
55

6-
import matplotlib.pyplot as plt
7-
86
class DataGenerator(object):
97
'Generates data for Keras'
108
def __init__(self, dim_x = 224, dim_y = 224, batch_size = 10, dataset_path = './places365-dataset/20_classes'):

0 commit comments

Comments
 (0)