36
36
import lfw
37
37
import os
38
38
import sys
39
- import math
39
+ from tensorflow . python . ops import data_flow_ops
40
40
from sklearn import metrics
41
41
from scipy .optimize import brentq
42
42
from scipy import interpolate
@@ -52,44 +52,89 @@ def main(args):
52
52
53
53
# Get the paths for the corresponding images
54
54
paths , actual_issame = lfw .get_paths (os .path .expanduser (args .lfw_dir ), pairs )
55
-
56
- # Load the model
57
- facenet .load_model (args .model )
58
55
59
- # Get input and output tensors
60
- images_placeholder = tf .get_default_graph ().get_tensor_by_name ("input:0" )
56
+ image_paths_placeholder = tf .placeholder (tf .string , shape = (None ,1 ), name = 'image_paths' )
57
+ labels_placeholder = tf .placeholder (tf .int32 , shape = (None ,1 ), name = 'labels' )
58
+ batch_size_placeholder = tf .placeholder (tf .int32 , name = 'batch_size' )
59
+ control_placeholder = tf .placeholder (tf .int32 , shape = (None ,1 ), name = 'control' )
60
+ phase_train_placeholder = tf .placeholder (tf .bool , name = 'phase_train' )
61
+
62
+ nrof_preprocess_threads = 4
63
+ image_size = (args .image_size , args .image_size )
64
+ eval_input_queue = data_flow_ops .FIFOQueue (capacity = 2000000 ,
65
+ dtypes = [tf .string , tf .int32 , tf .int32 ],
66
+ shapes = [(1 ,), (1 ,), (1 ,)],
67
+ shared_name = None , name = None )
68
+ eval_enqueue_op = eval_input_queue .enqueue_many ([image_paths_placeholder , labels_placeholder , control_placeholder ], name = 'eval_enqueue_op' )
69
+ image_batch , label_batch = facenet .create_input_pipeline (eval_input_queue , image_size , nrof_preprocess_threads , batch_size_placeholder )
70
+
71
+ # Load the model
72
+ input_map = {'image_batch' : image_batch , 'label_batch' : label_batch , 'phase_train' : phase_train_placeholder }
73
+ facenet .load_model (args .model , input_map = input_map )
74
+
75
+ # Get output tensor
61
76
embeddings = tf .get_default_graph ().get_tensor_by_name ("embeddings:0" )
62
- phase_train_placeholder = tf .get_default_graph ().get_tensor_by_name ("phase_train:0" )
63
-
64
- #image_size = images_placeholder.get_shape()[1] # For some reason this doesn't work for frozen graphs
65
- image_size = args .image_size
66
- embedding_size = embeddings .get_shape ()[1 ]
67
-
68
- # Run forward pass to calculate embeddings
69
- print ('Runnning forward pass on LFW images' )
70
- batch_size = args .lfw_batch_size
71
- nrof_images = len (paths )
72
- nrof_batches = int (math .ceil (1.0 * nrof_images / batch_size ))
73
- emb_array = np .zeros ((nrof_images , embedding_size ))
74
- for i in range (nrof_batches ):
75
- start_index = i * batch_size
76
- end_index = min ((i + 1 )* batch_size , nrof_images )
77
- paths_batch = paths [start_index :end_index ]
78
- images = facenet .load_data (paths_batch , False , False , image_size )
79
- feed_dict = { images_placeholder :images , phase_train_placeholder :False }
80
- emb_array [start_index :end_index ,:] = sess .run (embeddings , feed_dict = feed_dict )
81
-
82
- tpr , fpr , accuracy , val , val_std , far = lfw .evaluate (emb_array ,
83
- actual_issame , nrof_folds = args .lfw_nrof_folds )
77
+ #
78
+ coord = tf .train .Coordinator ()
79
+ tf .train .start_queue_runners (coord = coord , sess = sess )
84
80
85
- print ('Accuracy: %1.3f+-%1.3f' % (np .mean (accuracy ), np .std (accuracy )))
86
- print ('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val , val_std , far ))
81
+ evaluate (sess , eval_enqueue_op , image_paths_placeholder , labels_placeholder , phase_train_placeholder , batch_size_placeholder , control_placeholder ,
82
+ embeddings , label_batch , paths , actual_issame , args .lfw_batch_size , args .lfw_nrof_folds , args .distance_metric , args .subtract_mean ,
83
+ args .use_flipped_images , args .use_fixed_image_standardization )
87
84
88
- auc = metrics .auc (fpr , tpr )
89
- print ('Area Under Curve (AUC): %1.3f' % auc )
90
- eer = brentq (lambda x : 1. - x - interpolate .interp1d (fpr , tpr )(x ), 0. , 1. )
91
- print ('Equal Error Rate (EER): %1.3f' % eer )
92
-
85
+
86
+ def evaluate (sess , enqueue_op , image_paths_placeholder , labels_placeholder , phase_train_placeholder , batch_size_placeholder , control_placeholder ,
87
+ embeddings , labels , image_paths , actual_issame , batch_size , nrof_folds , distance_metric , subtract_mean , use_flipped_images , use_fixed_image_standardization ):
88
+ # Run forward pass to calculate embeddings
89
+ print ('Runnning forward pass on LFW images' )
90
+
91
+ # Enqueue one epoch of image paths and labels
92
+ nrof_embeddings = len (actual_issame )* 2 # nrof_pairs * nrof_images_per_pair
93
+ nrof_flips = 2 if use_flipped_images else 1
94
+ nrof_images = nrof_embeddings * nrof_flips
95
+ labels_array = np .expand_dims (np .arange (0 ,nrof_images ),1 )
96
+ image_paths_array = np .expand_dims (np .repeat (np .array (image_paths ),nrof_flips ),1 )
97
+ control_array = np .zeros_like (labels_array , np .int32 )
98
+ if use_fixed_image_standardization :
99
+ control_array += np .ones_like (labels_array )* facenet .FIXED_STANDARDIZATION
100
+ if use_flipped_images :
101
+ # Flip every second image
102
+ control_array += (labels_array % 2 )* facenet .FLIP
103
+ sess .run (enqueue_op , {image_paths_placeholder : image_paths_array , labels_placeholder : labels_array , control_placeholder : control_array })
104
+
105
+ embedding_size = int (embeddings .get_shape ()[1 ])
106
+ assert nrof_images % batch_size == 0 , 'The number of LFW images must be an integer multiple of the LFW batch size'
107
+ nrof_batches = nrof_images // batch_size
108
+ emb_array = np .zeros ((nrof_images , embedding_size ))
109
+ lab_array = np .zeros ((nrof_images ,))
110
+ for i in range (nrof_batches ):
111
+ feed_dict = {phase_train_placeholder :False , batch_size_placeholder :batch_size }
112
+ emb , lab = sess .run ([embeddings , labels ], feed_dict = feed_dict )
113
+ lab_array [lab ] = lab
114
+ emb_array [lab , :] = emb
115
+ if i % 10 == 9 :
116
+ print ('.' , end = '' )
117
+ sys .stdout .flush ()
118
+ print ('' )
119
+ embeddings = np .zeros ((nrof_embeddings , embedding_size * nrof_flips ))
120
+ if use_flipped_images :
121
+ # Concatenate embeddings for flipped and non flipped iversion of the images
122
+ embeddings [:,:embedding_size ] = emb_array [0 ::2 ,:]
123
+ embeddings [:,embedding_size :] = emb_array [1 ::2 ,:]
124
+ else :
125
+ embeddings = emb_array
126
+
127
+ assert np .array_equal (lab_array , np .arange (nrof_images ))== True , 'Wrong labels used for evaluation, possibly caused by training examples left in the input pipeline'
128
+ tpr , fpr , accuracy , val , val_std , far = lfw .evaluate (embeddings , actual_issame , nrof_folds = nrof_folds , distance_metric = distance_metric , subtract_mean = subtract_mean )
129
+
130
+ print ('Accuracy: %2.5f+-%2.5f' % (np .mean (accuracy ), np .std (accuracy )))
131
+ print ('Validation rate: %2.5f+-%2.5f @ FAR=%2.5f' % (val , val_std , far ))
132
+
133
+ auc = metrics .auc (fpr , tpr )
134
+ print ('Area Under Curve (AUC): %1.3f' % auc )
135
+ eer = brentq (lambda x : 1. - x - interpolate .interp1d (fpr , tpr )(x ), 0. , 1. )
136
+ print ('Equal Error Rate (EER): %1.3f' % eer )
137
+
93
138
def parse_arguments (argv ):
94
139
parser = argparse .ArgumentParser ()
95
140
@@ -105,6 +150,14 @@ def parse_arguments(argv):
105
150
help = 'The file containing the pairs to use for validation.' , default = 'data/pairs.txt' )
106
151
parser .add_argument ('--lfw_nrof_folds' , type = int ,
107
152
help = 'Number of folds to use for cross validation. Mainly used for testing.' , default = 10 )
153
+ parser .add_argument ('--distance_metric' , type = int ,
154
+ help = 'Distance metric 0:euclidian, 1:cosine similarity.' , default = 0 )
155
+ parser .add_argument ('--use_flipped_images' ,
156
+ help = 'Concatenates embeddings for the image and its horizontally flipped counterpart.' , action = 'store_true' )
157
+ parser .add_argument ('--subtract_mean' ,
158
+ help = 'Subtract feature mean before calculating distance.' , action = 'store_true' )
159
+ parser .add_argument ('--use_fixed_image_standardization' ,
160
+ help = 'Performs fixed standardization of images.' , action = 'store_true' )
108
161
return parser .parse_args (argv )
109
162
110
163
if __name__ == '__main__' :
0 commit comments