Skip to content

Commit 0ca25e0

Browse files
committed
Convert t5 to use config object.
1 parent 02fffd2 commit 0ca25e0

File tree

3 files changed

+11
-19
lines changed

3 files changed

+11
-19
lines changed

megatron/model/t5_model.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@
1111
from megatron.model import LayerNorm
1212
from megatron.model.utils import (
1313
openai_gelu,
14-
get_linear_layer,
15-
init_method_normal,
16-
scaled_init_method_normal
14+
get_linear_layer
1715
)
1816
from .module import MegatronModule
1917

@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
4341
4442
Arguments:
4543
mpu_vocab_size: model parallel size of vocabulary.
46-
hidden_size: hidden size
47-
init_method: init method for weight initialization
48-
layernorm_epsilon: tolerance for layer norm divisions
4944
parallel_output: wether output logits being distributed or not.
5045
"""
5146

5247
def __init__(self, mpu_vocab_size, parallel_output):
5348
super(T5LMHead, self).__init__()
5449

55-
args = get_args()
56-
5750
self.bias = torch.nn.Parameter(torch.zeros(mpu_vocab_size))
5851
self.bias.model_parallel = True
5952
self.bias.partition_dim = 0
@@ -72,37 +65,34 @@ class T5Model(MegatronModule):
7265
"""T5 Language model."""
7366

7467
def __init__(self,
68+
config,
7569
num_tokentypes=0,
7670
parallel_output=True,
7771
pre_process=True,
7872
post_process=True,
7973
add_encoder=True,
8074
add_decoder=True):
81-
super(T5Model, self).__init__()
75+
super().__init__(config=config)
8276
args = get_args()
8377

8478
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
8579
self.parallel_output = parallel_output
86-
init_method = init_method_normal(args.init_method_std)
87-
scaled_init_method = scaled_init_method_normal(args.init_method_std,
88-
args.num_layers)
8980
self.pre_process = pre_process
9081
self.post_process = post_process
9182
self.add_encoder = add_encoder
9283
self.add_decoder = add_decoder
9384

9485
self.language_model, self._language_model_key = get_language_model(
86+
config=config,
9587
num_tokentypes=num_tokentypes,
9688
add_pooler=False,
9789
add_encoder=add_encoder,
9890
add_decoder=add_decoder,
9991
encoder_attn_mask_type=AttnMaskType.padding,
100-
init_method=init_method,
101-
scaled_init_method=scaled_init_method,
10292
pre_process=self.pre_process,
10393
post_process=self.post_process)
10494

105-
self.initialize_word_embeddings(init_method_normal)
95+
self.initialize_word_embeddings()
10696

10797
if self.post_process and self.add_decoder:
10898
self.lm_head = T5LMHead(

megatron/model/transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -747,15 +747,14 @@ def __init__(self, config,
747747
LayerType.retro_decoder_with_retriever,
748748
LayerType.retro_encoder):
749749
self.inter_attention = ParallelAttention(
750-
config.init_method,
751-
config.output_layer_init_method,
750+
config,
752751
layer_number,
753752
attention_type=AttnType.cross_attn)
754753
# Layernorm on the attention output.
755754
self.post_inter_attention_layernorm = LayerNorm(
756755
config.hidden_size,
757756
eps=config.layernorm_epsilon,
758-
no_persist_layer_norm=config.no_persist_layer_norm,
757+
no_persist_layer_norm=not config.persist_layer_norm,
759758
sequence_parallel=config.sequence_parallel,
760759
apply_layernorm_1p=args.apply_layernorm_1p)
761760

pretrain_t5.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from megatron.model import T5Model
1818
from megatron.training import pretrain
1919
from megatron.utils import average_losses_across_data_parallel_group
20+
from megatron.arguments import core_transformer_config_from_args
2021

2122

2223
"""
@@ -60,7 +61,9 @@ def model_provider(pre_process=True, post_process=True,
6061
"""Build the model."""
6162

6263
print_rank_0('building T5 model ...')
63-
model = T5Model(num_tokentypes=0,
64+
config = core_transformer_config_from_args(get_args())
65+
model = T5Model(config=config,
66+
num_tokentypes=0,
6467
parallel_output=True,
6568
pre_process=pre_process,
6669
post_process=post_process,

0 commit comments

Comments
 (0)