Skip to content

Commit

Permalink
calculate scores from eval set, small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Long Nguyen-Vu committed Nov 12, 2023
1 parent 13dcd99 commit 736143f
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions oc_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,17 @@ def score_eval_set(extractor, encoder, dataloader, device, reference_embedding,
# total_labels.append(target)

# calculate the distance between the reference embedding and all embeddings
for emb in total_embeddings:
distance = F.pairwise_distance(reference_embedding, emb, p=2)
total_distances.append(distance)
# write the scores to a file
# each line contains a score and a label
with open("scores.txt", "w") as f:
for distance in total_distances:
for emb in total_embeddings:
distance = F.pairwise_distance(reference_embedding, emb, p=2)
total_distances.append(distance)
if float(distance) > threshold:
f.write(f"{float(distance)}, 1 \n")
else:
f.write(f"{float(distance)}, 0 \n")

# calculate the EER
# total_distances = torch.stack(total_distances)
# total_labels = torch.stack(total_labels)
Expand Down Expand Up @@ -210,8 +209,8 @@ def score_eval_set(extractor, encoder, dataloader, device, reference_embedding,
senet = se_resnet34().to(device)

# load pretrained weights
ssl.load_state_dict(torch.load("/datac/longnv/occm/ssl_1.pt"))
senet.load_state_dict(torch.load("/datac/longnv/occm/senet34_1.pt"))
ssl.load_state_dict(torch.load("/datac/longnv/occm/ssl_4.pt"))
senet.load_state_dict(torch.load("/datac/longnv/occm/senet34_4.pt"))
senet = DataParallel(senet)
ssl = DataParallel(ssl)
print("Pretrained weights loaded")
Expand Down

0 comments on commit 736143f

Please sign in to comment.