-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsim_deviation.py
126 lines (100 loc) · 4.5 KB
/
sim_deviation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import logging
import pickle
import datetime
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from random import sample
log_filename = datetime.datetime.now().strftime("./log/deviation_sim.log")
logging.basicConfig(format='%(asctime)s %(levelname)-8s %(message)s', datefmt='%m-%d %H:%M:%S',
filename=log_filename, filemode='a', level=logging.DEBUG)
logger = logging.getLogger(__name__)
def SimValue(filename, label_tag, close_tag, K):
pkl_filename = f"./corpus/{filename}.pkl"
with open(pkl_filename, 'rb') as fp:
result_tuple = pickle.load(fp)
sameList, muSameList, oriTrueDiffList, muTrueDiffList, patTrueDiffList, patternDict, seed_pattern_idx_list, mutationList = result_tuple
pattern_data_list = [pattern_data for pattern_data, noise, label_p, label_e in patternDict]
pattern_data_list_flat = torch.stack(pattern_data_list).view(len(pattern_data_list), -1)
all_similarity = 0.0
empty_cnt = 0
pbar = tqdm(patternDict)
noise_idx = 0
for data, old_noise, label, mu_num in pbar:
data_flat = data.view(1, -1)
# calculate distance
if close_tag == "topk":
similarities = F.pairwise_distance(data_flat, pattern_data_list_flat, p=2)
sorted_indices = torch.argsort(similarities, descending=False)
elif close_tag == "random":
sorted_indices = sample(range(len(pattern_data_list_flat)), len(pattern_data_list_flat))
sorted_indices = torch.Tensor(sorted_indices).to(torch.int)
else:
raise NotImplementedError
# get similar indices
most_similar_indices = []
K_count = 0
for i in sorted_indices:
if K_count == K:
break
if i == noise_idx:
continue
if label_tag == True:
if patternDict[i][2] == label:
most_similar_indices.append(i.item())
K_count += 1
else:
most_similar_indices.append(i.item())
K_count += 1
# calculate average cosine similarity
item_average_similarity = 0
for sim_idx in most_similar_indices:
item_average_similarity += F.cosine_similarity(data_flat, pattern_data_list_flat[sim_idx], dim=1).item()
item_average_similarity = item_average_similarity / len(most_similar_indices)
if len(most_similar_indices) != 0:
all_similarity += item_average_similarity
else:
empty_cnt += 1
noise_idx += 1
all_similarity = all_similarity / (len(patternDict) - empty_cnt)
print(f"Similarity Analysis: {filename}, Label: {label_tag}, Close: {close_tag}, K: {K}, Sim: {all_similarity:.4f}, No Sim: {empty_cnt}")
logger.info(f"Similarity Analysis: {filename}, Label: {label_tag}, Close: {close_tag}, K: {K}, Sim: {all_similarity:.4f}, No Sim: {empty_cnt}")
if __name__ == "__main__":
# Table 6
SimValue("ts_mnist", True, "topk", 1)
SimValue("ts_mnist", False, "topk", 1)
SimValue("ts_mnist", True, "random", 1)
SimValue("ts_mnist", False, "random", 1)
SimValue("ts_mnist", True, "topk", 3)
SimValue("ts_mnist", False, "topk", 3)
SimValue("ts_mnist", True, "random", 3)
SimValue("ts_mnist", False, "random", 3)
SimValue("ts_mnist", True, "topk", 5)
SimValue("ts_mnist", False, "topk", 5)
SimValue("ts_mnist", True, "random", 5)
SimValue("ts_mnist", False, "random", 5)
SimValue("ts_digits", True, "topk", 1)
SimValue("ts_digits", False, "topk", 1)
SimValue("ts_digits", True, "random", 1)
SimValue("ts_digits", False, "random", 1)
SimValue("ts_digits", True, "topk", 3)
SimValue("ts_digits", False, "topk", 3)
SimValue("ts_digits", True, "random", 3)
SimValue("ts_digits", False, "random", 3)
SimValue("ts_digits", True, "topk", 5)
SimValue("ts_digits", False, "topk", 5)
SimValue("ts_digits", True, "random", 5)
SimValue("ts_digits", False, "random", 5)
SimValue("ts_bank", True, "topk", 1)
SimValue("ts_bank", False, "topk", 1)
SimValue("ts_bank", True, "random", 1)
SimValue("ts_bank", False, "random", 1)
SimValue("ts_bank", True, "topk", 3)
SimValue("ts_bank", False, "topk", 3)
SimValue("ts_bank", True, "random", 3)
SimValue("ts_bank", False, "random", 3)
SimValue("ts_bank", True, "topk", 5)
SimValue("ts_bank", False, "topk", 5)
SimValue("ts_bank", True, "random", 5)
SimValue("ts_bank", False, "random", 5)