-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrecord.py
65 lines (58 loc) · 2.53 KB
/
record.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import torch
from operator import truediv
def evaluate_accuracy(data_iter, net, loss, device):
acc_sum, n = 0.0, 0
with torch.no_grad():
for X, y in data_iter:
test_l_sum, test_num = 0, 0
#X = X.permute(0, 3, 1, 2)
X = X.to(device)
y = y.to(device)
net.eval()
y_hat = net(X)
l = loss(y_hat, y.long())
acc_sum += (y_hat.argmax(dim=1) == y.to(device)).float().sum().cpu().item()
test_l_sum += l
test_num += 1
net.train()
n += y.shape[0]
return [acc_sum / n, test_l_sum] # / test_num]
def aa_and_each_accuracy(confusion_matrix):
list_diag = np.diag(confusion_matrix)
list_raw_sum = np.sum(confusion_matrix, axis=1)
each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
average_acc = np.mean(each_acc)
return each_acc, average_acc
def record_output(oa_ae, aa_ae, kappa_ae, element_acc_ae, path):
f = open(path, 'w')
sentence0 = 'OAs for each iteration are:' + str(oa_ae) + '\n'
f.write(sentence0)
sentence1 = 'AAs for each iteration are:' + str(aa_ae) + '\n'
f.write(sentence1)
sentence2 = 'KAPPAs for each iteration are:' + str(kappa_ae) + '\n' + '\n'
f.write(sentence2)
sentence3 = 'mean_OA ± std_OA is: ' + str(np.mean(oa_ae)) + ' ± ' + str(np.std(oa_ae)) + '\n'
f.write(sentence3)
sentence4 = 'mean_AA ± std_AA is: ' + str(np.mean(aa_ae)) + ' ± ' + str(np.std(aa_ae)) + '\n'
f.write(sentence4)
sentence5 = 'mean_KAPPA ± std_KAPPA is: ' + str(np.mean(kappa_ae)) + ' ± ' + str(np.std(kappa_ae)) + '\n' + '\n'
f.write(sentence5)
element_mean = np.mean(element_acc_ae, axis=0)
element_std = np.std(element_acc_ae, axis=0)
sentence8 = "Mean of all elements in confusion matrix: " + str(element_mean) + '\n'
f.write(sentence8)
sentence9 = "Standard deviation of all elements in confusion matrix: " + str(element_std) + '\n' + '\n'
f.write(sentence9)
element_mean = list(element_mean)
element_mean.extend([np.mean(oa_ae),np.mean(aa_ae),np.mean(kappa_ae)])
element_std = list(element_std)
element_std.extend([np.std(oa_ae),np.std(aa_ae),np.std(kappa_ae)])
sentence10 = "All values without std: " + str(element_mean) + '\n' + '\n'
f.write(sentence10)
sentence11 = "All values with std: "
for i,x in enumerate(element_mean):
sentence11 += str(element_mean[i]) + " ± " + str(element_std[i]) + ", "
sentence11 += "\n"
f.write(sentence11)
f.close()