Skip to content

Commit

Permalink
calculate scores from eval set, todo: audioread deprecated in librosa
Browse files Browse the repository at this point in the history
  • Loading branch information
Long Nguyen-Vu committed Nov 12, 2023
1 parent eb251a6 commit 13dcd99
Showing 1 changed file with 82 additions and 22 deletions.
104 changes: 82 additions & 22 deletions oc_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ def __init__(self, protocol_file, dataset_dir, eval=False):
LA_0079 LA_T_1138215 - - bonafide
Protocol files for DF eval
/datab/Dataset/ASVspoof/LA/ASVspoof_DF_cm_protocols/ASVspoof2021.DF.cm.eval.trl.txt
the file only contains the file names
metadata file for DF eval
eval-package/keys/DF/CM/trial_metadata.txt
Example
LA_0043 DF_E_2000026 mp3m4a asvspoof A09 spoof notrim eval traditional_vocoder - - - -
Args:
dataset_dir (str): wav file directory
extract_func (none): raw audio file
Expand All @@ -58,9 +61,8 @@ def __init__(self, protocol_file, dataset_dir, eval=False):
for line in lines:
line = line.strip()
line = line.split(" ")
self.file_list.append(line[1])
self.label_list.append(line[5]) # bonafide or spoof
self._length = len(self.file_list)
self.file_list.append(line[0])
self.label_list.append("unknown") # bonafide or spoof
else:
# collect bona fide list only
# for calculating `reference embedding`
Expand All @@ -72,8 +74,8 @@ def __init__(self, protocol_file, dataset_dir, eval=False):
if line[4] == "bonafide":
self.file_list.append(line[1])
self.label_list.append(line[4]) # bonafide only
self._length = len(self.file_list)

self._length = len(self.file_list)

def __len__(self):
return self._length
Expand All @@ -84,15 +86,19 @@ def __getitem__(self, idx):
audio_file = self.file_list[idx]
file_path = os.path.join(self.dataset_dir, audio_file + ".flac")
feature, _ = librosa.load(file_path, sr=None)
label = [1 if self.label_list[idx] == "spoof" else 0]

# Convert the list of features and labels to tensors
feature_tensors = torch.tensor(feature, dtype=torch.float32)
label_tensors = torch.tensor(label, dtype=torch.int64)
# print(f"feature_tensors.shape = {feature_tensors.shape}")
# print(f"label_tensors.shape = {label_tensors.shape}")

if self.eval == False:
label = [1 if self.label_list[idx] == "spoof" else 0]
# Convert the list of features and labels to tensors
label_tensors = torch.tensor(label, dtype=torch.int64)
return feature_tensors, label_tensors

if self.eval:
label = [1 if self.label_list[idx] == "spoof" else 0] # fake label
label_tensors = torch.tensor(label, dtype=torch.int64)
return feature_tensors, label_tensors

return feature_tensors, label_tensors

def collate_fn(self, batch):
"""pad the time series 1D"""
Expand Down Expand Up @@ -134,12 +140,68 @@ def create_reference_embedding(extractor, encoder, dataloader, device):

return reference_embedding, threshold

def score_eval_set(extractor, encoder, dataloader, device, reference_embedding, threshold):
"""Score the evaluation set and save the scores to a file
These scores will be used to calculate the EER.
Args:
extractor, encoder (nn.Module): pretrained models (e.g., XLSR, SE-ResNet34)
dataloader (DataLoader): dataloader for the dataset
reference_embedding (torch.Tensor): reference embedding
threshold (float): threshold
Returns:
float: scores saved to a file
"""
extractor.eval()
encoder.eval()
total_embeddings = []
total_distances = []

with torch.no_grad():
for _, (data, target) in enumerate(dataloader):
data = data.to(device)
target = target.to(device)
emb = extractor(data)
emb = emb.unsqueeze(1)
emb = encoder(emb)
total_embeddings.append(emb)
# total_labels.append(target)

# calculate the distance between the reference embedding and all embeddings
for emb in total_embeddings:
distance = F.pairwise_distance(reference_embedding, emb, p=2)
total_distances.append(distance)
# write the scores to a file
# each line contains a score and a label
with open("scores.txt", "w") as f:
for distance in total_distances:
if float(distance) > threshold:
f.write(f"{float(distance)}, 1 \n")
else:
f.write(f"{float(distance)}, 0 \n")

# calculate the EER
# total_distances = torch.stack(total_distances)
# total_labels = torch.stack(total_labels)
# total_labels = total_labels.squeeze(1)
# total_labels = total_labels.cpu().numpy()
# total_distances = total_distances.squeeze(1)
# total_distances = total_distances.cpu().numpy()
# fpr, tpr, thresholds = roc_curve(total_labels, total_distances, pos_label=1)
# eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)

# return eer

if __name__== "__main__":
parser = argparse.ArgumentParser(description='One-class classifier')
parser.add_argument('--protocol_file', type=str, default="/datab/Dataset/ASVspoof/LA/ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt",
help='Path to the protocol file')
parser.add_argument('--dataset_dir', type=str, default="/datab/Dataset/ASVspoof/LA/ASVspoof2019_LA_train/flac",
help='Path to the dataset directory')
parser.add_argument('--eval_protocol_file', type=str, default="/datab/Dataset/ASVspoof/LA/ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt",
help='Path to the protocol file')
parser.add_argument('--eval_dataset_dir', type=str, default="/datab/Dataset/ASVspoof/LA/ASVspoof2019_LA_eval/flac",
help='Path to the dataset directory')
args = parser.parse_args()

# initialize xlsr and lcnn models
Expand All @@ -148,8 +210,8 @@ def create_reference_embedding(extractor, encoder, dataloader, device):
senet = se_resnet34().to(device)

# load pretrained weights
ssl.load_state_dict(torch.load("/datac/longnv/occm/ssl_0.pt"))
senet.load_state_dict(torch.load("/datac/longnv/occm/senet34_0.pt"))
ssl.load_state_dict(torch.load("/datac/longnv/occm/ssl_1.pt"))
senet.load_state_dict(torch.load("/datac/longnv/occm/senet34_1.pt"))
senet = DataParallel(senet)
ssl = DataParallel(ssl)
print("Pretrained weights loaded")
Expand All @@ -163,10 +225,8 @@ def create_reference_embedding(extractor, encoder, dataloader, device):
print(f"reference_embedding.shape = {reference_embedding.shape}")
print(f"threshold = {threshold}")

# audio_file = "/datac/longnv/audio_samples/ADD2023_T2_T_00000000.wav"
# audio_data, _ = librosa.load(audio_file, sr=None)
# emb = ssl(torch.Tensor(audio_data).unsqueeze(0).to("cuda"))
# emb = emb.unsqueeze(1)
# emb = senet(emb)


# score the evaluation set
print("Scoring the evaluation set...")
eval_dataset = ASVDataset(args.eval_protocol_file, args.eval_dataset_dir, eval=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=0)
score_eval_set(ssl, senet, eval_dataloader, device, reference_embedding, threshold)

0 comments on commit 13dcd99

Please sign in to comment.