Skip to content

Commit 3883cd4

Browse files
committedNov 23, 2020
half precision training, cleanup
1 parent 7182ea0 commit 3883cd4

17 files changed

+233
-233
lines changed
 

‎.gitignore

+5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ data/
77
exps/
88
core.*
99

10+
# NSML related
11+
.nsmlignore
12+
*.nsml.py
13+
setup.py
14+
1015
# Byte-compiled / optimized / DLL files
1116
__pycache__/
1217
*.py[cod]

‎DatasetLoader.py

+16-14
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def additive_noise(self, noisecat, audio):
9292

9393
return numpy.sum(numpy.concatenate(noises,axis=0),axis=0,keepdims=True) + audio
9494

95-
9695
def reverberate(self, audio):
9796

9897
rir_file = random.choice(self.rir_files)
@@ -103,18 +102,6 @@ def reverberate(self, audio):
103102

104103
return signal.convolve(audio, rir, mode='full')[:,:self.max_audio]
105104

106-
def speed_up(self, audio):
107-
108-
audio = audio[0].astype(numpy.int16)
109-
110-
return numpy.expand_dims(self.speedup.build_array(input_array=audio, sample_rate_in=16000),0).astype(numpy.float)[:,:self.max_audio]
111-
112-
def slow_down(self, audio):
113-
114-
audio = audio[0].astype(numpy.int16)
115-
116-
return numpy.expand_dims(self.slowdown.build_array(input_array=audio, sample_rate_in=16000),0).astype(numpy.float)[:,:self.max_audio]
117-
118105

119106
class voxceleb_loader(Dataset):
120107
def __init__(self, dataset_file_name, augment, musan_path, rir_path, max_frames, train_path):
@@ -182,6 +169,22 @@ def __len__(self):
182169
return len(self.data_list)
183170

184171

172+
173+
class test_dataset_loader(Dataset):
174+
def __init__(self, test_list, test_path, eval_frames, num_eval, **kwargs):
175+
self.max_frames = eval_frames;
176+
self.num_eval = num_eval
177+
self.test_path = test_path
178+
self.test_list = test_list
179+
180+
def __getitem__(self, index):
181+
audio = loadWAV(os.path.join(self.test_path,self.test_list[index]), self.max_frames, evalmode=True, num_eval=self.num_eval)
182+
return torch.FloatTensor(audio), self.test_list[index]
183+
184+
def __len__(self):
185+
return len(self.test_list)
186+
187+
185188
class voxceleb_sampler(torch.utils.data.Sampler):
186189
def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size):
187190

@@ -228,7 +231,6 @@ def __len__(self):
228231
return len(self.data_source)
229232

230233

231-
232234
def get_data_loader(dataset_file_name, batch_size, augment, musan_path, rir_path, max_frames, max_seg_per_spk, nDataLoaderThread, nPerSpeaker, train_path, **kwargs):
233235

234236
train_dataset = voxceleb_loader(dataset_file_name, augment, musan_path, rir_path, max_frames, train_path)

‎README.md

+43-33
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,6 @@
11
# VoxCeleb trainer
22

3-
This repository contains the framework for training speaker recognition models described in 'In defence of metric learning for speaker recognition.'
4-
5-
6-
### Distributed training
7-
8-
This branch contains experimental code for distributed training. It will be merged into `master` in the future.
9-
10-
- GPU indices should be set using the command `export CUDA_VISIBLE_DEVICES=0,1,2,3`.
11-
12-
- Evaluation is not performed between epochs during training.
13-
14-
- Use `--distributed` flag to enable distributed training.
15-
16-
- At every epoch, the whole dataset is passed through **each** GPU once. Therefore `test_interval` and `max_epochs` must be divided by the number of GPUs for the same number of forward passes as single GPU training. For example, `--test_interval 10` using 1 GPU should be equivalent to `--test_interval 2` using 5 GPUs.
17-
18-
- If you run more than one distributed training session, you need to change the port.
19-
20-
- The code only works on Linux systems with CUDA 9.2 or later.
21-
22-
If you have any suggestions for improvement, please raise it as an issue.
3+
This repository contains the framework for training speaker recognition models described in the paper '_In defence of metric learning for speaker recognition_'.
234

245
### Dependencies
256
```
@@ -47,32 +28,32 @@ In addition to the Python dependencies, `wget` and `ffmpeg` must be installed on
4728

4829
- AM-Softmax:
4930
```
50-
python ./trainSpeakerNet.py --model ResNetSE34L --log_input True --encoder_type SAP --trainfunc amsoftmax --save_path exps/exp1 --nClasses 5994 --batch_size 200 --scale 30 --margin 0.3 --train_list train_list.txt --test_list test_list.txt
31+
python ./trainSpeakerNet.py --model ResNetSE34L --log_input True --encoder_type SAP --trainfunc amsoftmax --save_path exps/exp1 --nClasses 5994 --batch_size 200 --scale 30 --margin 0.3
5132
```
5233

5334
- Angular prototypical:
5435
```
55-
python ./trainSpeakerNet.py --model ResNetSE34L --log_input True --encoder_type SAP --trainfunc angleproto --save_path exps/exp2 --nPerSpeaker 2 --batch_size 200 --train_list train_list.txt --test_list test_list.txt
36+
python ./trainSpeakerNet.py --model ResNetSE34L --log_input True --encoder_type SAP --trainfunc angleproto --save_path exps/exp2 --nPerSpeaker 2 --batch_size 200
5637
```
5738

5839
The arguments can also be passed as `--config path_to_config.yaml`. Note that the configuration file overrides the arguments passed via command line.
5940

6041
### Pretrained models
6142

62-
A pretrained model can be downloaded from [here](http://www.robots.ox.ac.uk/~joon/data/baseline_lite_ap.model).
43+
A pretrained model, described in [1], can be downloaded from [here](http://www.robots.ox.ac.uk/~joon/data/baseline_lite_ap.model).
6344

6445
You can check that the following script returns: `EER 2.1792`. You will be given an option to save the scores.
6546

6647
```
67-
python ./trainSpeakerNet.py --eval --model ResNetSE34L --log_input True --trainfunc angleproto --save_path exps/test --eval_frames 400 --test_list test_list.txt --initial_model baseline_lite_ap.model
48+
python ./trainSpeakerNet.py --eval --model ResNetSE34L --log_input True --trainfunc angleproto --save_path exps/test --eval_frames 400 --initial_model baseline_lite_ap.model
6849
```
6950

70-
A larger model trained with data augmentation can be downloaded from [here](http://www.robots.ox.ac.uk/~joon/data/baseline_v2_ap.model).
51+
A larger model trained with online data augmentation, described in [2], can be downloaded from [here](http://www.robots.ox.ac.uk/~joon/data/baseline_v2_ap.model).
7152

72-
The following script should return: `EER 1.1771`.
53+
The following script should return: `EER 1.1771`.
7354

7455
```
75-
python ./trainSpeakerNet.py --eval --model ResNetSE34V2 --log_input True --encoder_type ASP --n_mels 64 --trainfunc softmaxproto --save_path exps/test --eval_frames 400 --test_list test_list.txt --initial_model baseline_v2_ap.model
56+
python ./trainSpeakerNet.py --eval --model ResNetSE34V2 --log_input True --encoder_type ASP --n_mels 64 --trainfunc softmaxproto --save_path exps/test --eval_frames 400 --initial_model baseline_v2_ap.model
7657
```
7758

7859
### Implemented loss functions
@@ -88,16 +69,33 @@ Angular Prototypical (angleproto)
8869

8970
### Implemented models and encoders
9071
```
91-
ResNetSE34 (SAP)
9272
ResNetSE34L (SAP, ASP)
9373
ResNetSE34V2 (SAP, ASP)
9474
VGGVox40 (SAP, TAP, MAX)
9575
```
9676

77+
### Data augmentation
78+
79+
`--augment True` enables online data augmentation, described in [2].
80+
9781
### Adding new models and loss functions
9882

9983
You can add new models and loss functions to `models` and `loss` directories respectively. See the existing definitions for examples.
10084

85+
### Accelerating training
86+
87+
- Use `--mixedprec` flag to enable mixed precision training. This is recommended for Tesla V100, GeForce RTX 20 series or later models.
88+
89+
- Use `--distributed` flag to enable distributed training.
90+
91+
- GPU indices should be set using the command `export CUDA_VISIBLE_DEVICES=0,1,2,3`.
92+
93+
- Evaluation is not performed between epochs during training.
94+
95+
- If you are running more than one distributed training session, you need to change the port.
96+
97+
- At every epoch, the whole dataset is passed through **each** GPU once. Therefore `test_interval` and `max_epochs` must be divided by the number of GPUs for the same number of forward passes as single GPU training.
98+
10199
### Data
102100

103101
The [VoxCeleb](http://www.robots.ox.ac.uk/~vgg/data/voxceleb/) datasets are used for these experiments.
@@ -114,9 +112,10 @@ test list for VoxCeleb1 from [here](http://www.robots.ox.ac.uk/~vgg/data/voxcele
114112
### Replicating the results from the paper
115113

116114
1. Model definitions
117-
- `VGG-M-40` in the paper is `VGGVox` in the code.
118-
- `Thin ResNet-34` is in the paper `ResNetSE34` in the code.
119-
- `Fast ResNet-34` is in the paper `ResNetSE34L` in the code.
115+
- `VGG-M-40` in [1] is `VGGVox` in the repository.
116+
- `Thin ResNet-34` in [1] is `ResNetSE34` in the repository.
117+
- `Fast ResNet-34` in [1] is `ResNetSE34L` in the repository.
118+
- `H / ASP` in [2] is `ResNetSE34V2` in the repository.
120119

121120
2. For metric learning objectives, the batch size in the paper is `nPerSpeaker` multiplied by `batch_size` in the code. For the batch size of 800 in the paper, use `--nPerSpeaker 2 --batch_size 400`, `--nPerSpeaker 3 --batch_size 266`, etc.
122121

@@ -125,13 +124,14 @@ test list for VoxCeleb1 from [here](http://www.robots.ox.ac.uk/~vgg/data/voxcele
125124
4. You can get a good balance between speed and performance using the configuration below.
126125

127126
```
128-
python ./trainSpeakerNet.py --model ResNetSE34L --trainfunc angleproto --batch_size 400 --nPerSpeaker 2 --train_list train_list.txt --test_list test_list.txt
127+
python ./trainSpeakerNet.py --model ResNetSE34L --trainfunc angleproto --batch_size 400 --nPerSpeaker 2
129128
```
130129

131130
### Citation
132131

133-
Please cite the following if you make use of the code. Please see [here](References.md) for the full list of methods used in this trainer.
132+
Please cite [1] if you make use of the code. Please see [here](References.md) for the full list of methods used in this trainer.
134133

134+
[1] _In defence of metric learning for speaker recognition_
135135
```
136136
@inproceedings{chung2020in,
137137
title={In defence of metric learning for speaker recognition},
@@ -141,6 +141,16 @@ Please cite the following if you make use of the code. Please see [here](Referen
141141
}
142142
```
143143

144+
[2] _Clova baseline system for the VoxCeleb Speaker Recognition Challenge 2020_
145+
```
146+
@article{heo2020clova,
147+
title={Clova baseline system for the {VoxCeleb} Speaker Recognition Challenge 2020},
148+
author={Heo, Hee Soo and Lee, Bong-Jin and Huh, Jaesung and Chung, Joon Son},
149+
journal={arXiv preprint arXiv:2009.14153},
150+
year={2020}
151+
}
152+
```
153+
144154
### License
145155
```
146156
Copyright (c) 2020-present NAVER Corp.

‎SpeakerNet.py

+45-32
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
import numpy, math, pdb, sys, random
88
import time, os, itertools, shutil, importlib
99
from tuneThreshold import tuneThresholdfromScore
10-
from DatasetLoader import loadWAV
10+
from DatasetLoader import test_dataset_loader
11+
12+
from torch.cuda.amp import autocast, GradScaler
1113

1214
class WrappedModel(nn.Module):
1315

@@ -43,15 +45,17 @@ def forward(self, data, label=None):
4345
return outp
4446

4547
else:
48+
4649
outp = outp.reshape(self.nPerSpeaker,-1,outp.size()[-1]).transpose(1,0).squeeze(1)
50+
4751
nloss, prec1 = self.__L__.forward(outp,label)
4852

4953
return nloss, prec1
5054

5155

5256
class ModelTrainer(object):
5357

54-
def __init__(self, speaker_model, optimizer, scheduler, gpu, **kwargs):
58+
def __init__(self, speaker_model, optimizer, scheduler, gpu, mixedprec, **kwargs):
5559

5660
self.__model__ = speaker_model
5761

@@ -61,8 +65,12 @@ def __init__(self, speaker_model, optimizer, scheduler, gpu, **kwargs):
6165
Scheduler = importlib.import_module('scheduler.'+scheduler).__getattribute__('Scheduler')
6266
self.__scheduler__, self.lr_step = Scheduler(self.__optimizer__, **kwargs)
6367

68+
self.scaler = GradScaler()
69+
6470
self.gpu = gpu
6571

72+
self.mixedprec = mixedprec
73+
6674
assert self.lr_step in ['epoch', 'iteration']
6775

6876
# ## ===== ===== ===== ===== ===== ===== ===== =====
@@ -90,15 +98,24 @@ def train_network(self, loader, verbose):
9098

9199
label = torch.LongTensor(data_label).cuda()
92100

93-
nloss, prec1 = self.__model__(data, label)
101+
if self.mixedprec:
102+
with autocast():
103+
nloss, prec1 = self.__model__(data, label)
104+
self.scaler.scale(nloss).backward();
105+
self.scaler.step(self.__optimizer__);
106+
self.scaler.update();
107+
else:
108+
nloss, prec1 = self.__model__(data, label)
109+
nloss.backward();
110+
self.__optimizer__.step();
111+
94112

95113
loss += nloss.detach().cpu();
96-
top1 += prec1
114+
top1 += prec1.detach().cpu()
97115
counter += 1;
98116
index += stepsize;
99117

100-
nloss.backward();
101-
self.__optimizer__.step();
118+
102119

103120
telapsed = time.time() - tstart
104121
tstart = time.time()
@@ -121,7 +138,7 @@ def train_network(self, loader, verbose):
121138
## Evaluate from list
122139
## ===== ===== ===== ===== ===== ===== ===== =====
123140

124-
def evaluateFromList(self, listfilename, print_interval=100, test_path='', num_eval=10, eval_frames=None):
141+
def evaluateFromList(self, test_list, test_path, nDataLoaderThread, print_interval=100, num_eval=10, **kwargs):
125142

126143
self.__model__.eval();
127144

@@ -131,34 +148,30 @@ def evaluateFromList(self, listfilename, print_interval=100, test_path='', num_e
131148
tstart = time.time()
132149

133150
## Read all lines
134-
with open(listfilename) as listfile:
135-
while True:
136-
line = listfile.readline();
137-
if (not line):
138-
break;
139-
140-
data = line.split();
141-
142-
## Append random label if missing
143-
if len(data) == 2: data = [random.randint(0,1)] + data
144-
145-
files.append(data[1])
146-
files.append(data[2])
147-
lines.append(line)
151+
with open(test_list) as f:
152+
lines = f.readlines()
148153

154+
## Get a list of unique file names
155+
files = sum([x.strip().split()[-2:] for x in lines],[])
149156
setfiles = list(set(files))
150157
setfiles.sort()
151158

152-
## Save all features to file
153-
for idx, file in enumerate(setfiles):
154-
155-
inp1 = torch.FloatTensor(loadWAV(os.path.join(test_path,file), eval_frames, evalmode=True, num_eval=num_eval)).cuda()
156-
157-
ref_feat = self.__model__(inp1).detach().cpu()
158-
159-
feats[file] = ref_feat
160-
161-
telapsed = time.time() - tstart
159+
## Define test data loader
160+
test_dataset = test_dataset_loader(setfiles, test_path, num_eval=num_eval, **kwargs)
161+
test_loader = torch.utils.data.DataLoader(
162+
test_dataset,
163+
batch_size=1,
164+
shuffle=False,
165+
num_workers=nDataLoaderThread,
166+
drop_last=False,
167+
)
168+
169+
## Extract features for every image
170+
for idx, data in enumerate(test_loader):
171+
inp1 = data[0][0].cuda()
172+
ref_feat = self.__model__(inp1).detach().cpu()
173+
feats[data[1][0]] = ref_feat
174+
telapsed = time.time() - tstart
162175

163176
if idx % print_interval == 0:
164177
sys.stdout.write("\rReading %d of %d: %.2f Hz, embedding size %d"%(idx,len(setfiles),idx/telapsed,ref_feat.size()[1]));
@@ -197,7 +210,7 @@ def evaluateFromList(self, listfilename, print_interval=100, test_path='', num_e
197210
sys.stdout.write("\rComputing %d of %d: %.2f Hz"%(idx,len(lines),idx/telapsed));
198211
sys.stdout.flush();
199212

200-
print('\n')
213+
print('')
201214

202215
return (all_scores, all_labels, all_trials);
203216

‎loss/aamsoftmax.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def forward(self, x, label=None):
3939
# cos(theta)
4040
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
4141
# cos(theta + m)
42-
sine = torch.sqrt((1.0 - torch.pow(cosine, 2)).clamp(0, 1))
42+
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
4343
phi = cosine * self.cos_m - sine * self.sin_m
4444

4545
if self.easy_margin:
@@ -53,6 +53,6 @@ def forward(self, x, label=None):
5353
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
5454
output = output * self.s
5555

56-
loss = self.ce(output, label)
57-
prec1, _ = accuracy(output.detach().cpu(), label.detach().cpu(), topk=(1, 5))
56+
loss = self.ce(output, label)
57+
prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
5858
return loss, prec1

‎loss/amsoftmax.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def forward(self, x, label=None):
3939
if x.is_cuda: delt_costh = delt_costh.cuda()
4040
costh_m = costh - delt_costh
4141
costh_m_s = self.s * costh_m
42-
loss = self.ce(costh_m_s, label)
43-
prec1, _ = accuracy(costh_m_s.detach().cpu(), label.detach().cpu(), topk=(1, 5))
42+
loss = self.ce(costh_m_s, label)
43+
prec1 = accuracy(costh_m_s.detach(), label.detach(), topk=(1,))[0]
4444
return loss, prec1
4545

‎loss/angleproto.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ def forward(self, x, label=None):
3232
torch.clamp(self.w, 1e-6)
3333
cos_sim_matrix = cos_sim_matrix * self.w + self.b
3434

35-
label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
36-
nloss = self.criterion(cos_sim_matrix, label)
37-
prec1, _ = accuracy(cos_sim_matrix.detach().cpu(), label.detach().cpu(), topk=(1, 5))
35+
label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
36+
nloss = self.criterion(cos_sim_matrix, label)
37+
prec1 = accuracy(cos_sim_matrix.detach(), label.detach(), topk=(1,))[0]
3838

3939
return nloss, prec1

‎loss/ge2e.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,6 @@ def forward(self, x, label=None):
4848

4949
label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
5050
nloss = self.criterion(cos_sim_matrix.view(-1,stepsize), torch.repeat_interleave(label,repeats=gsize,dim=0).cuda())
51-
prec1, _ = accuracy(cos_sim_matrix.view(-1,stepsize).detach().cpu(), torch.repeat_interleave(label,repeats=gsize,dim=0).detach().cpu(), topk=(1, 5))
51+
prec1 = accuracy(cos_sim_matrix.view(-1,stepsize).detach(), torch.repeat_interleave(label,repeats=gsize,dim=0).detach(), topk=(1,))[0]
5252

5353
return nloss, prec1

‎loss/proto.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ def forward(self, x, label=None):
2828
out_positive = x[:,0,:]
2929
stepsize = out_anchor.size()[0]
3030

31-
output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))**2)
32-
label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
33-
nloss = self.criterion(output, label)
34-
prec1, _ = accuracy(output.detach().cpu(), label.detach().cpu(), topk=(1, 5))
31+
output = -1 * (F.pairwise_distance(out_positive.unsqueeze(-1),out_anchor.unsqueeze(-1).transpose(0,2))**2)
32+
label = torch.from_numpy(numpy.asarray(range(0,stepsize))).cuda()
33+
nloss = self.criterion(output, label)
34+
prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
3535

3636
return nloss, prec1

‎loss/softmax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,6 @@ def forward(self, x, label=None):
2222

2323
x = self.fc(x)
2424
nloss = self.criterion(x, label)
25-
prec1, _ = accuracy(x.detach().cpu(), label.detach().cpu(), topk=(1, 5))
25+
prec1 = accuracy(x.detach(), label.detach(), topk=(1,))[0]
2626

2727
return nloss, prec1

‎models/ResNetSE34.py

-112
This file was deleted.

‎models/ResNetSE34L.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,11 @@ def new_parameter(self, *size):
7676

7777
def forward(self, x):
7878

79-
x = self.torchfb(x)+1e-6
80-
if self.log_input: x = x.log()
81-
x = self.instancenorm(x).unsqueeze(1).detach()
79+
with torch.no_grad():
80+
with torch.cuda.amp.autocast(enabled=False):
81+
x = self.torchfb(x)+1e-6
82+
if self.log_input: x = x.log()
83+
x = self.instancenorm(x).unsqueeze(1).detach()
8284

8385
x = self.conv1(x)
8486
x = self.bn1(x)

‎models/ResNetSE34V2.py

100755100644
+4-3
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,10 @@ def new_parameter(self, *size):
8787
def forward(self, x):
8888

8989
with torch.no_grad():
90-
x = self.torchfb(x)+1e-6
91-
if self.log_input: x = x.log()
92-
x = self.instancenorm(x).unsqueeze(1)
90+
with torch.cuda.amp.autocast(enabled=False):
91+
x = self.torchfb(x)+1e-6
92+
if self.log_input: x = x.log()
93+
x = self.instancenorm(x).unsqueeze(1)
9394

9495
x = self.conv1(x)
9596
x = self.relu(x)

‎models/VGGVox.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,11 @@ def new_parameter(self, *size):
7171

7272
def forward(self, x):
7373

74-
x = self.torchfb(x)+1e-6
75-
if self.log_input: x = x.log()
76-
x = self.instancenorm(x).unsqueeze(1).detach()
74+
with torch.no_grad():
75+
with torch.cuda.amp.autocast(enabled=False):
76+
x = self.torchfb(x)+1e-6
77+
if self.log_input: x = x.log()
78+
x = self.instancenorm(x).unsqueeze(1)
7779

7880
x = self.netcnn(x);
7981

‎requirements.txt

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
torch>=1.5.0
2-
torchaudio>=0.5.0
1+
torch>=1.6.0
2+
torchaudio>=0.6.0
33
numpy
44
scipy
55
scikit-learn
66
tqdm
7-
pyyaml
7+
pyyaml

‎trainSpeakerNet.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
import torch.distributed as dist
1616
import torch.multiprocessing as mp
1717

18+
# ## ===== ===== ===== ===== ===== ===== ===== =====
19+
# ## Parse arguments
20+
# ## ===== ===== ===== ===== ===== ===== ===== =====
21+
1822
parser = argparse.ArgumentParser(description = "SpeakerNet");
1923

2024
parser.add_argument('--config', type=str, default=None, help='Config YAML file');
@@ -42,18 +46,18 @@
4246
## Loss functions
4347
parser.add_argument("--hard_prob", type=float, default=0.5, help='Hard negative mining probability, otherwise random, only for some loss functions');
4448
parser.add_argument("--hard_rank", type=int, default=10, help='Hard negative mining rank in the batch, only for some loss functions');
45-
parser.add_argument('--margin', type=float, default=1, help='Loss margin, only for some loss functions');
46-
parser.add_argument('--scale', type=float, default=15, help='Loss scale, only for some loss functions');
49+
parser.add_argument('--margin', type=float, default=0.1, help='Loss margin, only for some loss functions');
50+
parser.add_argument('--scale', type=float, default=30, help='Loss scale, only for some loss functions');
4751
parser.add_argument('--nPerSpeaker', type=int, default=1, help='Number of utterances per speaker per batch, only for metric learning based losses');
4852
parser.add_argument('--nClasses', type=int, default=5994, help='Number of speakers in the softmax layer, only for softmax-based losses');
4953

5054
## Load and save
5155
parser.add_argument('--initial_model', type=str, default="", help='Initial model weights');
52-
parser.add_argument('--save_path', type=str, default="./data/exp1", help='Path for model and logs');
56+
parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path for model and logs');
5357

5458
## Training and test data
55-
parser.add_argument('--train_list', type=str, default="", help='Train list');
56-
parser.add_argument('--test_list', type=str, default="", help='Evaluation list');
59+
parser.add_argument('--train_list', type=str, default="data/train_list.txt", help='Train list');
60+
parser.add_argument('--test_list', type=str, default="data/test_list.txt", help='Evaluation list');
5761
parser.add_argument('--train_path', type=str, default="data/voxceleb2", help='Absolute path to the train set');
5862
parser.add_argument('--test_path', type=str, default="data/voxceleb1", help='Absolute path to the test set');
5963
parser.add_argument('--musan_path', type=str, default="data/musan_split", help='Absolute path to the test set');
@@ -69,9 +73,10 @@
6973
## For test only
7074
parser.add_argument('--eval', dest='eval', action='store_true', help='Eval only')
7175

72-
## Distributed training
76+
## Distributed and mixed precision training
7377
parser.add_argument('--port', type=str, default="8888", help='Port for distributed training, input as text');
7478
parser.add_argument('--distributed', dest='distributed', action='store_true', help='Enable distributed training')
79+
parser.add_argument('--mixedprec', dest='mixedprec', action='store_true', help='Enable mixed precision training')
7580

7681
args = parser.parse_args();
7782

@@ -120,13 +125,10 @@ def main_worker(gpu, ngpus_per_node, args):
120125
else:
121126
s = WrappedModel(s).cuda(args.gpu)
122127

123-
prevloss = float("inf");
124-
sumloss = 0;
125-
min_eer = [100];
126128
it = 1
127129

128130
## Write args to scorefile
129-
scorefile = open(args.result_save_path+"/scores.txt", "a+");
131+
scorefile = open(args.result_save_path+"/scores.txt", "a+");
130132

131133
## Initialise trainer and data loader
132134
trainLoader = get_data_loader(args.train_list, **vars(args));
@@ -150,12 +152,24 @@ def main_worker(gpu, ngpus_per_node, args):
150152
## Evaluation code - must run on single GPU
151153
if args.eval == True:
152154

155+
pytorch_total_params = sum(p.numel() for p in s.module.__S__.parameters())
156+
157+
print('Total parameters: ',pytorch_total_params)
158+
print('Test list',args.test_list)
159+
153160
assert args.distributed == False
154161

155-
sc, lab, _ = trainer.evaluateFromList(args.test_list, print_interval=100, test_path=args.test_path, eval_frames=args.eval_frames)
162+
sc, lab, _ = trainer.evaluateFromList(**vars(args))
156163
result = tuneThresholdfromScore(sc, lab, [1, 0.1]);
157164

158-
print(result[1])
165+
p_target = 0.05
166+
c_miss = 1
167+
c_fa = 1
168+
169+
fnrs, fprs, thresholds = ComputeErrorRates(sc, lab)
170+
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
171+
172+
print('EER %2.4f MinDCF %.5f'%(result[1],mindcf))
159173
quit();
160174

161175
## Save training code and params
@@ -182,6 +196,14 @@ def main_worker(gpu, ngpus_per_node, args):
182196

183197
if it % args.test_interval == 0 and args.gpu == 0:
184198

199+
## Perform evaluation only in single GPU training
200+
if not args.distributed:
201+
sc, lab, _ = trainer.evaluateFromList(**vars(args))
202+
result = tuneThresholdfromScore(sc, lab, [1, 0.1]);
203+
204+
print("IT %d, VEER %2.4f"%(it, result[1]));
205+
scorefile.write("IT %d, VEER %2.4f\n"%(it, result[1]));
206+
185207
trainer.saveParameters(args.model_save_path+"/model%09d.model"%it);
186208

187209
print(time.strftime("%Y-%m-%d %H:%M:%S"), "TEER/TAcc %2.2f, TLOSS %f"%( traineer, loss));
@@ -214,6 +236,7 @@ def main():
214236
print('Python Version:', sys.version)
215237
print('PyTorch Version:', torch.__version__)
216238
print('Number of GPUs:', torch.cuda.device_count())
239+
print('Save path:',args.save_path)
217240

218241
if args.distributed:
219242
mp.spawn(main_worker, nprocs=n_gpus, args=(n_gpus, args))

‎tuneThreshold.py

+58-4
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,12 @@
88
from sklearn import metrics
99
import numpy
1010
import pdb
11+
from operator import itemgetter
1112

1213
def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None):
1314

1415
fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1)
1516
fnr = 1 - tpr
16-
17-
fnr = fnr*100
18-
fpr = fpr*100
1917

2018
tunedThreshold = [];
2119
if target_fr:
@@ -28,6 +26,62 @@ def tuneThresholdfromScore(scores, labels, target_fa, target_fr = None):
2826
tunedThreshold.append([thresholds[idx], fpr[idx], fnr[idx]]);
2927

3028
idxE = numpy.nanargmin(numpy.absolute((fnr - fpr)))
31-
eer = max(fpr[idxE],fnr[idxE])
29+
eer = max(fpr[idxE],fnr[idxE])*100
3230

3331
return (tunedThreshold, eer, fpr, fnr);
32+
33+
# Creates a list of false-negative rates, a list of false-positive rates
34+
# and a list of decision thresholds that give those error-rates.
35+
def ComputeErrorRates(scores, labels):
36+
37+
# Sort the scores from smallest to largest, and also get the corresponding
38+
# indexes of the sorted scores. We will treat the sorted scores as the
39+
# thresholds at which the the error-rates are evaluated.
40+
sorted_indexes, thresholds = zip(*sorted(
41+
[(index, threshold) for index, threshold in enumerate(scores)],
42+
key=itemgetter(1)))
43+
sorted_labels = []
44+
labels = [labels[i] for i in sorted_indexes]
45+
fnrs = []
46+
fprs = []
47+
48+
# At the end of this loop, fnrs[i] is the number of errors made by
49+
# incorrectly rejecting scores less than thresholds[i]. And, fprs[i]
50+
# is the total number of times that we have correctly accepted scores
51+
# greater than thresholds[i].
52+
for i in range(0, len(labels)):
53+
if i == 0:
54+
fnrs.append(labels[i])
55+
fprs.append(1 - labels[i])
56+
else:
57+
fnrs.append(fnrs[i-1] + labels[i])
58+
fprs.append(fprs[i-1] + 1 - labels[i])
59+
fnrs_norm = sum(labels)
60+
fprs_norm = len(labels) - fnrs_norm
61+
62+
# Now divide by the total number of false negative errors to
63+
# obtain the false positive rates across all thresholds
64+
fnrs = [x / float(fnrs_norm) for x in fnrs]
65+
66+
# Divide by the total number of corret positives to get the
67+
# true positive rate. Subtract these quantities from 1 to
68+
# get the false positive rates.
69+
fprs = [1 - x / float(fprs_norm) for x in fprs]
70+
return fnrs, fprs, thresholds
71+
72+
# Computes the minimum of the detection cost function. The comments refer to
73+
# equations in Section 3 of the NIST 2016 Speaker Recognition Evaluation Plan.
74+
def ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa):
75+
min_c_det = float("inf")
76+
min_c_det_threshold = thresholds[0]
77+
for i in range(0, len(fnrs)):
78+
# See Equation (2). it is a weighted sum of false negative
79+
# and false positive errors.
80+
c_det = c_miss * fnrs[i] * p_target + c_fa * fprs[i] * (1 - p_target)
81+
if c_det < min_c_det:
82+
min_c_det = c_det
83+
min_c_det_threshold = thresholds[i]
84+
# See Equations (3) and (4). Now we normalize the cost.
85+
c_def = min(c_miss * p_target, c_fa * (1 - p_target))
86+
min_dcf = min_c_det / c_def
87+
return min_dcf, min_c_det_threshold

0 commit comments

Comments
 (0)
Please sign in to comment.