-
Notifications
You must be signed in to change notification settings - Fork 0
/
Result visualization.py
96 lines (86 loc) · 3.11 KB
/
Result visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
import subprocess
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns # I love this package!
sns.set_style('white')
import torch
# load check point
# loss check
model_path = 'resnet152_lrchange.pth.tar'
checkpoint = torch.load(model_path)
loss_history_train = checkpoint['loss_history_train']
loss_history_val = checkpoint['loss_history_val']
loss_train = [np.mean(l) for l in loss_history_train]
loss_val = [np.mean(l) for l in loss_history_val]
plt.plot(loss_train, label = 'Train Loss')
plt.plot(loss_val, label = 'Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss Trend')
plt.legend()
plt.show()
# model performance
'''
model_path = 'model_best.pth.tar'
# calculate outputs for the test data with our best model
output_csv_path = 'pred.csv'
command = ('python Pred.py '
'--img_dir test '
'--output_csvpath {csv_path} '
'--model {model} --batch_size 4 --cuda'
.format(csv_path=output_csv_path, model=model_path))
'''
# !{command}
# load prediction
df_pred = pd.read_csv("resnet152_lrchange.csv")
df_pred['imgpath'] = df_pred['imgpath']
# load target
test_label_path = 'annot_test.txt'
df_target = pd.read_csv(test_label_path, delimiter='\t')
#binary variables
from sklearn.metrics import accuracy_score, roc_auc_score, roc_curve
def plot_roc(attr, target, pred):
"""Plot a ROC curve and show the accuracy score and the AUC"""
fig, ax = plt.subplots()
auc = roc_auc_score(target, pred)
acc = accuracy_score(target, (pred >= 0.5).astype(int))
fpr, tpr, _ = roc_curve(target, pred)
plt.plot(fpr, tpr, lw = 2, label = attr.title())
plt.legend(loc = 4, fontsize = 15)
plt.title(('ROC Curve for {attr} (Accuracy = {acc:.3f}, AUC = {auc:.3f})'
.format(attr = attr.title(), acc= acc, auc = auc)),
fontsize = 15)
plt.xlabel('False Positive Rate', fontsize = 15)
plt.ylabel('True Positive Rate', fontsize = 15)
plt.show()
return fig
# plot ROC curve for protest
attr = "protest"
target = df_target[attr]
pred = df_pred[attr]
fig = plot_roc(attr, target, pred)
#.savefig(os.path.join('files', attr+'.png'))
# plot ROC curves for visual attributes
for attr in df_pred.columns[3:]:
target = df_target[attr]
pred = df_pred[attr][target != '-']
target = target[target != '-'].astype(int)
fig = plot_roc(attr, target, pred)
#fig.savefig(os.path.join('files', attr+'.png'))
import scipy.stats as stats
attr = 'violence'
pred = df_pred[df_target['protest'] == 1][attr].tolist()
target = df_target[df_target['protest'] == 1][attr].astype(float).tolist()
fig, ax = plt.subplots()
plt.scatter(target, pred, label = attr.title())
plt.xlim([-.05,1.05])
plt.ylim([-.05,1.05])
plt.xlabel('Annotation', fontsize = 15)
plt.ylabel('Predicton', fontsize = 15)
corr, pval = stats.pearsonr(target, pred)
plt.title(('Scatter Plot for {attr} (Correlation = {corr:.3f})'
.format(attr = attr.title(), corr= corr)), fontsize = 15)
plt.show()
#fig.savefig(os.path.join('files', attr+'.png'))