@@ -105,7 +105,7 @@ def __init__(
105
105
"""
106
106
Perceiver module which takes in image features and outputs image tokens.
107
107
Args:
108
- dim (int): final dimension of the incoming image features
108
+ dim (int): dimension of the incoming image features
109
109
dim_inner (int, optional): final dimension to project the incoming image features to;
110
110
also the final dimension of the outputted features. If None, no projection is used, and dim_inner = dim.
111
111
depth (int, optional): number of layers. Defaults to 6.
@@ -124,17 +124,17 @@ def __init__(
124
124
else :
125
125
projection = None
126
126
dim_inner = dim
127
- super ().__init__ (dim_media = dim_inner , num_tokens_per_media = num_latents )
127
+ super ().__init__ (dim_media = dim , num_tokens_per_media = num_latents )
128
128
self .projection = projection
129
- self .latents = nn .Parameter (torch .randn (num_latents , dim_inner ))
129
+ self .latents = nn .Parameter (torch .randn (num_latents , dim ))
130
130
# positional embeddings
131
131
self .frame_embs = (
132
- nn .Parameter (torch .randn (max_num_frames , dim_inner ))
132
+ nn .Parameter (torch .randn (max_num_frames , dim ))
133
133
if exists (max_num_frames )
134
134
else None
135
135
)
136
136
self .media_time_embs = (
137
- nn .Parameter (torch .randn (max_num_media , 1 , dim_inner ))
137
+ nn .Parameter (torch .randn (max_num_media , 1 , dim ))
138
138
if exists (max_num_media )
139
139
else None
140
140
)
@@ -145,14 +145,14 @@ def __init__(
145
145
nn .ModuleList (
146
146
[
147
147
PerceiverAttention (
148
- dim = dim_inner , dim_head = dim_head , heads = heads
148
+ dim = dim , dim_head = dim_head , heads = heads
149
149
),
150
- FeedForward (dim = dim_inner , mult = ff_mult ),
150
+ FeedForward (dim = dim , mult = ff_mult ),
151
151
]
152
152
)
153
153
)
154
154
155
- self .norm = nn .LayerNorm (dim_inner )
155
+ self .norm = nn .LayerNorm (dim )
156
156
157
157
def forward (self , x ):
158
158
"""
@@ -164,9 +164,6 @@ def forward(self, x):
164
164
"""
165
165
b , T , F , v = x .shape [:4 ]
166
166
167
- if exists (self .projection ):
168
- x = self .projection (x )
169
-
170
167
# frame and media time embeddings
171
168
if exists (self .frame_embs ):
172
169
frame_embs = repeat (self .frame_embs [:F ], "F d -> b T F v d" , b = b , T = T , v = v )
@@ -182,7 +179,11 @@ def forward(self, x):
182
179
for attn , ff in self .layers :
183
180
latents = attn (x , latents ) + latents
184
181
latents = ff (latents ) + latents
185
- return self .norm (latents )
182
+
183
+ if exists (self .projection ):
184
+ return self .projection (self .norm (latents ))
185
+ else :
186
+ self .norm (latents )
186
187
187
188
class LinearPatchProjection (VisionTokenizer ):
188
189
"""Linear projection from patch features to image tokens."""
0 commit comments