diff --git a/RawBoost.py b/RawBoost.py new file mode 100644 index 0000000..0339f66 --- /dev/null +++ b/RawBoost.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- + +import numpy as np +from scipy import signal +import copy + +''' + Hemlata Tak, Madhu Kamble, Jose Patino, Massimiliano Todisco, Nicholas Evans. + RawBoost: A Raw Data Boosting and Augmentation Method applied to Automatic Speaker Verification Anti-Spoofing. + In Proc. ICASSP 2022, pp:6382--6386. +''' + +def randRange(x1, x2, integer): + y = np.random.uniform(low=x1, high=x2, size=(1,)) + if integer: + y = int(y) + return y + +def normWav(x,always): + if always: + x = x/np.amax(abs(x)) + elif np.amax(abs(x)) > 1: + x = x/np.amax(abs(x)) + return x + + +def genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): + b = 1 + for i in range(0, nBands): + fc = randRange(minF,maxF,0); + bw = randRange(minBW,maxBW,0); + c = randRange(minCoeff,maxCoeff,1); + + if c/2 == int(c/2): + c = c + 1 + f1 = fc - bw/2 + f2 = fc + bw/2 + if f1 <= 0: + f1 = 1/1000 + if f2 >= fs/2: + f2 = fs/2-1/1000 + b = np.convolve(signal.firwin(c, [float(f1), float(f2)], window='hamming', fs=fs),b) + + G = randRange(minG,maxG,0); + _, h = signal.freqz(b, 1, fs=fs) + b = pow(10, G/20)*b/np.amax(abs(h)) + return b + + +def filterFIR(x,b): + N = b.shape[0] + 1 + xpad = np.pad(x, (0, N), 'constant') + y = signal.lfilter(b, 1, xpad) + y = y[int(N/2):int(y.shape[0]-N/2)] + return y + +# Linear and non-linear convolutive noise +def LnL_convolutive_noise(x,N_f,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,minBiasLinNonLin,maxBiasLinNonLin,fs): + y = [0] * x.shape[0] + for i in range(0, N_f): + if i == 1: + minG = minG-minBiasLinNonLin; + maxG = maxG-maxBiasLinNonLin; + b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) + y = y + filterFIR(np.power(x, (i+1)), b) + y = y - np.mean(y) + y = normWav(y,0) + return y + + +# Impulsive signal dependent noise +def ISD_additive_noise(x, P, g_sd): + beta = randRange(0, P, 0) + + y = copy.deepcopy(x) + x_len = x.shape[0] + n = int(x_len*(beta/100)) + p = np.random.permutation(x_len)[:n] + f_r= np.multiply(((2*np.random.rand(p.shape[0]))-1),((2*np.random.rand(p.shape[0]))-1)) + r = g_sd * x[p] * f_r + y[p] = x[p] + r + y = normWav(y,0) + return y + + +# Stationary signal independent noise + +def SSI_additive_noise(x,SNRmin,SNRmax,nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs): + noise = np.random.normal(0, 1, x.shape[0]) + b = genNotchCoeffs(nBands,minF,maxF,minBW,maxBW,minCoeff,maxCoeff,minG,maxG,fs) + noise = filterFIR(noise, b) + noise = normWav(noise,1) + SNR = randRange(SNRmin, SNRmax, 0) + noise = noise / np.linalg.norm(noise,2) * np.linalg.norm(x,2) / 10.0**(0.05 * SNR) + x = x + noise + return x + diff --git a/data_utils_SSL.py b/data_utils_SSL.py new file mode 100644 index 0000000..b45600a --- /dev/null +++ b/data_utils_SSL.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import torch +import torch.nn as nn +from torch import Tensor +import librosa +from torch.utils.data import Dataset +from RawBoost import ISD_additive_noise,LnL_convolutive_noise,SSI_additive_noise,normWav +from random import randrange +import random + + +___author__ = "Hemlata Tak" +__email__ = "tak@eurecom.fr" + + +def genSpoof_list( dir_meta,is_train=False,is_eval=False): + + d_meta = {} + file_list=[] + with open(dir_meta, 'r') as f: + l_meta = f.readlines() + + if (is_train): + for line in l_meta: + _,key,_,_,label = line.strip().split() + + file_list.append(key) + d_meta[key] = 1 if label == 'bonafide' else 0 + return d_meta,file_list + + elif(is_eval): + for line in l_meta: + key= line.strip() + file_list.append(key) + return file_list + else: + for line in l_meta: + _,key,_,_,label = line.strip().split() + + file_list.append(key) + d_meta[key] = 1 if label == 'bonafide' else 0 + return d_meta,file_list + + + +def pad(x, max_len=64600): + x_len = x.shape[0] + if x_len >= max_len: + return x[:max_len] + # need to pad + num_repeats = int(max_len / x_len)+1 + padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] + return padded_x + + +class Dataset_ASVspoof2019_train(Dataset): + def __init__(self,args,list_IDs, labels, base_dir,algo): + '''self.list_IDs : list of strings (each string: utt key), + self.labels : dictionary (key: utt key, value: label integer)''' + + self.list_IDs = list_IDs + self.labels = labels + self.base_dir = base_dir + self.algo=algo + self.args=args + self.cut=64600 # take ~4 sec audio (64600 samples) + + def __len__(self): + return len(self.list_IDs) + + + def __getitem__(self, index): + + utt_id = self.list_IDs[index] + X,fs = librosa.load(self.base_dir+'flac/'+utt_id+'.flac', sr=16000) + Y=process_Rawboost_feature(X,fs,self.args,self.algo) + X_pad= pad(Y,self.cut) + x_inp= Tensor(X_pad) + target = self.labels[utt_id] + + return x_inp, target + + +class Dataset_ASVspoof2021_eval(Dataset): + def __init__(self, list_IDs, base_dir): + '''self.list_IDs : list of strings (each string: utt key), + ''' + + self.list_IDs = list_IDs + self.base_dir = base_dir + self.cut=64600 # take ~4 sec audio (64600 samples) + + def __len__(self): + return len(self.list_IDs) + + + def __getitem__(self, index): + + utt_id = self.list_IDs[index] + X, fs = librosa.load(self.base_dir+'flac/'+utt_id+'.flac', sr=16000) + X_pad = pad(X,self.cut) + x_inp = Tensor(X_pad) + return x_inp,utt_id + + + + +#--------------RawBoost data augmentation algorithms---------------------------## + +def process_Rawboost_feature(feature, sr,args,algo): + + # Data process by Convolutive noise (1st algo) + if algo==1: + + feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) + + # Data process by Impulsive noise (2nd algo) + elif algo==2: + + feature=ISD_additive_noise(feature, args.P, args.g_sd) + + # Data process by coloured additive noise (3rd algo) + elif algo==3: + + feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) + + # Data process by all 3 algo. together in series (1+2+3) + elif algo==4: + + feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, + args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) + feature=ISD_additive_noise(feature, args.P, args.g_sd) + feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF, + args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) + + # Data process by 1st two algo. together in series (1+2) + elif algo==5: + + feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, + args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) + feature=ISD_additive_noise(feature, args.P, args.g_sd) + + + # Data process by 1st and 3rd algo. together in series (1+3) + elif algo==6: + + feature =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, + args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) + feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) + + # Data process by 2nd and 3rd algo. together in series (2+3) + elif algo==7: + + feature=ISD_additive_noise(feature, args.P, args.g_sd) + feature=SSI_additive_noise(feature,args.SNRmin,args.SNRmax,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW,args.minCoeff,args.maxCoeff,args.minG,args.maxG,sr) + + # Data process by 1st two algo. together in Parallel (1||2) + elif algo==8: + + feature1 =LnL_convolutive_noise(feature,args.N_f,args.nBands,args.minF,args.maxF,args.minBW,args.maxBW, + args.minCoeff,args.maxCoeff,args.minG,args.maxG,args.minBiasLinNonLin,args.maxBiasLinNonLin,sr) + feature2=ISD_additive_noise(feature, args.P, args.g_sd) + + feature_para=feature1+feature2 + feature=normWav(feature_para,0) #normalized resultant waveform + + # original data without Rawboost processing + else: + + feature=feature + + return feature diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..528c01e --- /dev/null +++ b/environment.yml @@ -0,0 +1,169 @@ +name: occm +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - blas=1.0=mkl + - ca-certificates=2023.08.22=h06a4308_0 + - intel-openmp=2023.1.0=hdb19cb5_46305 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.4.4=h6a678d5_0 + - libgcc-ng=11.2.0=h1234567_1 + - libgfortran-ng=11.2.0=h00389a5_1 + - libgfortran5=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - mkl=2023.1.0=h213fc3f_46343 + - mkl-service=2.4.0=py39h5eee18b_1 + - mkl_fft=1.3.8=py39h5eee18b_0 + - mkl_random=1.2.4=py39hdb19cb5_0 + - ncurses=6.4=h6a678d5_0 + - openssl=3.0.11=h7f8727e_2 + - pip=23.2.1=py39h06a4308_0 + - python=3.9.16=h955ad1f_3 + - readline=8.2=h5eee18b_0 + - setuptools=68.0.0=py39h06a4308_0 + - sqlite=3.41.2=h5eee18b_0 + - tbb=2021.8.0=hdb19cb5_0 + - tk=8.6.12=h1ccaba5_0 + - wheel=0.41.2=py39h06a4308_0 + - xz=5.4.2=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - aiohttp==3.8.6 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.8 + - appdirs==1.4.4 + - asttokens==2.4.1 + - async-timeout==4.0.3 + - attrs==23.1.0 + - audioread==3.0.1 + - backcall==0.2.0 + - bitarray==2.8.2 + - certifi==2022.12.7 + - cffi==1.16.0 + - chardet==4.0.0 + - charset-normalizer==2.1.1 + - click==8.1.7 + - cloudpickle==3.0.0 + - cmake==3.25.0 + - colorama==0.4.6 + - coloredlogs==15.0.1 + - colorlog==6.7.0 + - contourpy==1.2.0 + - cycler==0.12.1 + - cython==3.0.4 + - decorator==5.1.1 + - demucs==4.0.0 + - denoiser==0.1.5 + - diffq==0.2.4 + - docker-pycreds==0.4.0 + - dora-search==0.1.12 + - einops==0.7.0 + - exceptiongroup==1.1.3 + - executing==2.0.1 + - fairseq==1.0.0a0+a540213 + - ffmpeg-python==0.2.0 + - filelock==3.9.0 + - flatbuffers==23.5.26 + - fonttools==4.44.0 + - frozenlist==1.4.0 + - fsspec==2023.10.0 + - future==0.18.3 + - gitdb==4.0.11 + - gitpython==3.1.40 + - humanfriendly==10.0 + - hydra-colorlog==1.2.0 + - hydra-core==1.0.7 + - idna==2.10 + - importlib-resources==6.1.1 + - iniconfig==2.0.0 + - ipython==8.10.0 + - jedi==0.19.1 + - jinja2==3.1.2 + - joblib==1.3.2 + - julius==0.2.7 + - kiwisolver==1.4.5 + - lameenc==1.6.3 + - lazy-loader==0.3 + - librosa==0.10.0.post2 + - lightning-utilities==0.9.0 + - lit==15.0.7 + - llvmlite==0.41.1 + - lxml==4.9.3 + - markupsafe==2.1.2 + - matplotlib==3.7.1 + - matplotlib-inline==0.1.6 + - mpmath==1.3.0 + - msgpack==1.0.7 + - multidict==6.0.4 + - networkx==3.0 + - numba==0.58.1 + - numpy==1.26.1 + - omegaconf==2.0.6 + - onnxruntime==1.16.1 + - openunmix==1.2.1 + - packaging==23.2 + - pandas==2.0.2 + - parso==0.8.3 + - pesq==0.0.4 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pillow==9.3.0 + - platformdirs==3.11.0 + - pluggy==1.3.0 + - pooch==1.6.0 + - portalocker==2.8.2 + - prompt-toolkit==3.0.39 + - protobuf==4.25.0 + - psutil==5.9.6 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pycparser==2.21 + - pygments==2.16.1 + - pyparsing==3.1.1 + - pystoi==0.3.3 + - pytest==7.3.2 + - python-dateutil==2.8.2 + - pytorch-lightning==2.1.0 + - pytz==2023.3.post1 + - pywavelets==1.4.1 + - pyyaml==6.0.1 + - regex==2023.10.3 + - requests==2.25.1 + - retrying==1.3.4 + - sacrebleu==2.3.1 + - scikit-learn==1.3.1 + - scipy==1.11.3 + - seaborn==0.12.2 + - sentry-sdk==1.37.1 + - setproctitle==1.3.3 + - six==1.16.0 + - smmap==5.0.1 + - sounddevice==0.4.6 + - soundfile==0.12.1 + - soxr==0.3.7 + - stack-data==0.6.3 + - submitit==1.5.0 + - sympy==1.12 + - tabulate==0.9.0 + - threadpoolctl==3.2.0 + - tomli==2.0.1 + - torch==2.0.1+cu118 + - torchattacks==3.5.1 + - torchaudio==2.0.2+cu118 + - torchmetrics==1.2.0 + - torchvision==0.15.2+cu118 + - tqdm==4.65.0 + - traitlets==5.13.0 + - treetable==0.2.5 + - triton==2.0.0 + - typing-extensions==4.4.0 + - tzdata==2023.3 + - urllib3==1.26.13 + - wandb==0.16.0 + - wcwidth==0.2.9 + - yarl==1.9.2 + - zipp==3.17.0 +prefix: /home/longnv/.conda/envs/occm diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..667fd27 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,198 @@ +import argparse +from sklearn.metrics import confusion_matrix +import pandas as pd +import numpy as np +from evaluate_metrics import compute_eer + +def load_metadata(file_path): + """load the complete metadata file the label list + example: + LA_0043 DF_E_2000026 mp3m4a asvspoof A09 spoof notrim eval traditional_vocoder - - - - + + Args: + file_path (str): file path + """ + labels = [] + with open(file_path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + label = line.split(" ")[5] + labels.append(label) + return labels + +def load_metadata_from_proto(meta_file_path, proto_file_path): + """load the subset of metadata file the label list + based on the protocol file + the label list order is the same as the protocol file + example: + LA_0043 DF_E_2000026 mp3m4a asvspoof A09 spoof notrim eval traditional_vocoder - - - - + + Args: + file_path (str): file path + """ + labels = [] + protos = load_proto_file(proto_file_path) + # initialize labels with the same length as the protocol file + for i in range(len(protos)): + labels.append("") + with open(meta_file_path, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + file_name = line.split(" ")[1] + label = line.split(" ")[5] + if file_name in protos: + index = protos.index(file_name) + labels[index] = label + return labels + +def eval_dict(file_path): + """_summary_ + + Given a metadata file for DF eval + eval-package/keys/DF/CM/trial_metadata.txt + Example + LA_0043 DF_E_2000026 mp3m4a asvspoof A09 spoof notrim eval traditional_vocoder - - - - + + Create a dictionary with key = file name, value = label. Where key is the second column and value is the sixth column. + + """ + eval_dict = {} + with open(file_path, "r") as f: + for line in f: + line = line.strip() + file_name = line.split(" ")[1] + label = line.split(" ")[5] + eval_dict[file_name] = label + return eval_dict + +def load_proto_file(file_path): + """load the protocol file which contains the file names only + and return a list of file names + + Args: + file_path (str): path to the protocol file + """ + with open(file_path, "r") as f: + lines = f.readlines() + file_list = [] + for line in lines: + line = line.strip() + file_list.append(line) + return file_list + +def load_score_file(file_path): + """load the score file and return the score list + example: + 0.02207140438258648, 0 + 0.01588536612689495, 1 + + Args: + file_path (str): path to the score file + """ + with open(file_path, "r") as f: + lines = f.readlines() + score_list = [] + for line in lines: + line = line.strip() + score = float(line.split(",")[0]) + score_list.append(score) + return score_list + +def calculate_EER_old(scores, protocol, metadata): + """ + Step: + - protocol file only contains file names + - score file contains scores corresponding to the file names in the protocol file + - metadata is a dictionary with key = file name, value = label + - calculate EER + """ + spoof_scores = [] + bonafide_scores = [] + for file_name in protocol: + score = scores[protocol.index(file_name)] + label = metadata[file_name] + if label == "spoof": + spoof_scores.append(score) + else: + bonafide_scores.append(score) + + spoof_scores = np.array(spoof_scores) + bonafide_scores = np.array(bonafide_scores) + eer, threshold = compute_eer(spoof_scores, bonafide_scores) + # eer, threshold = compute_eer(bonafide_scores, spoof_scores) + print(f"EER = {eer*100.0}, threshold = {threshold}") + +def calculate_EER(scores, labels): + """Calculate EER based on scores and metadata + Since they are already sorted, we can just go through the list and calculate the EER + without the need of protocol file + + Args: + scores (list): scores + metadata (list): labels + """ + spoof_scores = [] + bonafide_scores = [] + for score, label in zip(scores, labels): + if label == "spoof": + spoof_scores.append(score) + else: + bonafide_scores.append(score) + spoof_scores = np.array(spoof_scores) + bonafide_scores = np.array(bonafide_scores) + eer, threshold = compute_eer(spoof_scores, bonafide_scores) + print(f"EER = {eer*100.0}, threshold = {threshold}") + +if __name__=="__main__": + + # args + parser = argparse.ArgumentParser() + parser.add_argument("--score_file", type=str, default="score.txt") + parser.add_argument("--protocol_file", type=str, default="protocol.txt") + parser.add_argument("--metadata_file", type=str, default="metadata.txt") + parser.add_argument("--threshold", type=float, default=0.1) + args = parser.parse_args() + + # load the protocol file, the score file, and the metadata file + proto = load_proto_file(args.protocol_file) + scores = load_score_file(args.score_file) + # metadata = eval_dict(args.metadata_file) + metadata = load_metadata(args.metadata_file) + + # for each file in the protocol file, get the score and the label + # compare the score with the threshold + # if the score is greater than the threshold, the prediction is spoof + # and bonafide otherwise + + # create two lists: one for the labels and one for the predictions + # labels = metadata + labels = load_metadata_from_proto(args.metadata_file, args.protocol_file) + predictions = [] + for i, file_name in enumerate(proto): + score = scores[i] + if score > args.threshold: + # predictions.append("bonafide") # use this for 2-class case + predictions.append("spoof") + + else: + predictions.append("bonafide") + # predictions.append("spoof") # use this for 2-class case + + # number of bona fide and spoof in labels + bona_fide = labels.count("bonafide") + spoof = labels.count("spoof") + print(f"bona fide = {bona_fide}") + print(f"spoof = {spoof}") + + # calculate the confusion matrix + cm = confusion_matrix(labels, predictions) + print(cm) + # print TP, TN, FP, FN + print(f"TP = {cm[0][0]}") + print(f"TN = {cm[1][1]}") + print(f"FP = {cm[0][1]}") + print(f"FN = {cm[1][0]}") + + calculate_EER(scores, labels) diff --git a/losses/custom_loss.py b/losses/custom_loss.py index da6aa09..c6d6961 100644 --- a/losses/custom_loss.py +++ b/losses/custom_loss.py @@ -2,51 +2,59 @@ import torch.nn.functional as F def compactness_loss(batch_embeddings): - """ - Compactness loss. - Calculates the Mahalanobis distance between each pair of samples and sums them up. - Input: - embeddings: tensor of shape (8, 128), where 8 is the number of samples and 128 is the embedding dimension - Process: - only use the first 4 samples (0, 1, 2, 3) to calculate the Mahalanobis distance - they are bona1, bona2, bona3 and bona4 - calculate the Mahalanobis distance between the following pairs - bona1 and bona2 - bona1 and bona3 - bona1 and bona4 - bona3 and bona2 - bona3 and bona4 - - Output: - loss: scalar tensor representing the Mahalanobis distance loss - """ - - # Example: given batch of embeddings (4 samples with 128 dimensions) - # batch_embeddings = torch.rand(4, 128) - - # Define the pairs you want to calculate the Mahalanobis distance for - pairs = [(0, 1), (0, 2), (0, 3), (2, 1), (2, 3)] + """Calculate the Euclidean distance between a bona fide and the mean + of the rest of the bona fides in the batch, then calculate the average + of these distances. Batch has a shape of `torch.Size([12, 128])` - # Get the embeddings of the pairs - batch_embeddings = torch.stack([batch_embeddings[i] for i, j in pairs]) - print(f"batch_embeddings = {batch_embeddings.shape}") - # Compute the sample mean - mean_embedding = torch.mean(batch_embeddings, dim=0) + Args: + batch_embeddings (tensor): embeddings of bona fide batch. + Expected shape [batch_size, embedding_dim]. + """ + distances = [] + # Only 6 bona fides in the batch, so iterate through them. + batch_embeddings = batch_embeddings[:6] + for i in range(len(batch_embeddings)): + bona_fide = batch_embeddings[i] + # Exclude the i-th embedding to calculate the mean of the others. + others_mean = torch.mean(torch.cat((batch_embeddings[:i], batch_embeddings[i+1:]), dim=0), dim=0) + + # Expand dimensions to match for pairwise_distance calculation. + bona_fide = bona_fide.unsqueeze(0) + others_mean = others_mean.unsqueeze(0) + # Calculate and store the distance. + distance = F.pairwise_distance(bona_fide, others_mean, p=2) + distances.append(distance) + + # Convert the list of distances into a tensor and compute the mean. + return torch.mean(torch.cat(distances)) - # Compute the sample covariance matrix - cov_matrix = torch.mm((batch_embeddings - mean_embedding).t(), batch_embeddings - mean_embedding) / (batch_embeddings.size(0) - 1) - # Calculate the Mahalanobis distance for each pair - mahalanobis_distances = [] - for i, j in pairs: - diff = batch_embeddings[i] - batch_embeddings[j] - mahalanobis_distance = torch.mm(torch.mm(diff.unsqueeze(0), torch.inverse(cov_matrix)), diff.unsqueeze(1)) - mahalanobis_distances.append(mahalanobis_distance) +def triplet_loss(batch_embeddings, margin=9.0): + """Calculate triplet loss using Euclidean distance. + Expects batch_embeddings to be ordered as [bona1, bona2, spoof1]. + Uses a default margin of 0.2. - # Sum up the Mahalanobis distances - total_mahalanobis_distance = torch.sum(torch.stack(mahalanobis_distances)) + Args: + batch_embeddings (tensor): embeddings of bona1, bona2, and spoof1. + Expected shape [3, embedding_dim]. + margin (float, optional): Margin by which the distance between the + negative and positive should be greater. + Defaults to 9.0. - return total_mahalanobis_distance + Returns: + torch.Tensor: Triplet loss. + """ + # Calculate pairwise distances + bona2bona = F.pairwise_distance(batch_embeddings[0].unsqueeze(0), + batch_embeddings[1].unsqueeze(0), p=2) + bona2spoof = F.pairwise_distance(batch_embeddings[0].unsqueeze(0), + batch_embeddings[2].unsqueeze(0), p=2) + + # Calculate triplet loss with margin + # Ensure the loss is non-negative + loss = F.relu(bona2bona - bona2spoof + margin) + # .mean() # If multiple triplets, return the mean loss. + return loss def euclidean_distance_loss(batch_embeddings): # Initialize the loss to 0 @@ -78,15 +86,14 @@ def descriptiveness_loss(batch_embeddings, labels): labels: tensor of shape (8,), where each element is either 0 or 1 representing the label of the sample Process: - calculate the cross entropy loss of 8 samples with their corresponding labels + calculate the cross entropy loss of 12 samples with their corresponding labels Output: - loss: sum the cross entropy loss of all 8 samples + loss: sum the cross entropy loss of all 12 samples """ - weight = 2 # Calculate the cross entropy loss for each pair of samples - # loss = torch.sum(F.cross_entropy(batch_embeddings, labels, reduction='none')) + loss = torch.sum(F.cross_entropy(batch_embeddings, labels, reduction='none')) # calculate the cross entropy loss of the first sample - loss_1 = F.cross_entropy(batch_embeddings[0].unsqueeze(0), labels[0].unsqueeze(0), reduction='none') - loss_2 = F.cross_entropy(batch_embeddings[4].unsqueeze(0), labels[4].unsqueeze(0), reduction='none') - return (weight * loss_1 + loss_2)/3 + # loss_1 = F.cross_entropy(batch_embeddings[0].unsqueeze(0), labels[0].unsqueeze(0), reduction='none') + # loss_2 = F.cross_entropy(batch_embeddings[4].unsqueeze(0), labels[4].unsqueeze(0), reduction='none') + return loss / len(batch_embeddings) diff --git a/models/senet.py b/models/senet.py index 3760dc9..ece92c2 100644 --- a/models/senet.py +++ b/models/senet.py @@ -136,9 +136,10 @@ def forward(self, x, eval=False): x = self.avgpool(x).view(x.size()[0], -1) # print(x.shape) - out = self.embedding(x) + com = self.embedding(x) + des = self.classifier(x) # out = self.classifier(out) - return out + return com, des # only use log_softmax if the loss function declared later is NLLLoss # otherwise, just return `out` and use CrossEntropyLoss as the loss function diff --git a/models/sslassist.py b/models/sslassist.py new file mode 100644 index 0000000..ff80121 --- /dev/null +++ b/models/sslassist.py @@ -0,0 +1,607 @@ +import random +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +import fairseq + + +___author__ = "Hemlata Tak" +__email__ = "tak@eurecom.fr" + +############################ +## FOR fine-tuned SSL MODEL +############################ + + +class SSLModel(nn.Module): + def __init__(self,device): + super(SSLModel, self).__init__() + + cp_path = '/datac/longnv/SSL_Anti-spoofing/pretrained/xlsr2_300m.pt' # Change the pre-trained XLSR model path. + model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([cp_path]) + self.model = model[0] + self.device=device + self.out_dim = 1024 + return + + def extract_feat(self, input_data): + + # put the model to GPU if it not there + # if next(self.model.parameters()).device != input_data.device \ + # or next(self.model.parameters()).dtype != input_data.dtype: + # self.model.to(input_data.device, dtype=input_data.dtype) + # self.model.train() + + + if True: + # input should be in shape (batch, length) + if input_data.ndim == 3: + input_tmp = input_data[:, :, 0] + else: + input_tmp = input_data + + # [batch, length, dim] + emb = self.model(input_tmp, mask=False, features_only=True)['x'] + return emb + + +#---------AASIST back-end------------------------# +''' Jee-weon Jung, Hee-Soo Heo, Hemlata Tak, Hye-jin Shim, Joon Son Chung, Bong-Jin Lee, Ha-Jin Yu and Nicholas Evans. + AASIST: Audio Anti-Spoofing Using Integrated Spectro-Temporal Graph Attention Networks. + In Proc. ICASSP 2022, pp: 6367--6371.''' + + +class GraphAttentionLayer(nn.Module): + def __init__(self, in_dim, out_dim, **kwargs): + super().__init__() + + # attention map + self.att_proj = nn.Linear(in_dim, out_dim) + self.att_weight = self._init_new_params(out_dim, 1) + + # project + self.proj_with_att = nn.Linear(in_dim, out_dim) + self.proj_without_att = nn.Linear(in_dim, out_dim) + + # batch norm + self.bn = nn.BatchNorm1d(out_dim) + + # dropout for inputs + self.input_drop = nn.Dropout(p=0.2) + + # activate + self.act = nn.SELU(inplace=True) + + # temperature + self.temp = 1. + if "temperature" in kwargs: + self.temp = kwargs["temperature"] + + def forward(self, x): + ''' + x :(#bs, #node, #dim) + ''' + # apply input dropout + x = self.input_drop(x) + + # derive attention map + att_map = self._derive_att_map(x) + + # projection + x = self._project(x, att_map) + + # apply batch norm + x = self._apply_BN(x) + x = self.act(x) + return x + + def _pairwise_mul_nodes(self, x): + ''' + Calculates pairwise multiplication of nodes. + - for attention map + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, #dim) + ''' + + nb_nodes = x.size(1) + x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) + x_mirror = x.transpose(1, 2) + + return x * x_mirror + + def _derive_att_map(self, x): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = self._pairwise_mul_nodes(x) + # size: (#bs, #node, #node, #dim_out) + att_map = torch.tanh(self.att_proj(att_map)) + # size: (#bs, #node, #node, 1) + att_map = torch.matmul(att_map, self.att_weight) + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _project(self, x, att_map): + x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) + x2 = self.proj_without_att(x) + + return x1 + x2 + + def _apply_BN(self, x): + org_size = x.size() + x = x.view(-1, org_size[-1]) + x = self.bn(x) + x = x.view(org_size) + + return x + + def _init_new_params(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + +class HtrgGraphAttentionLayer(nn.Module): + def __init__(self, in_dim, out_dim, **kwargs): + super().__init__() + + self.proj_type1 = nn.Linear(in_dim, in_dim) + self.proj_type2 = nn.Linear(in_dim, in_dim) + + # attention map + self.att_proj = nn.Linear(in_dim, out_dim) + self.att_projM = nn.Linear(in_dim, out_dim) + + self.att_weight11 = self._init_new_params(out_dim, 1) + self.att_weight22 = self._init_new_params(out_dim, 1) + self.att_weight12 = self._init_new_params(out_dim, 1) + self.att_weightM = self._init_new_params(out_dim, 1) + + # project + self.proj_with_att = nn.Linear(in_dim, out_dim) + self.proj_without_att = nn.Linear(in_dim, out_dim) + + self.proj_with_attM = nn.Linear(in_dim, out_dim) + self.proj_without_attM = nn.Linear(in_dim, out_dim) + + # batch norm + self.bn = nn.BatchNorm1d(out_dim) + + # dropout for inputs + self.input_drop = nn.Dropout(p=0.2) + + # activate + self.act = nn.SELU(inplace=True) + + # temperature + self.temp = 1. + if "temperature" in kwargs: + self.temp = kwargs["temperature"] + + def forward(self, x1, x2, master=None): + ''' + x1 :(#bs, #node, #dim) + x2 :(#bs, #node, #dim) + ''' + #print('x1',x1.shape) + #print('x2',x2.shape) + num_type1 = x1.size(1) + num_type2 = x2.size(1) + #print('num_type1',num_type1) + #print('num_type2',num_type2) + x1 = self.proj_type1(x1) + #print('proj_type1',x1.shape) + x2 = self.proj_type2(x2) + #print('proj_type2',x2.shape) + x = torch.cat([x1, x2], dim=1) + #print('Concat x1 and x2',x.shape) + + if master is None: + master = torch.mean(x, dim=1, keepdim=True) + #print('master',master.shape) + # apply input dropout + x = self.input_drop(x) + + # derive attention map + att_map = self._derive_att_map(x, num_type1, num_type2) + #print('master',master.shape) + # directional edge for master node + master = self._update_master(x, master) + #print('master',master.shape) + # projection + x = self._project(x, att_map) + #print('proj x',x.shape) + # apply batch norm + x = self._apply_BN(x) + x = self.act(x) + + x1 = x.narrow(1, 0, num_type1) + #print('x1',x1.shape) + x2 = x.narrow(1, num_type1, num_type2) + #print('x2',x2.shape) + return x1, x2, master + + def _update_master(self, x, master): + + att_map = self._derive_att_map_master(x, master) + master = self._project_master(x, master, att_map) + + return master + + def _pairwise_mul_nodes(self, x): + ''' + Calculates pairwise multiplication of nodes. + - for attention map + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, #dim) + ''' + + nb_nodes = x.size(1) + x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1) + x_mirror = x.transpose(1, 2) + + return x * x_mirror + + def _derive_att_map_master(self, x, master): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = x * master + att_map = torch.tanh(self.att_projM(att_map)) + + att_map = torch.matmul(att_map, self.att_weightM) + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _derive_att_map(self, x, num_type1, num_type2): + ''' + x :(#bs, #node, #dim) + out_shape :(#bs, #node, #node, 1) + ''' + att_map = self._pairwise_mul_nodes(x) + # size: (#bs, #node, #node, #dim_out) + att_map = torch.tanh(self.att_proj(att_map)) + # size: (#bs, #node, #node, 1) + + att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1) + + att_board[:, :num_type1, :num_type1, :] = torch.matmul( + att_map[:, :num_type1, :num_type1, :], self.att_weight11) + att_board[:, num_type1:, num_type1:, :] = torch.matmul( + att_map[:, num_type1:, num_type1:, :], self.att_weight22) + att_board[:, :num_type1, num_type1:, :] = torch.matmul( + att_map[:, :num_type1, num_type1:, :], self.att_weight12) + att_board[:, num_type1:, :num_type1, :] = torch.matmul( + att_map[:, num_type1:, :num_type1, :], self.att_weight12) + + att_map = att_board + + + + # apply temperature + att_map = att_map / self.temp + + att_map = F.softmax(att_map, dim=-2) + + return att_map + + def _project(self, x, att_map): + x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x)) + x2 = self.proj_without_att(x) + + return x1 + x2 + + def _project_master(self, x, master, att_map): + + x1 = self.proj_with_attM(torch.matmul( + att_map.squeeze(-1).unsqueeze(1), x)) + x2 = self.proj_without_attM(master) + + return x1 + x2 + + def _apply_BN(self, x): + org_size = x.size() + x = x.view(-1, org_size[-1]) + x = self.bn(x) + x = x.view(org_size) + + return x + + def _init_new_params(self, *size): + out = nn.Parameter(torch.FloatTensor(*size)) + nn.init.xavier_normal_(out) + return out + + +class GraphPool(nn.Module): + def __init__(self, k: float, in_dim: int, p: Union[float, int]): + super().__init__() + self.k = k + self.sigmoid = nn.Sigmoid() + self.proj = nn.Linear(in_dim, 1) + self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity() + self.in_dim = in_dim + + def forward(self, h): + Z = self.drop(h) + weights = self.proj(Z) + scores = self.sigmoid(weights) + new_h = self.top_k_graph(scores, h, self.k) + + return new_h + + def top_k_graph(self, scores, h, k): + """ + args + ===== + scores: attention-based weights (#bs, #node, 1) + h: graph data (#bs, #node, #dim) + k: ratio of remaining nodes, (float) + returns + ===== + h: graph pool applied data (#bs, #node', #dim) + """ + _, n_nodes, n_feat = h.size() + n_nodes = max(int(n_nodes * k), 1) + _, idx = torch.topk(scores, n_nodes, dim=1) + idx = idx.expand(-1, -1, n_feat) + + h = h * scores + h = torch.gather(h, 1, idx) + + return h + + + + +class Residual_block(nn.Module): + def __init__(self, nb_filts, first=False): + super().__init__() + self.first = first + + if not self.first: + self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0]) + self.conv1 = nn.Conv2d(in_channels=nb_filts[0], + out_channels=nb_filts[1], + kernel_size=(2, 3), + padding=(1, 1), + stride=1) + self.selu = nn.SELU(inplace=True) + + self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1]) + self.conv2 = nn.Conv2d(in_channels=nb_filts[1], + out_channels=nb_filts[1], + kernel_size=(2, 3), + padding=(0, 1), + stride=1) + + if nb_filts[0] != nb_filts[1]: + self.downsample = True + self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0], + out_channels=nb_filts[1], + padding=(0, 1), + kernel_size=(1, 3), + stride=1) + + else: + self.downsample = False + + + def forward(self, x): + identity = x + if not self.first: + out = self.bn1(x) + out = self.selu(out) + else: + out = x + + #print('out',out.shape) + out = self.conv1(x) + + #print('aft conv1 out',out.shape) + out = self.bn2(out) + out = self.selu(out) + # print('out',out.shape) + out = self.conv2(out) + #print('conv2 out',out.shape) + + if self.downsample: + identity = self.conv_downsample(identity) + + out += identity + #out = self.mp(out) + return out + + +class AModel(nn.Module): + def __init__(self, args, device): + super().__init__() + self.device = device + + # AASIST parameters + filts = [128, [1, 32], [32, 32], [32, 64], [64, 64]] + gat_dims = [64, 32] + pool_ratios = [0.5, 0.5, 0.5, 0.5] + temperatures = [2.0, 2.0, 100.0, 100.0] + + + #### + # create network wav2vec 2.0 + #### + self.ssl_model = SSLModel(self.device) + self.LL = nn.Linear(self.ssl_model.out_dim, 128) + + self.first_bn = nn.BatchNorm2d(num_features=1) + self.first_bn1 = nn.BatchNorm2d(num_features=64) + self.drop = nn.Dropout(0.5, inplace=True) + self.drop_way = nn.Dropout(0.2, inplace=True) + self.selu = nn.SELU(inplace=True) + + # RawNet2 encoder + self.encoder = nn.Sequential( + nn.Sequential(Residual_block(nb_filts=filts[1], first=True)), + nn.Sequential(Residual_block(nb_filts=filts[2])), + nn.Sequential(Residual_block(nb_filts=filts[3])), + nn.Sequential(Residual_block(nb_filts=filts[4])), + nn.Sequential(Residual_block(nb_filts=filts[4])), + nn.Sequential(Residual_block(nb_filts=filts[4]))) + + self.attention = nn.Sequential( + nn.Conv2d(64, 128, kernel_size=(1,1)), + nn.SELU(inplace=True), + nn.BatchNorm2d(128), + nn.Conv2d(128, 64, kernel_size=(1,1)), + + ) + # position encoding + self.pos_S = nn.Parameter(torch.randn(1, 42, filts[-1][-1])) + + self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) + self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0])) + + # Graph module + self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1], + gat_dims[0], + temperature=temperatures[0]) + self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1], + gat_dims[0], + temperature=temperatures[1]) + # HS-GAL layer + self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer( + gat_dims[0], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer( + gat_dims[1], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer( + gat_dims[0], gat_dims[1], temperature=temperatures[2]) + self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer( + gat_dims[1], gat_dims[1], temperature=temperatures[2]) + + # Graph pooling layers + self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3) + self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3) + self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + + self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3) + + self.out_layer = nn.Linear(5 * gat_dims[1], 2) + + def forward(self, x): + #-------pre-trained Wav2vec model fine tunning ------------------------## + x_ssl_feat = self.ssl_model.extract_feat(x.squeeze(-1)) + x = self.LL(x_ssl_feat) #(bs,frame_number,feat_out_dim) + + # post-processing on front-end features + x = x.transpose(1, 2) #(bs,feat_out_dim,frame_number) + x = x.unsqueeze(dim=1) # add channel + x = F.max_pool2d(x, (3, 3)) + x = self.first_bn(x) + x = self.selu(x) + + # RawNet2-based encoder + x = self.encoder(x) + x = self.first_bn1(x) + x = self.selu(x) + + w = self.attention(x) + + #------------SA for spectral feature-------------# + w1 = F.softmax(w,dim=-1) + m = torch.sum(x * w1, dim=-1) + e_S = m.transpose(1, 2) + self.pos_S + + # graph module layer + gat_S = self.GAT_layer_S(e_S) + out_S = self.pool_S(gat_S) # (#bs, #node, #dim) + + #------------SA for temporal feature-------------# + w2 = F.softmax(w,dim=-2) + m1 = torch.sum(x * w2, dim=-2) + + e_T = m1.transpose(1, 2) + + # graph module layer + gat_T = self.GAT_layer_T(e_T) + out_T = self.pool_T(gat_T) + + # learnable master node + master1 = self.master1.expand(x.size(0), -1, -1) + master2 = self.master2.expand(x.size(0), -1, -1) + + # inference 1 + out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11( + out_T, out_S, master=self.master1) + + out_S1 = self.pool_hS1(out_S1) + out_T1 = self.pool_hT1(out_T1) + + out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12( + out_T1, out_S1, master=master1) + out_T1 = out_T1 + out_T_aug + out_S1 = out_S1 + out_S_aug + master1 = master1 + master_aug + + # inference 2 + out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21( + out_T, out_S, master=self.master2) + out_S2 = self.pool_hS2(out_S2) + out_T2 = self.pool_hT2(out_T2) + + out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22( + out_T2, out_S2, master=master2) + out_T2 = out_T2 + out_T_aug + out_S2 = out_S2 + out_S_aug + master2 = master2 + master_aug + + out_T1 = self.drop_way(out_T1) + out_T2 = self.drop_way(out_T2) + out_S1 = self.drop_way(out_S1) + out_S2 = self.drop_way(out_S2) + master1 = self.drop_way(master1) + master2 = self.drop_way(master2) + + out_T = torch.max(out_T1, out_T2) + out_S = torch.max(out_S1, out_S2) + master = torch.max(master1, master2) + + # Readout operation + T_max, _ = torch.max(torch.abs(out_T), dim=1) + T_avg = torch.mean(out_T, dim=1) + + S_max, _ = torch.max(torch.abs(out_S), dim=1) + S_avg = torch.mean(out_S, dim=1) + + emb = last_hidden = torch.cat( + [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1) + + last_hidden = self.drop(last_hidden) + output = self.out_layer(last_hidden) + + return emb, output + +if __name__ == '__main__': + import librosa + + model = AModel(None,"cuda").to("cuda") + audio_file = "/datac/longnv/audio_samples/ADD2023_T2_T_00000000.wav" + audio_data, _ = librosa.load(audio_file, sr=None) + emb, out = model(torch.Tensor(audio_data).unsqueeze(0).to("cuda")) + print(emb.shape) + print(out.shape) diff --git a/oc_classifier.py b/oc_classifier.py index cc97de6..67c0214 100644 --- a/oc_classifier.py +++ b/oc_classifier.py @@ -14,9 +14,11 @@ from models.lcnn import * from models.senet import * from models.xlsr import * +from models.sslassist import * from losses.custom_loss import compactness_loss, descriptiveness_loss, euclidean_distance_loss - +import warnings +warnings.filterwarnings("ignore", category=DeprecationWarning) # to be used with one-class classifier # input is now a raw audio file @@ -85,6 +87,9 @@ def __getitem__(self, idx): """ audio_file = self.file_list[idx] file_path = os.path.join(self.dataset_dir, audio_file + ".flac") + if not os.path.exists(file_path): + file_path = os.path.join(self.dataset_dir, audio_file + ".wav") + feature, _ = librosa.load(file_path, sr=None) feature_tensors = torch.tensor(feature, dtype=torch.float32) @@ -115,6 +120,14 @@ def create_reference_embedding(extractor, encoder, dataloader, device): Returns: torch.Tensor: reference embedding """ + # check whether the reference embedding and threshold already exist + if os.path.exists("reference_embedding.pt") and os.path.exists("threshold.pt"): + print("Loading reference embedding and threshold...") + reference_embedding = torch.load("reference_embedding.pt") + threshold = torch.load("threshold.pt") + return reference_embedding, threshold + + print("Creating a reference embedding...") extractor.eval() encoder.eval() total_embeddings = [] @@ -126,7 +139,7 @@ def create_reference_embedding(extractor, encoder, dataloader, device): target = target.to(device) emb = extractor(data) emb = emb.unsqueeze(1) - emb = encoder(emb) + emb = encoder(emb)[0] total_embeddings.append(emb) # reference embedding is the mean of all embeddings @@ -138,10 +151,61 @@ def create_reference_embedding(extractor, encoder, dataloader, device): total_distances.append(distance) threshold = torch.max(torch.stack(total_distances)) + # save the reference embedding and threshold to a file + torch.save(reference_embedding, "reference_embedding.pt") + torch.save(threshold, "threshold.pt") return reference_embedding, threshold -def score_eval_set(extractor, encoder, dataloader, device, reference_embedding, threshold): - """Score the evaluation set and save the scores to a file +def create_reference_embedding2(model, dataloader, device): + """Create reference embeddings for one-class classifier SSL-AASIST + + Args: + model (nn.Module): pretrained models (e.g., XLSR, SE-ResNet34) + dataloader (DataLoader): dataloader for the dataset + + Returns: + torch.Tensor: reference embedding + """ + # check whether the reference embedding and threshold already exist + if os.path.exists("reference_embedding.pt") and os.path.exists("threshold.pt"): + print("Loading reference embedding and threshold...") + reference_embedding = torch.load("reference_embedding.pt") + threshold = torch.load("threshold.pt") + return reference_embedding, threshold + + print("Creating a reference embedding...") + model.eval() + total_embeddings = [] + total_distances = [] + + with torch.no_grad(): + for _, (data, target) in enumerate(dataloader): + data = data.to(device) + target = target.to(device) + emb, out = model(data) # torch.Size([1, 160]) + total_embeddings.append(emb) + + # reference embedding is the mean of all embeddings + reference_embedding = torch.mean(torch.stack(total_embeddings), dim=0) + + # threshold is the maximum Euclidean distance between the reference embedding and all embeddings + for emb in total_embeddings: + distance = F.pairwise_distance(reference_embedding, emb, p=2) + total_distances.append(distance) + with open("distances.txt", "a") as f: + f.write(f"{float(distance)}\n") + threshold = torch.max(torch.stack(total_distances)) + + # save the reference embedding and threshold to a file + torch.save(reference_embedding, "reference_embedding.pt") + torch.save(threshold, "threshold.pt") + return reference_embedding, threshold + + + +def score_eval_set_1c1(extractor, encoder, dataloader, device, reference_embedding, threshold): + """ONE-CLASS APPROACH: + Score the evaluation set and save the scores to a file These scores will be used to calculate the EER. Args: extractor, encoder (nn.Module): pretrained models (e.g., XLSR, SE-ResNet34) @@ -155,44 +219,107 @@ def score_eval_set(extractor, encoder, dataloader, device, reference_embedding, extractor.eval() encoder.eval() total_embeddings = [] - total_distances = [] - with torch.no_grad(): - for _, (data, target) in enumerate(dataloader): - data = data.to(device) - target = target.to(device) - emb = extractor(data) - emb = emb.unsqueeze(1) - emb = encoder(emb) - total_embeddings.append(emb) - # total_labels.append(target) + # calculate the distance between the reference embedding and all embeddings + # write the scores to a file + # each line contains a score and a label + print("Scoring the evaluation set...") + with open("scores.txt", "w") as f: + with torch.no_grad(): + for idx, (data, target) in enumerate(dataloader): + data = data.to(device) + target = target.to(device) + emb = extractor(data) + emb = emb.unsqueeze(1) + emb = encoder(emb)[0] + # total_labels.append(target) + print(f"Processing file counts: {idx} ...") + distance = F.pairwise_distance(reference_embedding, emb, p=2) + if float(distance) > threshold: + f.write(f"{float(distance)}, 1 \n") + else: + f.write(f"{float(distance)}, 0 \n") + +def score_eval_set_1c2(model, dataloader, device, reference_embedding, threshold): + """ ONE-CLASS APPROACH with SSL-AASIST + + """ + model.eval() + total_embeddings = [] # calculate the distance between the reference embedding and all embeddings # write the scores to a file # each line contains a score and a label + print("Scoring the evaluation set...") + with open("scores.txt", "w") as f: + with torch.no_grad(): + for idx, (data, target) in enumerate(dataloader): + data = data.to(device) + target = target.to(device) + emb, out = model(data) + print(f"Processing file counts: {idx} ...") + distance = F.pairwise_distance(reference_embedding, emb, p=2) + if float(distance) > threshold: + f.write(f"{float(distance)}, 1 \n") + else: + f.write(f"{float(distance)}, 0 \n") + + +def score_eval_set_2c1(extractor, encoder, dataloader, device): + """TWO-CLASS APPROACH: + Score the evaluation set and save the scores to a file + These scores will be used to calculate the EER. + + Args: + extractor, encoder (nn.Module): pretrained models (e.g., XLSR, SE-ResNet34) + dataloader (DataLoader): dataloader for the dataset + Returns: + float: scores saved to a file + """ + + extractor.eval() + encoder.eval() with open("scores.txt", "w") as f: - for emb in total_embeddings: - distance = F.pairwise_distance(reference_embedding, emb, p=2) - total_distances.append(distance) - if float(distance) > threshold: - f.write(f"{float(distance)}, 1 \n") - else: - f.write(f"{float(distance)}, 0 \n") - - # calculate the EER - # total_distances = torch.stack(total_distances) - # total_labels = torch.stack(total_labels) - # total_labels = total_labels.squeeze(1) - # total_labels = total_labels.cpu().numpy() - # total_distances = total_distances.squeeze(1) - # total_distances = total_distances.cpu().numpy() - # fpr, tpr, thresholds = roc_curve(total_labels, total_distances, pos_label=1) - # eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.) - - # return eer + with torch.no_grad(): + for idx, (data, target) in enumerate(dataloader): + data = data.to(device) + target = target.to(device) + emb = extractor(data) + emb = emb.unsqueeze(1) + emb = encoder(emb)[1][0][0] # descriptiveness component, batch, bona fide + print(f"Processing file counts: {idx} ...") + f.write(f"{float(emb)}\n") + +def score_eval_set_2c2(model, dataloader, device): + """TWO-CLASS APPROACH SSL-AASIST: + Score the evaluation set and save the scores to a file + These scores will be used to calculate the EER. + + Args: + model (nn.Module): pretrained models (e.g., ssl-aasist) + dataloader (Dataloader): dataloader for the dataset + """ + model.eval() + with open("scores.txt", "w") as f: + with torch.no_grad(): + for idx, (data, target) in enumerate(dataloader): + data = data.to(device) + target = target.to(device) + emb, out = model(data) + print(f"out = {out}") + out = out[0][0] + print(f"Processing file counts: {idx} ...") + f.write(f"{float(out)}\n") + if __name__== "__main__": parser = argparse.ArgumentParser(description='One-class classifier') + parser.add_argument('--pretrained-sslaasist', type=str, default="/datac/longnv/occm/aasist_vocoded_1.pt", + help='Path to the pretrained weights') + parser.add_argument('--pretrained-ssl', type=str, default="/datac/longnv/occm/ssl_triplet_1.pt", + help='Path to the pretrained weights') + parser.add_argument('--pretrained-senet', type=str, default="/datac/longnv/occm/senet34_triplet_1.pt", + help='Path to the pretrained weights') parser.add_argument('--protocol_file', type=str, default="/datab/Dataset/ASVspoof/LA/ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt", help='Path to the protocol file') parser.add_argument('--dataset_dir', type=str, default="/datab/Dataset/ASVspoof/LA/ASVspoof2019_LA_train/flac", @@ -203,29 +330,32 @@ def score_eval_set(extractor, encoder, dataloader, device, reference_embedding, help='Path to the dataset directory') args = parser.parse_args() - # initialize xlsr and lcnn models + # initialize models device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - ssl = SSLModel(device) - senet = se_resnet34().to(device) + aasist = AModel(None, device).to(device) + # ssl = SSLModel(device) + # senet = se_resnet34().to(device) # load pretrained weights - ssl.load_state_dict(torch.load("/datac/longnv/occm/ssl_4.pt")) - senet.load_state_dict(torch.load("/datac/longnv/occm/senet34_4.pt")) - senet = DataParallel(senet) - ssl = DataParallel(ssl) + aasist.load_state_dict(torch.load(args.pretrained_sslaasist)) + # ssl.load_state_dict(torch.load(args.pretrained_ssl)) + # senet.load_state_dict(torch.load(args.pretrained_senet)) + aasist = DataParallel(aasist) + # senet = DataParallel(senet) + # ssl = DataParallel(ssl) print("Pretrained weights loaded") # create a reference embedding & find a threshold - print("Creating a reference embedding...") train_dataset = ASVDataset(args.protocol_file, args.dataset_dir) train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, num_workers=0) - reference_embedding, threshold = create_reference_embedding(ssl, senet, train_dataloader, device) - - print(f"reference_embedding.shape = {reference_embedding.shape}") - print(f"threshold = {threshold}") + # reference_embedding, threshold = create_reference_embedding(ssl, senet, train_dataloader, device) + reference_embedding, threshold = create_reference_embedding2(aasist, train_dataloader, device) # score the evaluation set - print("Scoring the evaluation set...") eval_dataset = ASVDataset(args.eval_protocol_file, args.eval_dataset_dir, eval=True) eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=False, num_workers=0) - score_eval_set(ssl, senet, eval_dataloader, device, reference_embedding, threshold) \ No newline at end of file + # score_eval_set(ssl, senet, eval_dataloader, device, reference_embedding, threshold) + score_eval_set_1c2(aasist, eval_dataloader, device, reference_embedding, threshold) + # score_eval_set_2c2(aasist, eval_dataloader, device) + + print(f"threshold = {threshold}") diff --git a/oc_training.py b/oc_training.py new file mode 100644 index 0000000..37562c4 --- /dev/null +++ b/oc_training.py @@ -0,0 +1,402 @@ + +import os +import wandb +import argparse +from collections import defaultdict +import librosa +import random +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim +from torch.nn import DataParallel +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DataLoader, random_split +from sklearn.utils.class_weight import compute_class_weight + +from models.lcnn import * +from models.senet import * +from models.xlsr import * +from models.sslassist import * + +from losses.custom_loss import compactness_loss, descriptiveness_loss, euclidean_distance_loss + + +import torch.nn.functional as F +from torchattacks import PGD +from torch.utils.data import Dataset +from data_utils_SSL import process_Rawboost_feature + + +class PFDataset(Dataset): + def __init__(self, protocol_file, dataset_dir): + """ + Protocol file for LA train + /datab/Dataset/ASVspoof/LA/ASVspoof_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt + Example + LA_0079 LA_T_1138215 - - bonafide + + Protocol files for DF eval + eval-package/keys/DF/CM/trial_metadata.txt + Example + LA_0043 DF_E_2000026 mp3m4a asvspoof A09 spoof notrim eval traditional_vocoder - - - - + + + Args: + dataset_dir (str): wav file directory + extract_func (none): raw audio file + preprocessed (True, optional): extract directly from raw audio files. Defaults to True. + """ + self.protocol_file = protocol_file + self.file_list = [] + self.label_list = [] + self.dataset_dir = dataset_dir + + # file_list is now the second column of the protocol file + # label list is now the fifth column of the protocol file + # read the protocol file + + with open(self.protocol_file, "r") as f: + lines = f.readlines() + for line in lines: + line = line.strip() + line = line.split(" ") + self.file_list.append(line[1]) + self.label_list.append(line[4]) + + # Caching the indices of each label for quick access + self.spoof_indices = [i for i, label in enumerate(self.label_list) if label == 'spoof'] + self.bonafide_indices = [i for i, label in enumerate(self.label_list) if label == 'bonafide'] + self._length = len(self.bonafide_indices) + # self._denoiser = DeNoise() + self._vocoded_dir = "/datab/Dataset/ASVspoof/LA/ASVspoof2019_LA_vocoded" + # self.args = self._rawboost_args() + + def _rawboost_args(self): + """Initialize params for args + """ + parser = argparse.ArgumentParser(description='ASVspoof2021 baseline system') + parser.add_argument('--algo', type=int, default=3, + help='Rawboost algos discriptions. 0: No augmentation 1: LnL_convolutive_noise, 2: ISD_additive_noise, 3: SSI_additive_noise, 4: series algo (1+2+3), \ + 5: series algo (1+2), 6: series algo (1+3), 7: series algo(2+3), 8: parallel algo(1,2) .default=0]') + + # LnL_convolutive_noise parameters + parser.add_argument('--nBands', type=int, default=5, + help='number of notch filters.The higher the number of bands, the more aggresive the distortions is.[default=5]') + parser.add_argument('--minF', type=int, default=20, + help='minimum centre frequency [Hz] of notch filter.[default=20] ') + parser.add_argument('--maxF', type=int, default=8000, + help='maximum centre frequency [Hz] ( outputs[0] in case of AngleLoss - if args.model == "lcnn_net_asoftmax": - _, predicted = torch.max(outputs_lcnn[0].data, 1) - else: - _, predicted = torch.max(outputs_lcnn.data, 1) - total_train += labels.size(0) - correct_train += (predicted == labels).sum().item() - # Print statistics running_loss += loss.item() + running_closs += c_loss.item() + running_dloss += d_loss.item() if i % 100 == 99: - print(f"[{epoch + 1}, {i + 1}] Train Loss: {running_loss / (i+1):.3f}, \ - Train Acc: {(correct_train / total_train) * 100:.2f}") - # write the loss to a file - with open("loss.txt", "a") as f: - f.write(f"epoch = {epoch}-{i}, loss = {float(loss)}, c_loss = {float(c_loss)}, d_loss = {float(d_loss)} \n") + print(f"[{epoch + 1}, {i + 1}] Train Loss: {running_loss / (i+1):.3f}") + with open("loss.txt", "a") as f: + # write loss, running_closs, running_dloss to a file + f.write(f"epoch = {epoch + 1}, i = {i + 1}, loss = {running_loss / (i+1):.3f}, closs = {running_closs / (i+1):.3f}, dloss = {running_dloss / (i+1):.3f} \n") + wandb.log({"Epoch": epoch, "Train Loss": running_loss / (i+1), "Train Compactness Loss": running_closs / (i+1), "Train Descriptiveness Loss": running_dloss / (i+1)}) # save the models after each epoch print("Saving the models...") - torch.save(ssl.module.state_dict(), f"ssl_{epoch}.pt") - torch.save(senet34.module.state_dict(), f"senet34_{epoch}.pt") - torch.save(lcnn.module.state_dict(), f"lcnn_{epoch}.pt") - - - - # # Validation phase - # model.eval() # Set the model to evaluation mode - # correct_val = 0 - # total_val = 0 - # val_loss = 0.0 - - # correct_test = 0 - # total_test = 0 - # test_loss = 0.0 - - # print("Evaluating on test set...") - # with torch.no_grad(): - # for data in test_dataloader: - # inputs, labels = data[0].to(device), data[1].to(device) - # inputs = inputs.to(torch.float32) - # if not args.finetuned: - # inputs = inputs.unsqueeze(1) # Add channel dimension - # # Forward pass - # outputs = model(inputs) - - # # Calculate the loss - # loss = criterion(outputs, labels) - # test_loss += loss.item() - - # # Calculate test accuracy - # # outputs -> outputs[0] in case of AngleLoss - # if args.model == "lcnn_net_asoftmax": - # _, predicted = torch.max(outputs[0].data, 1) - # else: - # _, predicted = torch.max(outputs.data, 1) - # total_test += labels.size(0) - # correct_test += (predicted == labels).sum().item() - - # # Calculate average training loss and accuracy for the epoch - # avg_train_loss = running_loss / len(train_dataloader) - # avg_train_acc = (correct_train / total_train) * 100 - - # print("***********************************************") - # print(f"Train Loss: {avg_train_loss:.3f}, Train Acc: {avg_train_acc:.2f}") - # print(f"Test Loss: {test_loss / len(test_dataloader):.3f}, Test Acc: {(correct_test / total_test) * 100:.2f}") - # test_acc = (correct_test / total_test) * 100 - # # Save the best and the latest model only - # print("Saving the best model...") - # if test_acc > best_test_acc: - # best_test_acc = test_acc - # torch.save(model.module.state_dict(), args.model + "_best.pt") \ No newline at end of file + torch.save(ssl.module.state_dict(), f"ssl_vocoded_{epoch}.pt") + torch.save(senet34.module.state_dict(), f"senet34_vocoded_{epoch}.pt") + # torch.save(lcnn.module.state_dict(), f"lcnn_{epoch}.pt") \ No newline at end of file diff --git a/test_dataloader_v5.py b/test_dataloader_v5.py deleted file mode 100644 index bf7c9e4..0000000 --- a/test_dataloader_v5.py +++ /dev/null @@ -1,210 +0,0 @@ - -import os -import argparse -import pandas as pd -from collections import defaultdict -from sklearn.utils.class_weight import compute_class_weight -import torch -import torch.nn as nn -import torch.optim as optim -from torch.optim.lr_scheduler import StepLR -from torch.nn import DataParallel -from torch.utils.data import DataLoader, random_split - - -from evaluate_metrics import compute_eer -# from preprocess_data_dsp import PFDataset -from preprocess_data_xlsr import PFDataset -# from preprocess_data_xlsr_finetuned import PFDataset -from models.lcnn import * -from models.cnn import * -from models.senet import * -from utils import * - -# Evaluate only, calculate EER -# Additional dimension is due to the label, see `preprocess_data_xlsr.py` - -# Arguments -print("Arguments...") -parser = argparse.ArgumentParser(description='Train a model on a dataset') -parser.add_argument('--dataset_dir', type=str, default="./lfcc_train", - help='Path to the dataset directory') -parser.add_argument('--test_dataset_dir', type=str, default="./lfcc_test", - help='Path to the test dataset directory') -parser.add_argument('--extract_func', type=str, default="none", - help='Name of the function to extract features from the dataset') -parser.add_argument('--model', type=str, default="lcnn_net") - -# in case of eval -parser.add_argument('--eval', action='store_true', default=False) -parser.add_argument('--pretrained_model', type=str, default="model.pt") - -# in case of finetuned, dataset_dir is the raw audio file directory instead of the extracted feature directory -parser.add_argument('--finetuned', action='store_true', default=False) -parser.add_argument('--train_protocol_file', type=str, default="./database/protocols/PartialSpoof_LA_cm_protocols/PartialSpoof.LA.cm.train.trl.txt") -parser.add_argument('--test_protocol_file', type=str, default="./database/protocols/PartialSpoof_LA_cm_protocols/PartialSpoof.LA.cm.dev.trl.txt") -args = parser.parse_args() - - -print("collate_fn...") -# collate function -def collate_fn(batch): - max_width = max(features.shape[1] for features, _, _ in batch) - max_height = max(features.shape[0] for features, _, _ in batch) - padded_batch_features = [] - for features, _, _ in batch: - pad_width = max_width - features.shape[1] - pad_height = max_height - features.shape[0] - padded_features = F.pad(features, (0, pad_width, 0, pad_height), mode='constant', value=0) - padded_batch_features.append(padded_features) - - labels = torch.tensor([label for _, label, _ in batch]) - fnames = [fname for _, _, fname in batch] - padded_batch_features = torch.stack(padded_batch_features, dim=0) - return padded_batch_features, labels, fnames - -def collate_fn_total(batch): - """pad the time series 1D""" - max_width = max(features.shape[0] for features, _ in batch) - padded_batch_features = [] - for features, _ in batch: - pad_width = max_width - features.shape[0] - padded_features = F.pad(features, (0, pad_width), mode='constant', value=0) - padded_batch_features.append(padded_features) - - labels = torch.tensor([label for _, label in batch]) - - padded_batch_features = torch.stack(padded_batch_features, dim=0) - return padded_batch_features, labels - - -# Load the dataset -print("*************************************************") -print(f"Dataset dir = {args.dataset_dir}") -print(f"Test dataset dir = {args.test_dataset_dir}") -print(f"extract_func = {args.extract_func}") -print(f"model = {args.model}") -print(f"finetuned = {args.finetuned}") -print(f"train_protocol_file = {args.train_protocol_file}") -print(f"test_protocol_file = {args.test_protocol_file}") -print("*************************************************") - -# Define the collate function -if args.finetuned: - train_dataset = PFDataset(args.train_protocol_file, dataset_dir=args.dataset_dir) - test_dataset = PFDataset(args.test_protocol_file, dataset_dir=args.test_dataset_dir) - collate_func = collate_fn_total -else: - # train_dataset = PFDataset(dataset_dir=args.dataset_dir, extract_func=args.extract_func) - test_dataset = PFDataset(dataset_dir=args.test_dataset_dir, extract_func=args.extract_func, eval=args.eval) - collate_func = collate_fn - -# Create dataloaders for training and validation -batch_size = 128 - -print("Creating dataloaders...") -# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, # for wav2vec2 -# collate_fn=collate_func) -test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, - collate_fn=collate_func) - -print("Instantiating model...") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -if args.model == "lcnn_net": - model = lcnn_net(asoftmax=False).to(device) -elif args.model == "lcnn_net_asoftmax": - model = lcnn_net(asoftmax=True).to(device) -elif args.model == "cnn_net": - model = cnn_net().to(device) -elif args.model == "cnn_net_with_attention": - model = cnn_net_with_attention().to(device) -elif args.model == "se_resnet12": - model = se_resnet12().to(device) -elif args.model == "se_resnet34": - model = se_resnet34().to(device) -elif args.model == "total_cnn_net": - model = total_cnn_net(device).to(device) -elif args.model == "total_resnet34": - model = total_resnet34(device).to(device) - -# Set the model to evaluation mode -model.eval() -# load model from the best model -print("Loading model weights...") -model.load_state_dict(torch.load(args.pretrained_model), strict=True) -model = DataParallel(model) - -# Define the loss function and optimizer -criterion = nn.CrossEntropyLoss() -if args.model == "lcnn_net_asoftmax": - criterion = AngleLoss() -# Validation phase - -correct_test = 0 -total_test = 0 -test_loss = 0.0 -for name, param in model.named_parameters(): - if torch.isnan(param).any(): - print(f"NaN values found in parameter: {name}") -print("Evaluating on test set...") - -score_file = args.model + "_" + os.path.basename(args.test_dataset_dir).split("_")[1] + "_eval_scores.txt" -with torch.no_grad(): - for data in test_dataloader: - inputs, labels, fnames = data[0].to(device), data[1].to(device), data[2] - inputs = inputs.to(torch.float32) - if not args.finetuned: - inputs = inputs.unsqueeze(1) # Add channel dimension - # Forward pass - outputs = model(inputs) - - # Calculate the loss - loss = criterion(outputs, labels) - test_loss += loss.item() - - # Calculate test accuracy - # outputs -> outputs[0] in case of AngleLoss - if args.model == "lcnn_net_asoftmax": - _, predicted = torch.max(outputs[0].data, 1) - else: - _, predicted = torch.max(outputs.data, 1) - total_test += labels.size(0) - correct_test += (predicted == labels).sum().item() - - with open(score_file, "a") as f: - for i in range(len(fnames)): - if args.model == "lcnn_net_asoftmax": - f.write(f"{fnames[i]} {outputs[0][i][0]}\n") - else: - f.write(f"{fnames[i]} {outputs[i][0]}\n") - - -print("***********************************************") -print(f"Test Loss: {test_loss / len(test_dataloader):.3f}, Test Acc: {(correct_test / total_test) * 100:.2f}") -print("***********************************************") - - - -def calculate_EER(score_file=score_file): - """ - Step: - - load protocol file - - load score file - - calculate EER - """ - pro_columns = ["sid", "utt","phy", "attack", "label"] - eval_protocol_file = pd.read_csv(args.test_protocol_file, sep=" ", header=None) - eval_protocol_file.columns = pro_columns - - score_file = pd.read_csv(score_file, sep=" ", header=None) - score_file.columns = ["utt", "score"] - - res = pd.merge(eval_protocol_file, score_file, on="utt") - spoof_scores = res[res["label"] == "spoof"]["score"].values - bonafide_scores = res[res["label"] == "bonafide"]["score"].values - - eer, threshold = compute_eer(bonafide_scores, spoof_scores) - print(f"EER = {eer*100.0}, threshold = {threshold}") - -calculate_EER(score_file=score_file) \ No newline at end of file