Skip to content

Commit 9b55e00

Browse files
committed
fix recipe tests
1 parent 62bb185 commit 9b55e00

File tree

11 files changed

+280
-95
lines changed

11 files changed

+280
-95
lines changed

benchmarks/MP3S/IEMOCAP/ecapa_tdnn/hparams/ssl.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ test_spk_id: 1
3737
train_annotation: !ref <output_folder>/train.json
3838
valid_annotation: !ref <output_folder>/valid.json
3939
test_annotation: !ref <output_folder>/test.json
40+
skip_prep: False
4041

4142
# The train logger writes training statistics to a file, as well as stdout.
4243
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger

benchmarks/MP3S/IEMOCAP/ecapa_tdnn/train.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -224,19 +224,20 @@ def label_pipeline(emo):
224224
from iemocap_prepare import prepare_data # noqa E402
225225

226226
# Data preparation, to be run on only one process.
227-
sb.utils.distributed.run_on_main(
228-
prepare_data,
229-
kwargs={
230-
"data_original": hparams["data_folder"],
231-
"save_json_train": hparams["train_annotation"],
232-
"save_json_valid": hparams["valid_annotation"],
233-
"save_json_test": hparams["test_annotation"],
234-
"split_ratio": [80, 10, 10],
235-
"different_speakers": hparams["different_speakers"],
236-
"test_spk_id": hparams["test_spk_id"],
237-
"seed": hparams["seed"],
238-
},
239-
)
227+
if not hparams["skip_prep"]:
228+
sb.utils.distributed.run_on_main(
229+
prepare_data,
230+
kwargs={
231+
"data_original": hparams["data_folder"],
232+
"save_json_train": hparams["train_annotation"],
233+
"save_json_valid": hparams["valid_annotation"],
234+
"save_json_test": hparams["test_annotation"],
235+
"split_ratio": [80, 10, 10],
236+
"different_speakers": hparams["different_speakers"],
237+
"test_spk_id": hparams["test_spk_id"],
238+
"seed": hparams["seed"],
239+
},
240+
)
240241

241242
# Data preparation, to be run on only one process.
242243
# Create dataset objects "train", "valid", and "test".

benchmarks/MP3S/IEMOCAP/linear/hparams/ssl.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ test_spk_id: 1
3737
train_annotation: !ref <output_folder>/train.json
3838
valid_annotation: !ref <output_folder>/valid.json
3939
test_annotation: !ref <output_folder>/test.json
40+
skip_prep: False
4041

4142
# The train logger writes training statistics to a file, as well as stdout.
4243
train_logger: !new:speechbrain.utils.train_logger.FileTrainLogger

benchmarks/MP3S/IEMOCAP/linear/train.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,19 +232,20 @@ def label_pipeline(emo):
232232
from iemocap_prepare import prepare_data # noqa E402
233233

234234
# Data preparation, to be run on only one process.
235-
sb.utils.distributed.run_on_main(
236-
prepare_data,
237-
kwargs={
238-
"data_original": hparams["data_folder"],
239-
"save_json_train": hparams["train_annotation"],
240-
"save_json_valid": hparams["valid_annotation"],
241-
"save_json_test": hparams["test_annotation"],
242-
"split_ratio": [80, 10, 10],
243-
"different_speakers": hparams["different_speakers"],
244-
"test_spk_id": hparams["test_spk_id"],
245-
"seed": hparams["seed"],
246-
},
247-
)
235+
if not hparams["skip_prep"]:
236+
sb.utils.distributed.run_on_main(
237+
prepare_data,
238+
kwargs={
239+
"data_original": hparams["data_folder"],
240+
"save_json_train": hparams["train_annotation"],
241+
"save_json_valid": hparams["valid_annotation"],
242+
"save_json_test": hparams["test_annotation"],
243+
"split_ratio": [80, 10, 10],
244+
"different_speakers": hparams["different_speakers"],
245+
"test_spk_id": hparams["test_spk_id"],
246+
"seed": hparams["seed"],
247+
},
248+
)
248249

249250
# Data preparation, to be run on only one process.
250251
# Create dataset objects "train", "valid", and "test".

benchmarks/MP3S/VoxCeleb1/Xvectors/hparams/ssl.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ ssl_hub: facebook/wav2vec2-base
2424
# Use the following links for the official voxceleb splits:
2525
# Therefore you cannot use any files in VoxCeleb1 for training
2626
# if you are using these lists for testing.
27-
verification_file: !PLACEHOLDER #path/to/veri_test2.txt
27+
verification_file: https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt
2828

2929
train_data: !ref <save_folder>/train.csv
3030
enrol_data: !ref <save_folder>/enrol.csv
@@ -47,6 +47,7 @@ test_dataloader_opts:
4747
skip_prep: False
4848
ckpt_interval_minutes: 15 # save checkpoint every N min
4949
pretrain: True
50+
do_verification: True
5051

5152
# Training parameters
5253
precision: fp32

benchmarks/MP3S/VoxCeleb1/Xvectors/train.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ def label_pipeline(spk_id):
390390
splits=["train", "dev", "test"],
391391
split_ratio=[90, 10],
392392
seg_dur=hparams["sentence_len"],
393+
skip_prep=hparams["skip_prep"],
393394
source=hparams["voxceleb_source"]
394395
if "voxceleb_source" in hparams
395396
else None,
@@ -427,37 +428,41 @@ def label_pipeline(spk_id):
427428
valid_loader_kwargs=hparams["dataloader_options"],
428429
)
429430

430-
# Now preparing for test :
431-
hparams["device"] = speaker_brain.device
431+
if hparams["do_verification"]:
432432

433-
speaker_brain.modules.eval()
434-
train_dataloader, enrol_dataloader, test_dataloader = dataio_prep_verif(
435-
hparams
436-
)
437-
# Computing enrollment and test embeddings
438-
logger.info("Computing enroll/test embeddings...")
433+
# Now preparing for test :
434+
hparams["device"] = speaker_brain.device
435+
436+
speaker_brain.modules.eval()
437+
train_dataloader, enrol_dataloader, test_dataloader = dataio_prep_verif(
438+
hparams
439+
)
440+
# Computing enrollment and test embeddings
441+
logger.info("Computing enroll/test embeddings...")
439442

440-
# First run
441-
enrol_dict = compute_embedding_loop(enrol_dataloader)
442-
test_dict = compute_embedding_loop(test_dataloader)
443+
# First run
444+
enrol_dict = compute_embedding_loop(enrol_dataloader)
445+
test_dict = compute_embedding_loop(test_dataloader)
443446

444-
if "score_norm" in hparams:
445-
train_dict = compute_embedding_loop(train_dataloader)
447+
if "score_norm" in hparams:
448+
train_dict = compute_embedding_loop(train_dataloader)
446449

447-
# Compute the EER
448-
logger.info("Computing EER..")
449-
# Reading standard verification split
450-
with open(veri_file_path) as f:
451-
veri_test = [line.rstrip() for line in f]
450+
# Compute the EER
451+
logger.info("Computing EER..")
452+
# Reading standard verification split
453+
with open(veri_file_path) as f:
454+
veri_test = [line.rstrip() for line in f]
452455

453-
positive_scores, negative_scores = get_verification_scores(veri_test)
454-
del enrol_dict, test_dict
456+
positive_scores, negative_scores = get_verification_scores(veri_test)
457+
del enrol_dict, test_dict
455458

456-
eer, th = EER(torch.tensor(positive_scores), torch.tensor(negative_scores))
457-
logger.info("EER(%%)=%f", eer * 100)
459+
eer, th = EER(
460+
torch.tensor(positive_scores), torch.tensor(negative_scores)
461+
)
462+
logger.info("EER(%%)=%f", eer * 100)
458463

459-
min_dcf, th = minDCF(
460-
torch.tensor(positive_scores), torch.tensor(negative_scores)
461-
)
462-
# Testing
463-
logger.info("minDCF=%f", min_dcf * 100)
464+
min_dcf, th = minDCF(
465+
torch.tensor(positive_scores), torch.tensor(negative_scores)
466+
)
467+
# Testing
468+
logger.info("minDCF=%f", min_dcf * 100)

benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ valid_annotation: !ref <output_folder>/dev.csv
2121
ssl_folder: !ref <output_folder>/ssl_checkpoints
2222
ssl_hub: facebook/wav2vec2-base
2323

24-
verification_file: !PLACEHOLDER #path/to/veri_test2.txt
24+
verification_file: https://www.robots.ox.ac.uk/~vgg/data/voxceleb/meta/veri_test2.txt #path/to/veri_test2.txt
2525

2626
train_data: !ref <save_folder>/train.csv
2727
enrol_data: !ref <save_folder>/enrol.csv
@@ -43,6 +43,7 @@ test_dataloader_opts:
4343
skip_prep: False
4444
ckpt_interval_minutes: 15 # save checkpoint every N min
4545
pretrain: True
46+
do_verification: True
4647

4748
# Training parameters
4849
precision: fp32

benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,7 @@ def label_pipeline(spk_id):
389389
splits=["train", "dev", "test"],
390390
split_ratio=[90, 10],
391391
seg_dur=hparams["sentence_len"],
392+
skip_prep=hparams["skip_prep"],
392393
source=hparams["voxceleb_source"]
393394
if "voxceleb_source" in hparams
394395
else None,
@@ -426,37 +427,40 @@ def label_pipeline(spk_id):
426427
valid_loader_kwargs=hparams["dataloader_options"],
427428
)
428429

429-
# Now preparing for test :
430-
hparams["device"] = speaker_brain.device
430+
if hparams["do_verification"]:
431+
# Now preparing for test :
432+
hparams["device"] = speaker_brain.device
431433

432-
speaker_brain.modules.eval()
433-
train_dataloader, enrol_dataloader, test_dataloader = dataio_prep_verif(
434-
hparams
435-
)
436-
# Computing enrollment and test embeddings
437-
logger.info("Computing enroll/test embeddings...")
434+
speaker_brain.modules.eval()
435+
train_dataloader, enrol_dataloader, test_dataloader = dataio_prep_verif(
436+
hparams
437+
)
438+
# Computing enrollment and test embeddings
439+
logger.info("Computing enroll/test embeddings...")
438440

439-
# First run
440-
enrol_dict = compute_embedding_loop(enrol_dataloader)
441-
test_dict = compute_embedding_loop(test_dataloader)
441+
# First run
442+
enrol_dict = compute_embedding_loop(enrol_dataloader)
443+
test_dict = compute_embedding_loop(test_dataloader)
442444

443-
if "score_norm" in hparams:
444-
train_dict = compute_embedding_loop(train_dataloader)
445+
if "score_norm" in hparams:
446+
train_dict = compute_embedding_loop(train_dataloader)
445447

446-
# Compute the EER
447-
logger.info("Computing EER..")
448-
# Reading standard verification split
449-
with open(veri_file_path) as f:
450-
veri_test = [line.rstrip() for line in f]
448+
# Compute the EER
449+
logger.info("Computing EER..")
450+
# Reading standard verification split
451+
with open(veri_file_path) as f:
452+
veri_test = [line.rstrip() for line in f]
451453

452-
positive_scores, negative_scores = get_verification_scores(veri_test)
453-
del enrol_dict, test_dict
454+
positive_scores, negative_scores = get_verification_scores(veri_test)
455+
del enrol_dict, test_dict
454456

455-
eer, th = EER(torch.tensor(positive_scores), torch.tensor(negative_scores))
456-
logger.info("EER(%%)=%f", eer * 100)
457+
eer, th = EER(
458+
torch.tensor(positive_scores), torch.tensor(negative_scores)
459+
)
460+
logger.info("EER(%%)=%f", eer * 100)
457461

458-
min_dcf, th = minDCF(
459-
torch.tensor(positive_scores), torch.tensor(negative_scores)
460-
)
461-
# Testing
462-
logger.info("minDCF=%f", min_dcf * 100)
462+
min_dcf, th = minDCF(
463+
torch.tensor(positive_scores), torch.tensor(negative_scores)
464+
)
465+
# Testing
466+
logger.info("minDCF=%f", min_dcf * 100)

tests/recipes/MP3S.csv

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
Task,Dataset,Script_file,Hparam_file,Data_prep_file,Readme_file,Result_url,HF_repo,test_debug_flags,test_debug_checks
2+
ASR,LibriSpeech,benchmarks/MP3S/Buckeye/LSTM/train.py,benchmarks/MP3S/Buckeye/LSTM/hparams/ssl.yaml,,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_Buckeye.csv --valid_csv=tests/samples/annotation/ASR_Buckeye.csv --test_csv=[tests/samples/annotation/ASR_Buckeye.csv] --number_of_epochs=2 --skip_prep=True --output_neurons=22,
3+
ASR,LibriSpeech,benchmarks/MP3S/Buckeye/contextnet/train.py,benchmarks/MP3S/Buckeye/contextnet/hparams/ssl.yaml,,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_Buckeye.csv --valid_csv=tests/samples/annotation/ASR_Buckeye.csv --test_csv=[tests/samples/annotation/ASR_Buckeye.csv] --number_of_epochs=2 --skip_prep=True --output_neurons=22,
4+
ASR,LibriSpeech,benchmarks/MP3S/LibriSpeech/contextnet/train.py,benchmarks/MP3S/LibriSpeech/contextnet/hparams/ssl.yaml,benchmarks/MP3S/LibriSpeech/contextnet/librispeech_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --output_neurons=22,
5+
ASR,LibriSpeech,benchmarks/MP3S/LibriSpeech/LSTM/train.py,benchmarks/MP3S/LibriSpeech/LSTM/hparams/ssl.yaml,benchmarks/MP3S/LibriSpeech/LSTM/librispeech_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=[tests/samples/annotation/ASR_train.csv] --number_of_epochs=2 --skip_prep=True --output_neurons=22,
6+
SLU,SLURP,benchmarks/MP3S/SLURP/LSTM_linear/train.py,benchmarks/MP3S/SLURP/LSTM_linear/hparams/ssl.yaml,benchmarks/MP3S/SLURP/LSTM_linear/prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,
7+
SLU,SLURP,benchmarks/MP3S/SLURP/linear/train.py,benchmarks/MP3S/SLURP/linear/hparams/ssl.yaml,benchmarks/MP3S/SLURP/linear/prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --csv_train=tests/samples/annotation/ASR_train.csv --csv_valid=tests/samples/annotation/ASR_train.csv --csv_test=tests/samples/annotation/ASR_train.csv --skip_prep=True --number_of_epochs=2,
8+
Emotion_recognition,IEMOCAP,benchmarks/MP3S/IEMOCAP/ecapa_tdnn/train.py,benchmarks/MP3S/IEMOCAP/ecapa_tdnn/hparams/ssl.yaml,benchmarks/MP3S/IEMOCAP/ecapa_tdnn/iemocap_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --number_of_epochs=2 --skip_prep=Tru,
9+
Emotion_recognition,IEMOCAP,benchmarks/MP3S/IEMOCAP/linear/train.py,benchmarks/MP3S/IEMOCAP/linear/hparams/ssl.yaml,benchmarks/MP3S/IEMOCAP/linear/iemocap_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.json --valid_annotation=tests/samples/annotation/ASR_train.json --test_annotation=tests/samples/annotation/ASR_train.json --number_of_epochs=2 --skip_prep=True,
10+
Speaker_recognition,VoxCeleb,benchmarks/MP3S/VoxCeleb1/Xvectors/train.py,benchmarks/MP3S/VoxCeleb1/Xvectors/hparams/ssl.yaml,benchmarks/MP3S/VoxCeleb1/Xvectors/voxceleb_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5 --do_verification=False,
11+
Speaker_recognition,VoxCeleb,benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/train.py,benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/hparams/ssl.yaml,benchmarks/MP3S/VoxCeleb1/ecapa_tdnn/voxceleb_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_annotation=tests/samples/annotation/ASR_train.csv --valid_annotation=tests/samples/annotation/ASR_train.csv --number_of_epochs=2 --skip_prep=True --sentence_len=0.5 --do_verification=False,
12+
ASR,CommonVoice,benchmarks/MP3S/CommonVoice/linear/train.py, benchmarks/MP3S/CommonVoice/linear/hparams/ssl.yaml,benchmarks/MP3S/CommonVoice/linear/common_voice_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,
13+
ASR,CommonVoice,benchmarks/MP3S/CommonVoice/LSTM/train.py, benchmarks/MP3S/CommonVoice/LSTM/hparams/ssl.yaml,benchmarks/MP3S/CommonVoice/LSTM/common_voice_prepare.py,benchmarks/MP3S/README.md,,,--data_folder=tests/samples/ASR/ --train_csv=tests/samples/annotation/ASR_train.csv --valid_csv=tests/samples/annotation/ASR_train.csv --test_csv=tests/samples/annotation/ASR_train.csv --number_of_epochs=1 --skip_prep=True,

0 commit comments

Comments
 (0)