Skip to content

Commit a72c96b

Browse files
fix resampler projection
1 parent 292afa1 commit a72c96b

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

open_flamingo/src/helpers.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def __init__(
105105
"""
106106
Perceiver module which takes in image features and outputs image tokens.
107107
Args:
108-
dim (int): final dimension of the incoming image features
108+
dim (int): dimension of the incoming image features
109109
dim_inner (int, optional): final dimension to project the incoming image features to;
110110
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
111111
depth (int, optional): number of layers. Defaults to 6.
@@ -124,17 +124,17 @@ def __init__(
124124
else:
125125
projection = None
126126
dim_inner = dim
127-
super().__init__(dim_media=dim_inner, num_tokens_per_media=num_latents)
127+
super().__init__(dim_media=dim, num_tokens_per_media=num_latents)
128128
self.projection = projection
129-
self.latents = nn.Parameter(torch.randn(num_latents, dim_inner))
129+
self.latents = nn.Parameter(torch.randn(num_latents, dim))
130130
# positional embeddings
131131
self.frame_embs = (
132-
nn.Parameter(torch.randn(max_num_frames, dim_inner))
132+
nn.Parameter(torch.randn(max_num_frames, dim))
133133
if exists(max_num_frames)
134134
else None
135135
)
136136
self.media_time_embs = (
137-
nn.Parameter(torch.randn(max_num_media, 1, dim_inner))
137+
nn.Parameter(torch.randn(max_num_media, 1, dim))
138138
if exists(max_num_media)
139139
else None
140140
)
@@ -145,14 +145,14 @@ def __init__(
145145
nn.ModuleList(
146146
[
147147
PerceiverAttention(
148-
dim=dim_inner, dim_head=dim_head, heads=heads
148+
dim=dim, dim_head=dim_head, heads=heads
149149
),
150-
FeedForward(dim=dim_inner, mult=ff_mult),
150+
FeedForward(dim=dim, mult=ff_mult),
151151
]
152152
)
153153
)
154154

155-
self.norm = nn.LayerNorm(dim_inner)
155+
self.norm = nn.LayerNorm(dim)
156156

157157
def forward(self, x):
158158
"""
@@ -164,9 +164,6 @@ def forward(self, x):
164164
"""
165165
b, T, F, v = x.shape[:4]
166166

167-
if exists(self.projection):
168-
x = self.projection(x)
169-
170167
# frame and media time embeddings
171168
if exists(self.frame_embs):
172169
frame_embs = repeat(self.frame_embs[:F], "F d -> b T F v d", b=b, T=T, v=v)
@@ -182,7 +179,11 @@ def forward(self, x):
182179
for attn, ff in self.layers:
183180
latents = attn(x, latents) + latents
184181
latents = ff(latents) + latents
185-
return self.norm(latents)
182+
183+
if exists(self.projection):
184+
return self.projection(self.norm(latents))
185+
else:
186+
self.norm(latents)
186187

187188
class LinearPatchProjection(VisionTokenizer):
188189
"""Linear projection from patch features to image tokens."""

0 commit comments

Comments
 (0)