Skip to content

Commit 292afa1

Browse files
fix flamingo init
1 parent 52ca075 commit 292afa1

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

open_flamingo/src/vlm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@ def __init__(
393393
gradient_checkpointing=gradient_checkpointing,
394394
)
395395
self.lang_model.set_decoder_layers_attr_name(decoder_layers_attr_name)
396+
self.decoder_layers_attr_name = decoder_layers_attr_name
396397
self.lang_model.init_cross_attention_layers(
397398
lang_hidden_size=self.lang_hidden_dim,
398399
vis_hidden_size=self.vis_embedding_dim,
@@ -491,7 +492,7 @@ def lambda_fn(module: nn.Module):
491492
return True
492493
if isinstance(module, GatedCrossAttentionBlock):
493494
return True
494-
if isinstance(module, original_decoder_block_class):
495+
if isinstance(module, decoder_block_class):
495496
return True
496497

497498
return lambda_fn

0 commit comments

Comments
 (0)