Skip to content

Commit a9c3b5e

Browse files
authored
Add files via upload
1 parent d34cd7b commit a9c3b5e

File tree

4 files changed

+245
-0
lines changed

4 files changed

+245
-0
lines changed

Test_awa2.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
import torch
2+
import torch.optim as optim
3+
import torch.nn as nn
4+
import pandas as pd
5+
from core.MSDN import MSDN
6+
from core.AWA2DataLoader import AWA2DataLoader
7+
from core.helper_MSDN_AWA2 import eval_zs_gzsl,visualize_attention#,get_attribute_attention_stats
8+
import importlib
9+
import pdb
10+
import numpy as np
11+
12+
NFS_path = './'
13+
idx_GPU = 0
14+
device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu")
15+
dataloader = AWA2DataLoader(NFS_path,device)
16+
torch.backends.cudnn.benchmark = True
17+
18+
def get_lr(optimizer):
19+
lr = []
20+
for param_group in optimizer.param_groups:
21+
lr.append(param_group['lr'])
22+
return lr
23+
24+
seed = 214#214
25+
torch.manual_seed(seed)
26+
torch.cuda.manual_seed_all(seed)
27+
np.random.seed(seed)
28+
29+
batch_size = 50
30+
nepoches = 50
31+
niters = dataloader.ntrain * nepoches//batch_size
32+
dim_f = 2048
33+
dim_v = 300
34+
init_w2v_att = dataloader.w2v_att
35+
att = dataloader.att#dataloader.normalize_att#
36+
att[att<0] = 0
37+
normalize_att = dataloader.normalize_att
38+
#assert (att.min().item() == 0 and att.max().item() == 1)
39+
40+
trainable_w2v = True
41+
lambda_ = 0.12#0.1 ,0.12 for T-I in GZSL, 0.15 for T-I in CZSL, 0.13 for I-T,0.3 for baseline
42+
bias = 0
43+
prob_prune = 0
44+
uniform_att_1 = False
45+
uniform_att_2 = False
46+
47+
seenclass = dataloader.seenclasses
48+
unseenclass = dataloader.unseenclasses
49+
desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))
50+
report_interval = niters//nepoches#10000//batch_size#
51+
52+
model_gzsl = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att,
53+
seenclass,unseenclass,
54+
lambda_,
55+
trainable_w2v,normalize_V=True,normalize_F=True,is_conservative=True,
56+
uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
57+
prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
58+
is_bias=True)
59+
model_gzsl.to(device)
60+
model_gzsl.load_state_dict(torch.load('saved_model/AWA2_MSDN_GZSL.pth'))
61+
62+
model_czsl = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att,
63+
seenclass,unseenclass,
64+
lambda_,
65+
trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,
66+
uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
67+
prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
68+
is_bias=True)
69+
model_czsl.to(device)
70+
model_czsl.load_state_dict(torch.load('saved_model/AWA2_MSDN_CZSL.pth'))
71+
72+
73+
74+
print('-'*30)
75+
acc_seen, acc_novel, H, _ = eval_zs_gzsl(dataloader,model_gzsl,device,bias_seen=-bias,bias_unseen=bias)
76+
_, _, _, acc_zs = eval_zs_gzsl(dataloader,model_czsl,device,bias_seen=-bias,bias_unseen=bias)
77+
78+
print('acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f'%(acc_novel,acc_seen,H, acc_zs))# %%

Test_cub.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import torch.optim as optim
3+
import torch.nn as nn
4+
import pandas as pd
5+
from core.MSDN import MSDN
6+
from core.CUBDataLoader import CUBDataLoader
7+
from core.helper_MSDN_CUB import eval_zs_gzsl,visualize_attention#,get_attribute_attention_stats
8+
# from global_setting import NFS_path
9+
import importlib
10+
import pdb
11+
import numpy as np
12+
from PIL import Image
13+
import matplotlib.pyplot as plt
14+
import skimage
15+
from sklearn.manifold import TSNE
16+
from torchvision import transforms
17+
import torch.nn.functional as F
18+
19+
20+
21+
NFS_path = './'
22+
23+
idx_GPU = 0
24+
device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu")
25+
dataloader = CUBDataLoader(NFS_path,device,is_unsupervised_attr=False,is_balance=False)
26+
dataloader.augment_img_path()
27+
torch.backends.cudnn.benchmark = True
28+
29+
def get_lr(optimizer):
30+
lr = []
31+
for param_group in optimizer.param_groups:
32+
lr.append(param_group['lr'])
33+
return lr
34+
35+
seed = 214#215#
36+
torch.manual_seed(seed)
37+
torch.cuda.manual_seed_all(seed)
38+
np.random.seed(seed)
39+
40+
batch_size = 50
41+
nepoches = 70#22
42+
niters = dataloader.ntrain * nepoches//batch_size
43+
dim_f = 2048
44+
dim_v = 300
45+
init_w2v_att = dataloader.w2v_att
46+
att = dataloader.att
47+
normalize_att = dataloader.normalize_att
48+
49+
trainable_w2v = True
50+
lambda_ = 0.1#0.1
51+
bias = 0
52+
prob_prune = 0
53+
uniform_att_1 = False
54+
uniform_att_2 = False
55+
56+
seenclass = dataloader.seenclasses
57+
unseenclass = dataloader.unseenclasses
58+
desired_mass = 1
59+
report_interval = niters//nepoches
60+
61+
62+
model_gzsl = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att,
63+
seenclass,unseenclass,
64+
lambda_,
65+
trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,
66+
uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
67+
prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
68+
is_bias=True)
69+
model_gzsl.to(device)
70+
model_gzsl.load_state_dict(torch.load('saved_model/CUB_MSDN_GZSL.pth'))
71+
72+
model_czsl = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att,
73+
seenclass,unseenclass,
74+
lambda_,
75+
trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,
76+
uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
77+
prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
78+
is_bias=True)
79+
model_czsl.to(device)
80+
model_czsl.load_state_dict(torch.load('saved_model/CUB_MSDN_CZSL.pth'))
81+
82+
83+
84+
85+
print('-'*30)
86+
acc_seen, acc_novel, H, _ = eval_zs_gzsl(dataloader,model_gzsl,device,bias_seen=-bias,bias_unseen=bias)
87+
_, _, _, acc_zs = eval_zs_gzsl(dataloader,model_czsl,device,bias_seen=-bias,bias_unseen=bias)
88+
89+
print('acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f'%(acc_novel,acc_seen,H, acc_zs))# %%

Test_sun.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import torch
2+
import torch.optim as optim
3+
import torch.nn as nn
4+
import pandas as pd
5+
from core.MSDN import MSDN
6+
from core.SUNDataLoader import SUNDataLoader
7+
from core.helper_MSDN_SUN import eval_zs_gzsl,visualize_attention#,get_attribute_attention_stats
8+
from global_setting import NFS_path
9+
import importlib
10+
import pdb
11+
import numpy as np
12+
13+
idx_GPU = 0
14+
device = torch.device("cuda:{}".format(idx_GPU) if torch.cuda.is_available() else "cpu")
15+
dataloader = SUNDataLoader(NFS_path,device,is_scale=False,is_balance = True)
16+
17+
torch.backends.cudnn.benchmark = True
18+
19+
seed = 214
20+
torch.manual_seed(seed)
21+
torch.cuda.manual_seed_all(seed)
22+
23+
print('Randomize seed {}'.format(seed))
24+
#%%
25+
batch_size = 50
26+
nepoches = 80
27+
niters = dataloader.ntrain * nepoches//batch_size
28+
dim_f = 2048
29+
dim_v = 300
30+
init_w2v_att = dataloader.w2v_att
31+
att = dataloader.att#dataloader.normalize_att#
32+
normalize_att = dataloader.normalize_att
33+
#assert (att.min().item() == 0 and att.max().item() == 1)
34+
35+
trainable_w2v = True
36+
lambda_ = 0.001
37+
bias = 0.
38+
prob_prune = 0
39+
uniform_att_1 = False
40+
uniform_att_2 = True
41+
42+
seenclass = dataloader.seenclasses
43+
unseenclass = dataloader.unseenclasses
44+
desired_mass = 1#unseenclass.size(0)/(seenclass.size(0)+unseenclass.size(0))
45+
report_interval = niters//nepoches
46+
#%%
47+
model_gzsl = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att,
48+
seenclass,unseenclass,
49+
lambda_,
50+
trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,
51+
uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
52+
prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
53+
is_bias=True,non_linear_act=False)
54+
model_gzsl.to(device)
55+
model_gzsl.load_state_dict(torch.load('saved_model/SUN_MSDN_GZSL.pth'))
56+
57+
model_czsl = MSDN(dim_f,dim_v,init_w2v_att,att,normalize_att,
58+
seenclass,unseenclass,
59+
lambda_,
60+
trainable_w2v,normalize_V=False,normalize_F=True,is_conservative=True,
61+
uniform_att_1=uniform_att_1,uniform_att_2=uniform_att_2,
62+
prob_prune=prob_prune,desired_mass=desired_mass, is_conv=False,
63+
is_bias=True)
64+
model_czsl.to(device)
65+
model_czsl.load_state_dict(torch.load('saved_model/SUN_MSDN_CZSL.pth'))
66+
67+
68+
69+
70+
print('-'*30)
71+
acc_seen, acc_novel, H, _ = eval_zs_gzsl(dataloader,model_gzsl,device,bias_seen=-bias,bias_unseen=bias)
72+
_, _, _, acc_zs = eval_zs_gzsl(dataloader,model_czsl,device,bias_seen=-bias,bias_unseen=bias)
73+
74+
print('acc_unseen=%.3f, acc_seen=%.3f, H=%.3f, acc_zs=%.3f'%(acc_novel,acc_seen,H, acc_zs))# %%

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
numpy==1.20.2
2+
torchvision==0.9.0
3+
torch==1.8.0
4+
Pillow==8.3.2

0 commit comments

Comments
 (0)