11
11
from megatron .model import LayerNorm
12
12
from megatron .model .utils import (
13
13
openai_gelu ,
14
- get_linear_layer ,
15
- init_method_normal ,
16
- scaled_init_method_normal
14
+ get_linear_layer
17
15
)
18
16
from .module import MegatronModule
19
17
@@ -43,17 +41,12 @@ class T5LMHead(MegatronModule):
43
41
44
42
Arguments:
45
43
mpu_vocab_size: model parallel size of vocabulary.
46
- hidden_size: hidden size
47
- init_method: init method for weight initialization
48
- layernorm_epsilon: tolerance for layer norm divisions
49
44
parallel_output: wether output logits being distributed or not.
50
45
"""
51
46
52
47
def __init__ (self , mpu_vocab_size , parallel_output ):
53
48
super (T5LMHead , self ).__init__ ()
54
49
55
- args = get_args ()
56
-
57
50
self .bias = torch .nn .Parameter (torch .zeros (mpu_vocab_size ))
58
51
self .bias .model_parallel = True
59
52
self .bias .partition_dim = 0
@@ -72,37 +65,34 @@ class T5Model(MegatronModule):
72
65
"""T5 Language model."""
73
66
74
67
def __init__ (self ,
68
+ config ,
75
69
num_tokentypes = 0 ,
76
70
parallel_output = True ,
77
71
pre_process = True ,
78
72
post_process = True ,
79
73
add_encoder = True ,
80
74
add_decoder = True ):
81
- super (T5Model , self ).__init__ ()
75
+ super ().__init__ (config = config )
82
76
args = get_args ()
83
77
84
78
self .fp16_lm_cross_entropy = args .fp16_lm_cross_entropy
85
79
self .parallel_output = parallel_output
86
- init_method = init_method_normal (args .init_method_std )
87
- scaled_init_method = scaled_init_method_normal (args .init_method_std ,
88
- args .num_layers )
89
80
self .pre_process = pre_process
90
81
self .post_process = post_process
91
82
self .add_encoder = add_encoder
92
83
self .add_decoder = add_decoder
93
84
94
85
self .language_model , self ._language_model_key = get_language_model (
86
+ config = config ,
95
87
num_tokentypes = num_tokentypes ,
96
88
add_pooler = False ,
97
89
add_encoder = add_encoder ,
98
90
add_decoder = add_decoder ,
99
91
encoder_attn_mask_type = AttnMaskType .padding ,
100
- init_method = init_method ,
101
- scaled_init_method = scaled_init_method ,
102
92
pre_process = self .pre_process ,
103
93
post_process = self .post_process )
104
94
105
- self .initialize_word_embeddings (init_method_normal )
95
+ self .initialize_word_embeddings ()
106
96
107
97
if self .post_process and self .add_decoder :
108
98
self .lm_head = T5LMHead (
0 commit comments