diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index e32f64cf1c..2c0a86455a 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -7,33 +7,21 @@ import os import sys import argparse - -import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from contextlib import nullcontext import torch import torch.distributed as dist import torch.nn.functional as F -from torch import nn, optim +from torch import optim from torch.distributed import DeviceMesh -from torch.distributed._composable.fsdp import fully_shard +from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy from torch.distributed.device_mesh import init_device_mesh -from contextlib import nullcontext - -class SimpleNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super(SimpleNet, self).__init__() - self.fc1 = te.Linear(input_size, hidden_size) - self.fc2 = te.Linear(hidden_size, output_size) - - def forward(self, x): - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling -def save_custom_attrs(module): +def _save_custom_attrs(module): custom_attrs = {} for name, param in module.named_parameters(): attrs = vars(param) @@ -41,25 +29,47 @@ def save_custom_attrs(module): return custom_attrs -def restore_custom_attrs(module, custom_attrs): +def _restore_custom_attrs(module, custom_attrs): for name, param in module.named_parameters(): if name in custom_attrs: for attr_name, attr_value in custom_attrs[name].items(): setattr(param, attr_name, attr_value) +def _te_layer_type(layer_name): + te_layer_types = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + te_layer_names = [layer.__name__ for layer in te_layer_types] + te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types)) + if layer_name.lower() not in te_layer_map.keys(): + raise argparse.ArgumentTypeError( + f'"{layer_name}" is not a valid Transformer Engine layer, ' + f"please choose layer from {te_layer_names}." + ) + return te_layer_map[layer_name.lower()] + + def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") - parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") - parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") - parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") - parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") parser.add_argument( - "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." + "--layer-type", + type=_te_layer_type, + default=te.TransformerLayer, + help="Transformer Engine layer type", ) + parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads") + parser.add_argument("--head-dim", type=int, default=64, help="Attention head size") + parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input") + parser.add_argument("--seq-length", type=int, default=1024, help="Sequence length of input") parser.add_argument( - "--iter", type=int, default=10, help="Number of iterations for forward pass" + "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) + parser.add_argument("--iter", type=int, default=3, help="Number of iterations for forward pass") parser.add_argument("--seed", type=int, default=42, help="RNG seed.") # Adding hsdp_dim as a list argument, comma-separated parser.add_argument( @@ -74,7 +84,26 @@ def _parse_args(argv=None, namespace=None): return args -sub_modules_to_wrap = [te.Linear] +def _init_te_model(config): + hidden_size = config.num_heads * config.head_dim + args = [hidden_size, hidden_size] + inp_shape = [config.seq_length, config.batch_size, hidden_size] + out_shape = [config.seq_length, config.batch_size, hidden_size] + kwargs = { + "params_dtype": torch.bfloat16, + } + if config.layer_type == te.LayerNormLinear: + args[1] *= 3 # QKV projection + out_shape[-1] *= 3 + elif config.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + args[1] *= 4 # FFN hidden size + args.append(config.num_heads) + kwargs["fuse_qkv_params"] = True + if config.layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + + model = config.layer_type(*args, **kwargs) + return model, inp_shape, out_shape def _train(args): @@ -98,30 +127,13 @@ def _train(args): } assert dist.is_nccl_available() dist.init_process_group(**dist_init_kwargs) - nccl_world = dist.new_group(backend="nccl") device = torch.device(f"cuda:{LOCAL_RANK}") - # FP8 Configuration + # Initialize TE model fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - - if not args.fp8_init: - # Build model context (FP8 init) - build_model_context = nullcontext - build_model_context_args = {} - - from transformer_engine.pytorch import fp8_model_init - - build_model_context = fp8_model_init - build_model_context_args["enabled"] = True - - # Build the model with the specified context - with build_model_context(**build_model_context_args): - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) - else: - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) - # Move the model to the correct device - + with te.fp8_model_init(enabled=args.fp8_init, recipe=fp8_recipe): + model, inp_shape, out_shape = _init_te_model(args) model.to(device) if LOCAL_RANK == 0: @@ -132,13 +144,13 @@ def _train(args): if LOCAL_RANK == 0: print(f"sharding-dims:{args.sharding_dims}") # Setup the sharding mesh for FSDP/HSDP - if args.sharding_dims == None: # FSDP + if args.sharding_dims is None: # FSDP mesh = DeviceMesh("cuda", device_ids) elif len(args.sharding_dims) == 1: - assert args.sharding_dims[0] == device_ids[-1] + 1 + assert args.sharding_dims[0] == world_size mesh = DeviceMesh("cuda", device_ids) elif len(args.sharding_dims) == 2: # HSDP - assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 + assert args.sharding_dims[0] * args.sharding_dims[1] == world_size mesh = init_device_mesh( "cuda", (args.sharding_dims[0], args.sharding_dims[1]), @@ -148,25 +160,42 @@ def _train(args): assert False # Apply FSDP/HSDP - custom_attrs = save_custom_attrs(model) - for sub_module in model.modules(): - if any( - isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap - ): - fully_shard(sub_module, mesh=mesh) - fully_shard(model, mesh=mesh) - restore_custom_attrs(model, custom_attrs) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=True, + ) + custom_attrs = _save_custom_attrs(model) + if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: + # Composite modules require wrapping submodules bottom-up for the correct parameter grouping + sub_modules_to_wrap = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + ] + for sub_module in model.modules(): + if any( + isinstance(sub_module, sub_module_to_wrap) + for sub_module_to_wrap in sub_modules_to_wrap + ): + fully_shard(sub_module, mesh=mesh, mp_policy=None if args.fp8_init else mp_policy) + fully_shard(model, mesh=mesh, mp_policy=None if args.fp8_init else mp_policy) + _restore_custom_attrs(model, custom_attrs) optimizer = optim.Adam(model.parameters(), lr=1e-3) for iteration in range(args.iter): # Zero the parameter gradients optimizer.zero_grad() - input_data = torch.randn(args.batch_size, args.input_size).to(device) - output = model(input_data) - target = torch.randn(args.batch_size, args.output_size).to(device) - loss = F.mse_loss(output, target) - loss.backward() + input_data = torch.randn(inp_shape).to(device) + target = torch.randn(out_shape).to(device) + with torch.autograd.detect_anomaly(): + with torch.amp.autocast(enabled=not args.fp8_init, device_type="cuda"): + with te.fp8_autocast(enabled=args.fp8_init, fp8_recipe=fp8_recipe): + output = model(input_data) + loss = F.mse_loss(output, target) + loss.backward() optimizer.step() if LOCAL_RANK == 0: print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index f5c186a3bc..196a33bd9a 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -3,24 +3,24 @@ # See LICENSE for license information. import os -import pytest import subprocess from pathlib import Path -from transformer_engine.pytorch import torch_version -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import pytest import torch +import transformer_engine.pytorch as te -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.FP8GlobalStateManager.is_fp8_available() NUM_PROCS: int = torch.cuda.device_count() -def _run_test(fp_init, sharding_dims): +def _run_test(fp_init, sharding_dims, layer_type): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] + test_cmd = ["--layer-type", layer_type.__name__] if fp_init: test_cmd += ["--fp8-init"] if len(sharding_dims) == 1: @@ -29,16 +29,20 @@ def _run_test(fp_init, sharding_dims): test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] else: assert False - result = subprocess.run(test_cmd, env=os.environ, check=True) + return subprocess.run(test_cmd, env=os.environ, check=True) @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") -@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") +@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) -def test_distributed(fp8_init, sharding_dims): - +@pytest.mark.parametrize( + "layer_type", + (te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, te.TransformerLayer), +) +def test_torch_fsdp2(fp8_init, sharding_dims, layer_type): + """Test a Transformer Engine Linear layer with PyTorch native FSDP2.""" # Skip invalid configurations if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") @@ -46,13 +50,4 @@ def test_distributed(fp8_init, sharding_dims): if fp8_init and not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims) - - -def test_dummy() -> None: - """Dummy test - - pytest returns exit code 5 if all tests are skipped. - - """ - pass + _run_test(fp8_init, sharding_dims, layer_type)