diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py b/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py index 9ea1b4afe318..c2d87fe3cdeb 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gemma2/gemma2_modules.py @@ -67,7 +67,7 @@ def __init__( attn_mask_type: AttnMaskType, attention_type: str, attention_dropout: float = None, - cp_comm_type: str = None, + **kwargs, ): super().__init__(config=config)