4
4
5
5
import deepspeed
6
6
import torch
7
+ from parameterized import parameterized
7
8
from torch import nn
8
9
import torch .nn .functional as F
9
10
11
+ from megatron .enums import AttnMaskType
10
12
from megatron .model .fused_layer_norm import MixedFusedLayerNorm
11
13
from packaging import version
12
14
13
15
from megatron import initialize_megatron , get_args , get_tokenizer , global_vars
14
- from megatron .model .fused_softmax import ScaledMaskedSoftmax
16
+ from megatron .model .fused_softmax import ScaledMaskedSoftmax , FusedScaleMaskSoftmax
17
+ from megatron .model .utils import attention_mask_func
15
18
from megatron .testing_utils import TestCasePlus , mockenv_context , flatten_arguments , torch_assert_equal , \
16
19
torch_assert_close , require_torch_bf16
17
20
from megatron .training import setup_model_and_optimizer
@@ -366,7 +369,8 @@ def test_fused_layer_norm(self):
366
369
367
370
torch_assert_equal (mfln_output , torch_layer_norm_output )
368
371
369
- def test_fused_masked_softmax (self ):
372
+ @parameterized .expand ([(attn_mask_type ,) for attn_mask_type in AttnMaskType ])
373
+ def test_fused_masked_softmax (self , attn_mask_type : AttnMaskType ):
370
374
command_args = get_default_args (self .test_file_dir_str )
371
375
372
376
with patch ('sys.argv' , flatten_arguments (command_args )):
@@ -382,30 +386,54 @@ def test_fused_masked_softmax(self):
382
386
device = "cuda" ,
383
387
dtype = args .params_dtype
384
388
)
385
- dummy_attention_mask = torch .randn (
386
- args .micro_batch_size ,
387
- 1 , # `args.num_attention_heads` not implemented in our cuda kernel
388
- args .seq_length ,
389
- args .seq_length ,
390
- device = "cuda" ,
391
- dtype = args .params_dtype
392
- ) < 0
389
+ if attn_mask_type == AttnMaskType .causal :
390
+ dummy_attention_mask = None
391
+ else :
392
+ dummy_attention_mask = torch .randn (
393
+ args .micro_batch_size ,
394
+ 1 , # `args.num_attention_heads` not implemented in our cuda kernel
395
+ args .seq_length ,
396
+ args .seq_length ,
397
+ device = "cuda" ,
398
+ dtype = args .params_dtype
399
+ ) < 0
393
400
scale = torch .rand (())
394
401
395
- fused_scaled_softmax = ScaledMaskedSoftmax
396
-
397
- fused_output = fused_scaled_softmax .apply (dummy_input , dummy_attention_mask , scale )
402
+ fused_scaled_softmax = FusedScaleMaskSoftmax (
403
+ input_in_fp16 = args .params_dtype == torch .float16 ,
404
+ input_in_bf16 = args .params_dtype == torch .bfloat16 ,
405
+ attn_mask_type = attn_mask_type ,
406
+ scaled_masked_softmax_fusion = True ,
407
+ mask_func = attention_mask_func ,
408
+ softmax_in_fp32 = True ,
409
+ scale = scale ,
410
+ )
411
+ unfused_scaled_softmax = FusedScaleMaskSoftmax (
412
+ input_in_fp16 = args .params_dtype == torch .float16 ,
413
+ input_in_bf16 = args .params_dtype == torch .bfloat16 ,
414
+ attn_mask_type = attn_mask_type ,
415
+ scaled_masked_softmax_fusion = False ,
416
+ mask_func = attention_mask_func ,
417
+ softmax_in_fp32 = True ,
418
+ scale = scale ,
419
+ )
398
420
399
- # mimick the same via torch
400
- output = scale * dummy_input
401
- output = output . masked_fill ( dummy_attention_mask , torch . finfo ( args . params_dtype ). min )
402
- output = F . softmax ( output , dim = - 1 )
421
+ self . assertTrue ( fused_scaled_softmax . is_kernel_available ( dummy_attention_mask , * dummy_input . size ()))
422
+ fused_output = fused_scaled_softmax ( dummy_input , dummy_attention_mask )
423
+ self . assertFalse ( unfused_scaled_softmax . is_kernel_available ( dummy_attention_mask , * dummy_input . size ()) )
424
+ unfused_output = unfused_scaled_softmax ( dummy_input , dummy_attention_mask )
403
425
404
426
# Test that the nonzeros are the same with the mask
405
427
for i in range (args .num_attention_heads ):
406
- torch_assert_equal (torch .nonzero (fused_output [:, i ]), torch .nonzero (~ dummy_attention_mask [:, 0 ]))
428
+ if dummy_attention_mask is None :
429
+ # Make sure it's causal, values in the lower triangle should be not zero.
430
+ non_zero_values = torch .tril (torch .ones_like (fused_output [:, i ]))
431
+ torch_assert_equal (torch .nonzero (fused_output [:, i ]), torch .nonzero (non_zero_values ))
432
+ else :
433
+ torch_assert_equal (torch .nonzero (fused_output [:, i ]), torch .nonzero (~ dummy_attention_mask [:, 0 ]))
434
+
407
435
# Cuda kernel produces slightly different results
408
- torch_assert_close (fused_output , output )
436
+ torch_assert_close (fused_output , unfused_output )
409
437
410
438
411
439
def test_non_causal_decoder_model_with_packed_input_passed_with_attention_mask_is_not_causal_across_segments (self ):
0 commit comments