Skip to content

Commit aba4a24

Browse files
committed
Add files
1 parent 4096d9a commit aba4a24

File tree

15 files changed

+586
-0
lines changed

15 files changed

+586
-0
lines changed

MSDN.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
2+
import tensorflow as tf
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
import numpy as np
7+
import torchvision
8+
9+
10+
class MSDN(nn.Module):
11+
#####
12+
# einstein sum notation
13+
# b: Batch size \ f: dim feature=2048 \ v: dim w2v=300 \ r: number of region=196 \ k: number of classes
14+
# i: number of attribute=312
15+
#####
16+
def __init__(self, config, normalize_V = False, normalize_F = False, is_conservative = False,
17+
prob_prune=0,uniform_att_1 = False,uniform_att_2 = False, is_conv = False,
18+
is_bias = False,bias = 1,non_linear_act=False,
19+
loss_type = 'CE',non_linear_emb = False,
20+
is_sigmoid = False):
21+
super(MSDN, self).__init__()
22+
self.config = config
23+
self.dim_f = config.dim_f
24+
self.dim_v = config.dim_v
25+
self.nclass = config.num_class
26+
self.dim_att = config.num_attribute
27+
self.hidden = self.dim_att//2
28+
self.non_linear_act = non_linear_act
29+
self.loss_type = loss_type
30+
self.w1 = config.w1
31+
self.w2 = config.w2
32+
33+
self.att = nn.Parameter(torch.empty(
34+
self.nclass, self.dim_att), requires_grad=False)
35+
self.V = nn.Parameter(torch.empty(
36+
self.dim_att, self.dim_v), requires_grad=True)
37+
38+
self.W_1 = nn.Parameter(nn.init.normal_(
39+
torch.empty(self.dim_v, self.dim_f)), requires_grad=True)
40+
self.W_2 = nn.Parameter(nn.init.zeros_(
41+
torch.empty(self.dim_v, self.dim_f)), requires_grad=True)
42+
self.W_3 = nn.Parameter(nn.init.zeros_(
43+
torch.empty(self.dim_v, self.dim_f)), requires_grad=True)
44+
45+
self.W_1_1 = nn.Parameter(nn.init.zeros_(
46+
torch.empty(self.dim_f, self.dim_v)), requires_grad=True)
47+
self.W_2_1 = nn.Parameter(nn.init.zeros_(
48+
torch.empty(self.dim_v, self.dim_f)), requires_grad=True)
49+
self.W_3_1 = nn.Parameter(nn.init.zeros_(
50+
torch.empty(self.dim_f, self.dim_v)), requires_grad=True)
51+
52+
self.normalize_V = normalize_V
53+
self.normalize_F = normalize_F
54+
self.is_conservative = is_conservative
55+
self.is_conv = is_conv
56+
self.is_bias = is_bias
57+
58+
if is_bias:
59+
self.bias = nn.Parameter(torch.tensor(1), requires_grad=False)
60+
self.mask_bias = nn.Parameter(torch.empty(
61+
1, self.nclass), requires_grad=False)
62+
63+
self.prob_prune = nn.Parameter(torch.tensor(prob_prune),requires_grad = False)
64+
65+
self.uniform_att_1 = uniform_att_1
66+
self.uniform_att_2 = uniform_att_2
67+
68+
self.non_linear_emb = non_linear_emb
69+
if self.non_linear_emb:
70+
self.emb_func = torch.nn.Sequential(
71+
torch.nn.Linear(self.dim_att, self.dim_att//2),
72+
torch.nn.ReLU(),
73+
torch.nn.Linear(self.dim_att//2, 1),)
74+
self.is_sigmoid = is_sigmoid
75+
76+
# bakcbone
77+
resnet101 = torchvision.models.resnet101(pretrained=True)
78+
self.resnet101 = nn.Sequential(*list(resnet101.children())[:-2])
79+
80+
81+
def compute_V(self):
82+
if self.normalize_V:
83+
V_n = F.normalize(self.V)
84+
else:
85+
V_n = self.V
86+
return V_n
87+
88+
def get_global_feature(self, x):
89+
90+
N, C, W, H = x.shape
91+
global_feat = F.avg_pool2d(x, kernel_size=(W, H))
92+
global_feat = global_feat.view(N, C)
93+
94+
return global_feat
95+
96+
97+
def forward(self, imgs):
98+
99+
Fs = self.resnet101(imgs)
100+
101+
if self.is_conv:
102+
Fs = self.conv1(Fs)
103+
Fs = self.conv1_bn(Fs)
104+
Fs = F.relu(Fs)
105+
106+
shape = Fs.shape
107+
108+
visualf_ori = self.get_global_feature(Fs)
109+
110+
111+
Fs = Fs.reshape(shape[0],shape[1],shape[2]*shape[3]) # batch x 2048 x 49
112+
113+
R = Fs.size(2) # 49
114+
B = Fs.size(0) # batch
115+
V_n = self.compute_V() # 312x300
116+
117+
if self.normalize_F and not self.is_conv: # true
118+
Fs = F.normalize(Fs,dim = 1)
119+
120+
121+
##########################Text-Image################################
122+
123+
## Compute attribute score on each image region
124+
S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1,Fs) # batchx312x49
125+
126+
if self.is_sigmoid:
127+
S=torch.sigmoid(S)
128+
129+
## Ablation setting
130+
A_b = Fs.new_full((B,self.dim_att,R),1/R)
131+
A_b_p = self.att.new_full((B,self.dim_att),fill_value = 1)
132+
S_b_p = torch.einsum('bir,bir->bi',A_b,S)
133+
S_b_pp = torch.einsum('ki,bi,bi->bk',self.att,A_b_p,S_b_p)
134+
##
135+
136+
## compute Dense Attention
137+
A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2,Fs)
138+
A = F.softmax(A,dim = -1)
139+
140+
F_p = torch.einsum('bir,bfr->bif',A,Fs)
141+
if self.uniform_att_1: # false
142+
S_p = torch.einsum('bir,bir->bi',A_b,S)
143+
else:
144+
S_p = torch.einsum('bir,bir->bi',A,S)
145+
146+
if self.non_linear_act: # false
147+
S_p = F.relu(S_p)
148+
##
149+
150+
## compute Attention over Attribute
151+
A_p = torch.einsum('iv,vf,bif->bi',V_n,self.W_3,F_p) #eq. 6
152+
A_p = torch.sigmoid(A_p)
153+
##
154+
155+
if self.uniform_att_2: # true
156+
S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_b_p,S_p)
157+
else:
158+
# S_pp = torch.einsum('ki,bi,bi->bik',self.att,A_p,S_p)
159+
S_pp = torch.einsum('ki,bi->bik',self.att,S_p)
160+
161+
S_attr = torch.einsum('bi,bi->bi',A_b_p,S_p)
162+
163+
if self.non_linear_emb:
164+
S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik]
165+
S_pp = self.emb_func(S_pp) #[bk1] <== [bki]
166+
S_pp = S_pp[:,:,0] #[bk] <== [bk1]
167+
else:
168+
S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik]
169+
170+
# augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes
171+
if self.is_bias:
172+
self.vec_bias = self.mask_bias*self.bias
173+
S_pp = S_pp + self.vec_bias
174+
175+
## spatial attention supervision
176+
Pred_att = torch.einsum('iv,vf,bif->bi',V_n,self.W_1,F_p)
177+
package1 = {'S_pp':S_pp,'Pred_att':Pred_att,'S_p':S_p,'S_b_pp':S_b_pp,'A_p':A_p,'A':A,'S_attr':S_attr,'visualf_ori':visualf_ori,'visualf_a_v':F_p}
178+
179+
##########################Image-Text################################
180+
181+
## Compute attribute score on each image region
182+
183+
S = torch.einsum('bfr,fv,iv->bri',Fs,self.W_1_1,V_n) # batchx49x312
184+
# S = torch.einsum('iv,vf,bfr->bir',V_n,self.W_1_1,Fs)
185+
if self.is_sigmoid:
186+
S=torch.sigmoid(S)
187+
188+
189+
190+
## compute Dense Attention
191+
A = torch.einsum('iv,vf,bfr->bir',V_n,self.W_2_1,Fs)
192+
A = F.softmax(A,dim = 1)
193+
194+
v_a = torch.einsum('bir,iv->brv',A,V_n)
195+
196+
S_p = torch.einsum('bir,bri->bi',A,S)
197+
198+
if self.non_linear_act: # false
199+
S_p = F.relu(S_p)
200+
201+
202+
203+
S_pp = torch.einsum('ki,bi->bik',self.att,S_p)
204+
205+
S_attr = 0#torch.einsum('bi,bi->bi',A_b_p,S_p)
206+
207+
if self.non_linear_emb:
208+
S_pp = torch.transpose(S_pp,2,1) #[bki] <== [bik]
209+
S_pp = self.emb_func(S_pp) #[bk1] <== [bki]
210+
S_pp = S_pp[:,:,0] #[bk] <== [bk1]
211+
else:
212+
S_pp = torch.sum(S_pp,axis=1) #[bk] <== [bik]
213+
214+
# augment prediction scores by adding a margin of 1 to unseen classes and -1 to seen classes
215+
if self.is_bias:
216+
self.vec_bias = self.mask_bias*self.bias
217+
S_pp = S_pp + self.vec_bias
218+
219+
## spatial attention supervision
220+
package2 = {'S_pp':S_pp,'visualf_v_a':v_a, 'S_p':S_p, 'A':A}
221+
222+
package = {'embed': self.w1 * package1['S_pp']+self.w2 * package2['S_pp']}
223+
224+
return package

Test_AWA2.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from MSDN import MSDN
3+
from dataset import UNIDataloader
4+
import argparse
5+
import json
6+
from utils import evaluation
7+
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--config', type=str, default='config/test_AWA2.json')
11+
config = parser.parse_args()
12+
with open(config.config, 'r') as f:
13+
config.__dict__ = json.load(f)
14+
15+
dataloader = UNIDataloader(config)
16+
17+
model_gzsl = MSDN(config, normalize_V=True, normalize_F=True, is_conservative=True,
18+
uniform_att_1=False, uniform_att_2=False,
19+
is_conv=False, is_bias=True).to(config.device)
20+
model_dict = model_gzsl.state_dict()
21+
saved_dict = torch.load('saved_model/AWA2_MSDN_GZSL.pth')
22+
saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict}
23+
model_dict.update(saved_dict)
24+
model_gzsl.load_state_dict(model_dict)
25+
26+
model_czsl = MSDN(config, normalize_V=True, normalize_F=True, is_conservative=True,
27+
uniform_att_1=False, uniform_att_2=False,
28+
is_conv=False, is_bias=True).to(config.device)
29+
model_dict = model_czsl.state_dict()
30+
saved_dict = torch.load('saved_model/AWA2_MSDN_CZSL.pth')
31+
saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict}
32+
model_dict.update(saved_dict)
33+
model_czsl.load_state_dict(model_dict)
34+
35+
evaluation(config.batch_size, config.device,
36+
dataloader, model_gzsl, model_czsl)

Test_CUB.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import torch
2+
from MSDN import MSDN
3+
from dataset import UNIDataloader
4+
import argparse
5+
import json
6+
from utils import evaluation
7+
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--config', type=str, default='config/test_CUB.json')
11+
config = parser.parse_args()
12+
with open(config.config, 'r') as f:
13+
config.__dict__ = json.load(f)
14+
15+
dataloader = UNIDataloader(config)
16+
17+
model_gzsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True,
18+
uniform_att_1=False, uniform_att_2=False,
19+
is_conv=False, is_bias=True).to(config.device)
20+
model_dict = model_gzsl.state_dict()
21+
saved_dict = torch.load('saved_model/CUB_MSDN_GZSL.pth')
22+
saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict}
23+
model_dict.update(saved_dict)
24+
model_gzsl.load_state_dict(model_dict)
25+
26+
model_czsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True,
27+
uniform_att_1=False, uniform_att_2=False,
28+
is_conv=False, is_bias=True).to(config.device)
29+
model_dict = model_czsl.state_dict()
30+
saved_dict = torch.load('saved_model/CUB_MSDN_CZSL.pth')
31+
saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict}
32+
model_dict.update(saved_dict)
33+
model_czsl.load_state_dict(model_dict)
34+
35+
evaluation(config.batch_size, config.device,
36+
dataloader, model_gzsl, model_czsl)

Test_SUN.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
from MSDN import MSDN
3+
from dataset import UNIDataloader
4+
import argparse
5+
import json
6+
from utils import evaluation
7+
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument('--config', type=str, default='config/test_SUN.json')
10+
config = parser.parse_args()
11+
with open(config.config, 'r') as f:
12+
config.__dict__ = json.load(f)
13+
14+
dataloader = UNIDataloader(config)
15+
16+
model_gzsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True,
17+
uniform_att_1=False, uniform_att_2=True,
18+
is_conv=False, is_bias=True, non_linear_act=False).to(config.device)
19+
model_dict = model_gzsl.state_dict()
20+
saved_dict = torch.load('saved_model/SUN_MSDN_GZSL.pth')
21+
saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict}
22+
model_dict.update(saved_dict)
23+
model_gzsl.load_state_dict(model_dict)
24+
25+
model_czsl = MSDN(config, normalize_V=False, normalize_F=True, is_conservative=True,
26+
uniform_att_1=False, uniform_att_2=True,
27+
is_conv=False, is_bias=True, non_linear_act=False).to(config.device)
28+
model_dict = model_czsl.state_dict()
29+
saved_dict = torch.load('saved_model/SUN_MSDN_CZSL.pth')
30+
saved_dict = {k: v for k, v in saved_dict.items() if k in model_dict}
31+
model_dict.update(saved_dict)
32+
model_czsl.load_state_dict(model_dict)
33+
34+
evaluation(config.batch_size, config.device,
35+
dataloader, model_gzsl, model_gzsl)

config/test_AWA2.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"dataset": "AWA2",
3+
"dataset_path": "./data/AWA2",
4+
"pkl_path": "./data/AWA2.pkl",
5+
"device": "cuda:0",
6+
"num_workers": 16,
7+
"batch_size": 50,
8+
"num_attribute": 85,
9+
"num_class": 50,
10+
"resnet_region": 196,
11+
"dim_f": 2048,
12+
"dim_v": 300,
13+
"img_size": 448,
14+
"w1": 1.0,
15+
"w2": 0.0
16+
}

config/test_CUB.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"dataset": "CUB",
3+
"dataset_path": "./data/CUB",
4+
"pkl_path": "./data/CUB.pkl",
5+
"device": "cuda:0",
6+
"num_workers": 16,
7+
"batch_size": 50,
8+
"num_attribute": 312,
9+
"num_class": 200,
10+
"resnet_region": 196,
11+
"dim_f": 2048,
12+
"dim_v": 300,
13+
"img_size": 448,
14+
"w1": 0.9,
15+
"w2": 0.1
16+
}

config/test_SUN.json

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"dataset": "SUN",
3+
"dataset_path": "./data/SUN",
4+
"pkl_path": "./data/SUN.pkl",
5+
"device": "cuda:0",
6+
"num_workers": 16,
7+
"batch_size": 50,
8+
"num_attribute": 102,
9+
"num_class": 717,
10+
"resnet_region": 196,
11+
"dim_f": 2048,
12+
"dim_v": 300,
13+
"img_size": 448,
14+
"w1": 0.7,
15+
"w2": 0.3
16+
}

data/AWA2

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../data/AWA2

data/AWA2.pkl

2.67 MB
Binary file not shown.

data/CUB

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
../../../data/CUB

0 commit comments

Comments
 (0)