Skip to content

Commit

Permalink
freeze pretrained kor w2v
Browse files Browse the repository at this point in the history
  • Loading branch information
nguyenvulong committed Dec 26, 2023
1 parent f1b45f6 commit 23279a0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
12 changes: 6 additions & 6 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def load_metadata(file_path):
lines = f.readlines()
for line in lines:
line = line.strip()
label = line.split(" ")[5]
label = line.split(" ")[2]
labels.append(label)
return labels

Expand All @@ -40,8 +40,8 @@ def load_metadata_from_proto(meta_file_path, proto_file_path):
lines = f.readlines()
for line in lines:
line = line.strip()
file_name = line.split(" ")[1]
label = line.split(" ")[5]
file_name = line.split(" ")[0]
label = line.split(" ")[2]
if file_name in protos:
index = protos.index(file_name)
labels[index] = label
Expand Down Expand Up @@ -167,8 +167,8 @@ def calculate_EER(scores, labels):
# and bonafide otherwise

# create two lists: one for the labels and one for the predictions
# labels = metadata
labels = load_metadata_from_proto(args.metadata_file, args.protocol_file)
labels = metadata
# labels = load_metadata_from_proto(args.metadata_file, args.protocol_file)
predictions = []
for i, file_name in enumerate(proto):
score = scores[i]
Expand All @@ -194,5 +194,5 @@ def calculate_EER(scores, labels):
print(f"TN = {cm[1][1]}")
print(f"FP = {cm[0][1]}")
print(f"FN = {cm[1][0]}")

calculate_EER(scores, labels)
8 changes: 5 additions & 3 deletions oc_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ def __init__(self, protocol_file, dataset_dir, eval=False):
for line in lines:
line = line.strip()
line = line.split(" ")
if line[4] == "bonafide":
self.file_list.append(line[1])
self.label_list.append(line[4]) # bonafide only
self.file_list.append(line[0])
self.label_list.append("bonafide")
# if line[4] == "bonafide":
# self.file_list.append(line[1])
# self.label_list.append(line[4]) # bonafide only

self._length = len(self.file_list)

Expand Down
8 changes: 7 additions & 1 deletion oc_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def collate_fn(self, batch):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

aasist = AModel(None, device).to(device)
aasist.train()

aasist.ssl_model.eval()
for param in aasist.ssl_model.parameters():
param.requires_grad = False
# ssl = SSLModel(device)
# senet34 = se_resnet34().to(device)
# lcnn = lcnn_net(asoftmax=False).to(device)
Expand Down Expand Up @@ -349,7 +354,8 @@ def collate_fn(self, batch):
print(f"Epoch {epoch + 1}\n-------------------------------")

# Training phase
aasist.train()

#aasist.train()
# ssl.eval()
# senet34.train()
# lcnn.train()
Expand Down

0 comments on commit 23279a0

Please sign in to comment.