-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSTGA.py
60 lines (40 loc) · 1.58 KB
/
STGA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
import torch.nn as nn
try:
from TEM import TEM
from SGEM import SGEM
except:
from model.net.TEM import TEM
from model.net.SGEM import SGEM
class STGA(nn.Module):
def __init__(self, n_channels, h, w, in_frames):
super(STGA, self).__init__()
self.n_channels= n_channels
self.in_frames = in_frames
self.TEM = TEM(self.n_channels, self.in_frames)
self.SGEM = SGEM(self.n_channels, h, w)
self.conv2d_1 = nn.Conv2d(self.n_channels, self.n_channels,kernel_size=1, stride=1, bias=False, padding=0, groups=1)
self.conv2d_2 = nn.Conv2d(self.n_channels, self.n_channels,kernel_size=1, stride=1, bias=False, padding=0, groups=1)
def forward(self, x):
'''
x: shape (n, t, c, h, w)
'''
assert len(x.size()) == 5
n, t, c, h, w = x.size()
x_reshape_1 = x.view(n*t, c, h, w)
x_conv2d_1 = self.conv2d_1(x_reshape_1) # (n*t, c, h, w)
x_conv2d_1 = x_conv2d_1.view(n, t, c, h, w)
x_TEM = self.TEM(x_conv2d_1)
x_SGEM = self.SGEM(x_conv2d_1)
x_out = x_TEM + x_SGEM
# x_out = torch.cat((x_TEM, x_SGEM), 2)
x_out = x_out.view(n*t, c, h, w)
x_out = self.conv2d_2(x_out)
x_out = x_out.view(n, t, c, h, w)
out = x_out + x
return out
if __name__ == '__main__':
a = STGA(16)
data = torch.zeros(3, 3, 16, 224, 224)
out = a(data)
print(out.shape)