Skip to content

Commit 38607ae

Browse files
authored
Fix causal attention mask (#306)
1 parent 0d0d84c commit 38607ae

File tree

5 files changed

+62
-30
lines changed

5 files changed

+62
-30
lines changed

megatron/model/fused_softmax.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
15+
from functools import lru_cache
1616

1717
import torch
1818
import torch.nn as nn
@@ -201,6 +201,12 @@ def forward_fused_softmax(self, input, mask):
201201
else:
202202
return ScaledSoftmax.apply(input, scale)
203203

204+
@staticmethod
205+
@lru_cache(maxsize=1)
206+
def get_causal_mask(sequence_length: int):
207+
mask = torch.ones(1, 1, sequence_length, sequence_length, dtype=torch.bool, device=torch.cuda.current_device())
208+
return torch.triu(mask, diagonal=1)
209+
204210
def forward_torch_softmax(self, input, mask):
205211
if self.input_in_float16 and self.softmax_in_fp32:
206212
input = input.float()
@@ -210,8 +216,8 @@ def forward_torch_softmax(self, input, mask):
210216

211217
if self.attn_mask_type == AttnMaskType.causal:
212218
assert mask is None
213-
mask = torch.ones_like(input, dtype=torch.bool)
214-
mask = torch.triu(mask, diagonal=1, out=mask)
219+
assert input.shape[2] == input.shape[3]
220+
mask = self.get_causal_mask(input.shape[2])
215221

216222
mask_output = self.mask_func(input, mask) if mask is not None else input
217223

megatron/model/gpt_model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,13 @@ def _to_float16(inputs):
232232
tied_weight_attr='word_embeddings_weight'))
233233

234234
if args.fp32_residual_connection:
235-
if hasattr(args, 'attn_mask'):
235+
if getattr(args, 'pretrain_causal_attention', False):
236236
self.specs.append(lambda x: x.transpose(0, 1).contiguous().float())
237237
else:
238238
# EmbeddingPipe returns attention mask as well
239239
self.specs.append(lambda x: (x[0].transpose(0, 1).contiguous().float(), *x[1:]))
240240
else:
241-
if hasattr(args, 'attn_mask'):
241+
if getattr(args, 'pretrain_causal_attention', False):
242242
self.specs.append(lambda x: x.transpose(0, 1).contiguous())
243243
else:
244244
# EmbeddingPipe returns attention mask as well
@@ -256,7 +256,7 @@ def _to_float16(inputs):
256256

257257
# Undo data format change
258258
def undo(x):
259-
if not hasattr(args, 'attn_mask'):
259+
if not getattr(args, 'pretrain_causal_attention', False):
260260
x = x[0]
261261
return x.transpose(0, 1).contiguous()
262262
self.specs.append(undo)

megatron/model/language_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(self, inputs, **kwargs):
274274

275275
input_ids = inputs[0]
276276
position_ids = inputs[1]
277-
if hasattr(self._args, 'attn_mask'):
277+
if getattr(self._args, 'pretrain_causal_attention', False):
278278
attention_mask = None
279279
else:
280280
attention_mask = inputs[2]
@@ -287,7 +287,7 @@ def forward(self, inputs, **kwargs):
287287
embeddings = super().forward(input_ids, position_ids, tokentype_ids=tokentype_ids)
288288

289289
# If cmd args has attn_mask, we don't forward it as an activation.
290-
if hasattr(self._args, 'attn_mask'):
290+
if getattr(self._args, 'pretrain_causal_attention', False):
291291
return embeddings
292292
else:
293293
return embeddings, attention_mask

pretrain_gpt.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,7 @@ def model_provider(pre_process=True, post_process=True):
5454
enabled=args.zero_stage == 3,
5555
mpu=mpu):
5656
if args.deepspeed:
57-
# Hack @thomasw21 to get fused_softmax.forward_torch_softmax working
58-
args.attn_mask = None
59-
57+
args.pretrain_causal_attention = True
6058
model = GPTModelPipe(
6159
num_tokentypes=0,
6260
parallel_output=True,

tests/test_model.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44

55
import deepspeed
66
import torch
7+
from parameterized import parameterized
78
from torch import nn
89
import torch.nn.functional as F
910

11+
from megatron.enums import AttnMaskType
1012
from megatron.model.fused_layer_norm import MixedFusedLayerNorm
1113
from packaging import version
1214

1315
from megatron import initialize_megatron, get_args, get_tokenizer, global_vars
14-
from megatron.model.fused_softmax import ScaledMaskedSoftmax
16+
from megatron.model.fused_softmax import ScaledMaskedSoftmax, FusedScaleMaskSoftmax
17+
from megatron.model.utils import attention_mask_func
1518
from megatron.testing_utils import TestCasePlus, mockenv_context, flatten_arguments, torch_assert_equal, \
1619
torch_assert_close, require_torch_bf16
1720
from megatron.training import setup_model_and_optimizer
@@ -366,7 +369,8 @@ def test_fused_layer_norm(self):
366369

367370
torch_assert_equal(mfln_output, torch_layer_norm_output)
368371

369-
def test_fused_masked_softmax(self):
372+
@parameterized.expand([(attn_mask_type,) for attn_mask_type in AttnMaskType])
373+
def test_fused_masked_softmax(self, attn_mask_type: AttnMaskType):
370374
command_args = get_default_args(self.test_file_dir_str)
371375

372376
with patch('sys.argv', flatten_arguments(command_args)):
@@ -382,30 +386,54 @@ def test_fused_masked_softmax(self):
382386
device="cuda",
383387
dtype=args.params_dtype
384388
)
385-
dummy_attention_mask = torch.randn(
386-
args.micro_batch_size,
387-
1, # `args.num_attention_heads` not implemented in our cuda kernel
388-
args.seq_length,
389-
args.seq_length,
390-
device="cuda",
391-
dtype=args.params_dtype
392-
) < 0
389+
if attn_mask_type == AttnMaskType.causal:
390+
dummy_attention_mask = None
391+
else:
392+
dummy_attention_mask = torch.randn(
393+
args.micro_batch_size,
394+
1, # `args.num_attention_heads` not implemented in our cuda kernel
395+
args.seq_length,
396+
args.seq_length,
397+
device="cuda",
398+
dtype=args.params_dtype
399+
) < 0
393400
scale = torch.rand(())
394401

395-
fused_scaled_softmax = ScaledMaskedSoftmax
396-
397-
fused_output = fused_scaled_softmax.apply(dummy_input, dummy_attention_mask, scale)
402+
fused_scaled_softmax = FusedScaleMaskSoftmax(
403+
input_in_fp16=args.params_dtype == torch.float16,
404+
input_in_bf16=args.params_dtype == torch.bfloat16,
405+
attn_mask_type=attn_mask_type,
406+
scaled_masked_softmax_fusion=True,
407+
mask_func=attention_mask_func,
408+
softmax_in_fp32=True,
409+
scale=scale,
410+
)
411+
unfused_scaled_softmax = FusedScaleMaskSoftmax(
412+
input_in_fp16=args.params_dtype == torch.float16,
413+
input_in_bf16=args.params_dtype == torch.bfloat16,
414+
attn_mask_type=attn_mask_type,
415+
scaled_masked_softmax_fusion=False,
416+
mask_func=attention_mask_func,
417+
softmax_in_fp32=True,
418+
scale=scale,
419+
)
398420

399-
# mimick the same via torch
400-
output = scale * dummy_input
401-
output = output.masked_fill(dummy_attention_mask, torch.finfo(args.params_dtype).min)
402-
output = F.softmax(output, dim=-1)
421+
self.assertTrue(fused_scaled_softmax.is_kernel_available(dummy_attention_mask, *dummy_input.size()))
422+
fused_output = fused_scaled_softmax(dummy_input, dummy_attention_mask)
423+
self.assertFalse(unfused_scaled_softmax.is_kernel_available(dummy_attention_mask, *dummy_input.size()))
424+
unfused_output = unfused_scaled_softmax(dummy_input, dummy_attention_mask)
403425

404426
# Test that the nonzeros are the same with the mask
405427
for i in range(args.num_attention_heads):
406-
torch_assert_equal(torch.nonzero(fused_output[:, i]), torch.nonzero(~dummy_attention_mask[:, 0]))
428+
if dummy_attention_mask is None:
429+
# Make sure it's causal, values in the lower triangle should be not zero.
430+
non_zero_values = torch.tril(torch.ones_like(fused_output[:, i]))
431+
torch_assert_equal(torch.nonzero(fused_output[:, i]), torch.nonzero(non_zero_values))
432+
else:
433+
torch_assert_equal(torch.nonzero(fused_output[:, i]), torch.nonzero(~dummy_attention_mask[:, 0]))
434+
407435
# Cuda kernel produces slightly different results
408-
torch_assert_close(fused_output, output)
436+
torch_assert_close(fused_output, unfused_output)
409437

410438

411439
def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_is_not_causal_across_segments(self):

0 commit comments

Comments
 (0)