Skip to content

Commit 4948971

Browse files
committed
Feat: major functions
1. Build Siamese NIMA model network 2. Train and predict NIMA model network
1 parent d3ede10 commit 4948971

File tree

8 files changed

+597
-0
lines changed

8 files changed

+597
-0
lines changed

demo_predict.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
'''Predict with NIMA model network.'''
2+
3+
from model.siamese_nima import SiameseNIMA
4+
5+
6+
if __name__ == '__main__':
7+
# dirs and paths to load data
8+
predict_image_dir = './assets/demo/predict_images'
9+
predict_data_path = './assets/demo/predict_data.csv'
10+
11+
# load data and train model
12+
siamese = SiameseNIMA()
13+
predict_raw = siamese.load_data(predict_image_dir, predict_data_path,
14+
columns=['file_name'])
15+
results = siamese.predict(predict_raw,
16+
nima_weight_path='./assets/weights/nima_weights_pre_trained.h5')
17+
print(results)

demo_train.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
'''Train Siamese NIMA model networks.'''
2+
3+
from model.siamese_nima import SiameseNIMA
4+
5+
6+
if __name__ == '__main__':
7+
# dirs and paths to load data
8+
train_image_dir = './assets/demo/train_images'
9+
train_data_path = './assets/demo/train_data.csv'
10+
11+
# load data and train model
12+
siamese = SiameseNIMA(output_dir='./assets')
13+
train_raw = siamese.load_data(train_image_dir, train_data_path)
14+
siamese.train(train_raw,
15+
epochs=5,
16+
batch_size=16,
17+
nima_weight_path='./assets/weights/nima_weights_pre_trained.h5')

installation.sh

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# 0. Check python and pip3
2+
(python3 -V && pip3 -V) &>/dev/null;
3+
if [[ $? -ne 0 ]]; then
4+
echo "Error: please check whether 'python3' and 'pip3' is valid";
5+
exit 1;
6+
fi
7+
8+
# 1. Create virtual environment
9+
conda -V &>/dev/null;
10+
if [[ $? -eq 0 ]]; then
11+
echo "1. Creating conda environment: image_quality";
12+
conda create --name image_quality python=3.6;
13+
source activate image_quality;
14+
else
15+
echo "1. Creating virtual environment: .env"
16+
python3 -m venv .env
17+
source .env/bin/activate
18+
fi
19+
20+
# 2. Install tensorflow backend
21+
nvidia-smi &>/dev/null;
22+
if [[ $? -eq 0 ]]; then
23+
echo "2. Installing backend: tensorflow-gpu==1.12.0";
24+
pip3 install numpy==1.15.4 Keras==2.2.4 tensorflow-gpu==1.12.0;
25+
else
26+
echo "2. Installing backend: tensorflow==1.12.0";
27+
pip3 install numpy==1.15.4 Keras==2.2.4 tensorflow==1.12.0;
28+
fi
29+
30+
# 3. Install dependencies
31+
echo "3. Installing the rest dependencies";
32+
pip3 install -r requirements.txt;
33+
34+
# 4. Deactivate
35+
deactivate;

model/__init__.py

Whitespace-only changes.

model/data_sequence.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
'''DataSequence for model training.
2+
Inherit from keras.utils.Sequence,
3+
see https://keras.io/utils/#sequence for more details.
4+
'''
5+
6+
import random
7+
import numpy as np
8+
9+
from keras.utils import Sequence
10+
from keras.applications.inception_resnet_v2 import preprocess_input
11+
from keras.preprocessing.image import load_img, img_to_array
12+
13+
14+
class DataSequence(Sequence):
15+
def __init__(self, x_raw, y_raw, batch_size, num_classes, target_size=(224, 224)):
16+
'''Create train / validate generator.
17+
18+
Args:
19+
x_raw (np.ndarray): an array of image paths, not image data.
20+
y_raw (np.ndarray): an array of image classes.
21+
batch_size (int): the sample size of each batch.
22+
num_classes (int): the number of classes that dataset have.
23+
target_size (tuple, optional): the size (width, height) of image data. Defaults to (224, 224).
24+
'''
25+
self.batch_size = batch_size
26+
self.num_classes = num_classes
27+
self.target_size = target_size
28+
self.x, self.y = self.__create_pairs(x_raw, y_raw)
29+
30+
def __create_pairs(self, x_raw, y_raw):
31+
'''Create positive and negative pairs.
32+
Wont load image data itself but create a list of sample pairs,
33+
so that we can train model with generator and load image data asynchronous.
34+
35+
Args:
36+
x_raw (np.ndarray): an array of image paths, not image data.
37+
y_raw (np.ndarray): an array of image classes.
38+
39+
Returns:
40+
pairs (list): the samples of image paths combined into pairs, known as `x`.
41+
labels (list): the samples of image labels combined into pairs, known as `y`.
42+
'''
43+
pairs = []
44+
labels = []
45+
digit_indices = [np.where(y_raw == i)[0] for i in range(self.num_classes)]
46+
n = min([len(digit_indices[d]) for d in range(self.num_classes)]) - 1
47+
for d in range(self.num_classes):
48+
for i in range(n):
49+
z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
50+
pairs += [[x_raw[z1], x_raw[z2]]]
51+
inc = random.randrange(1, self.num_classes)
52+
dn = (d + inc) % self.num_classes
53+
z1, z2 = digit_indices[d][i], digit_indices[dn][i]
54+
pairs += [[x_raw[z1], x_raw[z2]]]
55+
labels += [1, 0]
56+
return pairs, labels
57+
58+
def __load_image_data(self, image_path):
59+
'''Load image data from local image path.
60+
61+
Args:
62+
image_path (str): the image path of an image.
63+
64+
Returns:
65+
x (np.ndarray): the image data, known as `x`.
66+
'''
67+
img = load_img(image_path, target_size=self.target_size)
68+
x = img_to_array(img)
69+
x = preprocess_input(x) # normalization
70+
return x
71+
72+
def __len__(self):
73+
return int(np.ceil(len(self.x) / float(self.batch_size)))
74+
75+
def __getitem__(self, idx):
76+
batch_x_raw = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
77+
batch_y_raw = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
78+
batch_x1, batch_x2 = [], []
79+
for x1, x2 in batch_x_raw:
80+
batch_x1.append(self.__load_image_data(x1))
81+
batch_x2.append(self.__load_image_data(x2))
82+
batch_x1 = np.array(batch_x1)
83+
batch_x2 = np.array(batch_x2)
84+
batch_x = [batch_x1, batch_x2]
85+
batch_y = np.array(batch_y_raw)
86+
return batch_x, batch_y

model/data_sequence_pred.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
'''DataSequence for model predicting.
2+
Inherit from keras.utils.Sequence,
3+
see https://keras.io/utils/#sequence for more details.
4+
'''
5+
6+
import random
7+
import numpy as np
8+
9+
from keras.utils import Sequence
10+
from keras.applications.inception_resnet_v2 import preprocess_input
11+
from keras.preprocessing.image import load_img, img_to_array
12+
13+
14+
class DataSequencePred(Sequence):
15+
def __init__(self, x_raw, batch_size=1, target_size=None):
16+
'''Create predict generator.
17+
18+
Args:
19+
x_raw (np.ndarray | list): an array of image paths, not image data.
20+
batch_size (int, optional): the sample size of each batch. Defaults to 1.
21+
target_size (tuple, optional): the size (width, height) of image data. Defaults to None.
22+
'''
23+
self.batch_size = batch_size
24+
self.target_size = target_size
25+
self.x = x_raw
26+
27+
def __load_image_data(self, image_path):
28+
'''Load image data from local image path.
29+
30+
Args:
31+
image_path (str): the image path of an image.
32+
33+
Returns:
34+
x (np.ndarray): the image data, known as `x`.
35+
'''
36+
img = load_img(image_path, target_size=self.target_size)
37+
x = img_to_array(img)
38+
x = preprocess_input(x) # normalization
39+
return x
40+
41+
def __len__(self):
42+
return int(np.ceil(len(self.x) / float(self.batch_size)))
43+
44+
def __getitem__(self, idx):
45+
batch_x_raw = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
46+
batch_x = [self.__load_image_data(x) for x in batch_x_raw]
47+
return np.array(batch_x)

0 commit comments

Comments
 (0)