Skip to content

Commit

Permalink
Some cleanup of the dataset filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
davidsandberg committed Jan 21, 2017
1 parent 3c31c23 commit 571fad0
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 38 deletions.
6 changes: 6 additions & 0 deletions data/learning_rate_schedule_classifier_casia.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Learning rate schedule
# Maps an epoch number to a learning rate
0: 0.1
65: 0.01
77: 0.001
1000: 0.0001
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
0: 0.1
150: 0.01
180: 0.001
1000: 0.0001
251: 0.0001
62 changes: 25 additions & 37 deletions src/facenet_train_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def main(args):

np.random.seed(seed=args.seed)
train_set = facenet.get_dataset(args.data_dir)
train_set = filter_dataset(train_set, '/home/david/msceleb_embeddings_tmp3.hdf', args.filter_method,
args.filter_percentile, args.filter_min_nrof_images_per_class)
if args.filter_filename:
train_set = filter_dataset(train_set, args.filter_filename,
args.filter_percentile, args.filter_min_nrof_images_per_class)
nrof_classes = len(train_set)

print('Model directory: %s' % model_dir)
Expand Down Expand Up @@ -192,40 +193,27 @@ def find_threshold(var, percentile):
threshold = np.interp(percentile*0.01, cdf, bin_centers)
return threshold

def filter_dataset(dataset, data_filename, method, percentile, min_nrof_images_per_class):
if method=='':
def filter_dataset(dataset, data_filename, percentile, min_nrof_images_per_class):
with h5py.File(data_filename,'r') as f:
distance_to_center = np.array(f.get('distance_to_center'))
label_list = np.array(f.get('label_list'))
image_list = np.array(f.get('image_list'))
distance_to_center_threshold = find_threshold(distance_to_center, percentile)
indices = np.where(distance_to_center>=distance_to_center_threshold)[0]
filtered_dataset = dataset
if method=='intra_class_variance':
with h5py.File(data_filename,'r') as f:
# Keep the classes with the X% lowest intra-class variance
class_variance = np.array(f.get('class_variance'))
variance_threshold = find_threshold(class_variance, percentile)
indices = np.where(class_variance<variance_threshold)[1]
filtered_dataset = [ dataset[idx] for idx in indices ]
elif method=='distance_to_class_center':
with h5py.File(data_filename,'r') as f:
distance_to_center = np.array(f.get('distance_to_center'))
label_list = np.array(f.get('label_list'))
image_list = np.array(f.get('image_list'))
distance_to_center_threshold = find_threshold(distance_to_center, percentile)
indices = np.where(distance_to_center>=distance_to_center_threshold)[0]
filtered_dataset = dataset
removelist = []
for i in indices:
label = label_list[i]
image = image_list[i]
if image in filtered_dataset[label].image_paths:
filtered_dataset[label].image_paths.remove(image)
if len(filtered_dataset[label].image_paths)<min_nrof_images_per_class:
removelist.append(label)
removelist = []
for i in indices:
label = label_list[i]
image = image_list[i]
if image in filtered_dataset[label].image_paths:
filtered_dataset[label].image_paths.remove(image)
if len(filtered_dataset[label].image_paths)<min_nrof_images_per_class:
removelist.append(label)

ix = sorted(list(set(removelist)), reverse=True)
for i in ix:
del(filtered_dataset[i])
ix = sorted(list(set(removelist)), reverse=True)
for i in ix:
del(filtered_dataset[i])


else:
raise('Filtering method "%s" not implemented' % method)
return filtered_dataset

def train(args, sess, epoch, learning_rate_placeholder, global_step,
Expand Down Expand Up @@ -374,12 +362,12 @@ def parse_arguments(argv):
help='Enables logging of weight/bias histograms in tensorboard.', action='store_true')
parser.add_argument('--learning_rate_schedule_file', type=str,
help='File containing the learning rate schedule that is used when learning_rate is set to to -1.', default='../data/learning_rate_schedule.txt')
parser.add_argument('--filter_method', type=str,
help='Type of dataset filtering to apply.', default='')
parser.add_argument('--filter_filename', type=str,
help='File containing image data used for dataset filtering', default='')
parser.add_argument('--filter_percentile', type=float,
help='Keep only the percentile classes with the lowest intra-class variance.', default=100.0)
help='Keep only the percentile images closed to its class center', default=100.0)
parser.add_argument('--filter_min_nrof_images_per_class', type=int,
help='Keep only the classes with this number of examples or more.', default=60)
help='Keep only the classes with this number of examples or more', default=0)

# Parameters for validation on LFW
parser.add_argument('--lfw_pairs', type=str,
Expand Down

0 comments on commit 571fad0

Please sign in to comment.