Skip to content

[PyTorch] Update PyTorch FSDP2 test to cover all TE layer types #1777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 90 additions & 61 deletions tests/pytorch/distributed/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,59 +7,69 @@
import os
import sys
import argparse

import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
from contextlib import nullcontext

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import nn, optim
from torch import optim
from torch.distributed import DeviceMesh
from torch.distributed._composable.fsdp import fully_shard
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.device_mesh import init_device_mesh
from contextlib import nullcontext


class SimpleNet(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(SimpleNet, self).__init__()
self.fc1 = te.Linear(input_size, hidden_size)
self.fc2 = te.Linear(hidden_size, output_size)

def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling


def save_custom_attrs(module):
def _save_custom_attrs(module):
custom_attrs = {}
for name, param in module.named_parameters():
attrs = vars(param)
custom_attrs[name] = {k: v for k, v in attrs.items()}
return custom_attrs


def restore_custom_attrs(module, custom_attrs):
def _restore_custom_attrs(module, custom_attrs):
for name, param in module.named_parameters():
if name in custom_attrs:
for attr_name, attr_value in custom_attrs[name].items():
setattr(param, attr_name, attr_value)


def _te_layer_type(layer_name):
te_layer_types = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
te.MultiheadAttention,
te.TransformerLayer,
]
te_layer_names = [layer.__name__ for layer in te_layer_types]
te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types))
if layer_name.lower() not in te_layer_map.keys():
raise argparse.ArgumentTypeError(
f'"{layer_name}" is not a valid Transformer Engine layer, '
f"please choose layer from {te_layer_names}."
)
return te_layer_map[layer_name.lower()]


def _parse_args(argv=None, namespace=None):
parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()")
parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model")
parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size")
parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model")
parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model")
parser.add_argument(
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
"--layer-type",
type=_te_layer_type,
default=te.TransformerLayer,
help="Transformer Engine layer type",
)
parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads")
parser.add_argument("--head-dim", type=int, default=64, help="Attention head size")
parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input")
parser.add_argument("--seq-length", type=int, default=1024, help="Sequence length of input")
parser.add_argument(
"--iter", type=int, default=10, help="Number of iterations for forward pass"
"--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8."
)
parser.add_argument("--iter", type=int, default=3, help="Number of iterations for forward pass")
parser.add_argument("--seed", type=int, default=42, help="RNG seed.")
# Adding hsdp_dim as a list argument, comma-separated
parser.add_argument(
Expand All @@ -74,7 +84,26 @@ def _parse_args(argv=None, namespace=None):
return args


sub_modules_to_wrap = [te.Linear]
def _init_te_model(config):
hidden_size = config.num_heads * config.head_dim
args = [hidden_size, hidden_size]
inp_shape = [config.seq_length, config.batch_size, hidden_size]
out_shape = [config.seq_length, config.batch_size, hidden_size]
kwargs = {
"params_dtype": torch.bfloat16,
}
if config.layer_type == te.LayerNormLinear:
args[1] *= 3 # QKV projection
out_shape[-1] *= 3
elif config.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
args[1] *= 4 # FFN hidden size
args.append(config.num_heads)
kwargs["fuse_qkv_params"] = True
if config.layer_type is te.MultiheadAttention:
kwargs["input_layernorm"] = True

model = config.layer_type(*args, **kwargs)
return model, inp_shape, out_shape


def _train(args):
Expand All @@ -98,30 +127,13 @@ def _train(args):
}
assert dist.is_nccl_available()
dist.init_process_group(**dist_init_kwargs)
nccl_world = dist.new_group(backend="nccl")
device = torch.device(f"cuda:{LOCAL_RANK}")

# FP8 Configuration
# Initialize TE model
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

if not args.fp8_init:
# Build model context (FP8 init)
build_model_context = nullcontext
build_model_context_args = {}

from transformer_engine.pytorch import fp8_model_init

build_model_context = fp8_model_init
build_model_context_args["enabled"] = True

# Build the model with the specified context
with build_model_context(**build_model_context_args):
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
else:
model = SimpleNet(args.input_size, args.hidden_size, args.output_size)
# Move the model to the correct device

with te.fp8_model_init(enabled=args.fp8_init, recipe=fp8_recipe):
model, inp_shape, out_shape = _init_te_model(args)
model.to(device)

if LOCAL_RANK == 0:
Expand All @@ -132,13 +144,13 @@ def _train(args):
if LOCAL_RANK == 0:
print(f"sharding-dims:{args.sharding_dims}")
# Setup the sharding mesh for FSDP/HSDP
if args.sharding_dims == None: # FSDP
if args.sharding_dims is None: # FSDP
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 1:
assert args.sharding_dims[0] == device_ids[-1] + 1
assert args.sharding_dims[0] == world_size
mesh = DeviceMesh("cuda", device_ids)
elif len(args.sharding_dims) == 2: # HSDP
assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1
assert args.sharding_dims[0] * args.sharding_dims[1] == world_size
mesh = init_device_mesh(
"cuda",
(args.sharding_dims[0], args.sharding_dims[1]),
Expand All @@ -148,24 +160,41 @@ def _train(args):
assert False

# Apply FSDP/HSDP
custom_attrs = save_custom_attrs(model)
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh)
fully_shard(model, mesh=mesh)
restore_custom_attrs(model, custom_attrs)
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
output_dtype=torch.bfloat16,
cast_forward_inputs=True,
)
custom_attrs = _save_custom_attrs(model)
if args.layer_type in [te.MultiheadAttention, te.TransformerLayer]:
# Composite modules require wrapping submodules bottom-up for the correct parameter grouping
sub_modules_to_wrap = [
te.Linear,
te.LayerNormLinear,
te.LayerNormMLP,
]
for sub_module in model.modules():
if any(
isinstance(sub_module, sub_module_to_wrap)
for sub_module_to_wrap in sub_modules_to_wrap
):
fully_shard(sub_module, mesh=mesh, mp_policy=None if args.fp8_init else mp_policy)
fully_shard(model, mesh=mesh, mp_policy=None if args.fp8_init else mp_policy)
_restore_custom_attrs(model, custom_attrs)

optimizer = optim.Adam(model.parameters(), lr=1e-3)

for iteration in range(args.iter):
# Zero the parameter gradients
optimizer.zero_grad()
input_data = torch.randn(args.batch_size, args.input_size).to(device)
output = model(input_data)
target = torch.randn(args.batch_size, args.output_size).to(device)
loss = F.mse_loss(output, target)
input_data = torch.randn(inp_shape).to(device)
target = torch.randn(out_shape).to(device)
with torch.autograd.detect_anomaly():
with torch.amp.autocast(enabled=not args.fp8_init, device_type="cuda"):
with te.fp8_autocast(enabled=args.fp8_init, fp8_recipe=fp8_recipe):
output = model(input_data)
loss = F.mse_loss(output, target)
loss.backward()
optimizer.step()
if LOCAL_RANK == 0:
Expand Down
33 changes: 14 additions & 19 deletions tests/pytorch/distributed/test_torch_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@
# See LICENSE for license information.

import os
import pytest
import subprocess
from pathlib import Path
from transformer_engine.pytorch import torch_version
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager

import pytest
import torch

import transformer_engine.pytorch as te

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_available, reason_for_no_fp8 = te.FP8GlobalStateManager.is_fp8_available()

NUM_PROCS: int = torch.cuda.device_count()


def _run_test(fp_init, sharding_dims):
def _run_test(fp_init, sharding_dims, layer_type):
test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py"
test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)]

test_cmd = ["--layer-type", layer_type.__name__]
if fp_init:
test_cmd += ["--fp8-init"]
if len(sharding_dims) == 1:
Expand All @@ -29,30 +29,25 @@ def _run_test(fp_init, sharding_dims):
test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])]
else:
assert False
result = subprocess.run(test_cmd, env=os.environ, check=True)
return subprocess.run(test_cmd, env=os.environ, check=True)


@pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs")
@pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs")
@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+")
@pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2]))
@pytest.mark.parametrize("fp8_init", (False, True))
def test_distributed(fp8_init, sharding_dims):

@pytest.mark.parametrize(
"layer_type",
(te.Linear, te.LayerNormLinear, te.LayerNormMLP, te.MultiheadAttention, te.TransformerLayer),
)
def test_torch_fsdp2(fp8_init, sharding_dims, layer_type):
"""Test a Transformer Engine Linear layer with PyTorch native FSDP2."""
# Skip invalid configurations
if torch.cuda.device_count() < 4:
pytest.skip("FSDP2 test requires at least 4 GPUs")

if fp8_init and not fp8_available:
pytest.skip(reason_for_no_fp8)

_run_test(fp8_init, sharding_dims)


def test_dummy() -> None:
"""Dummy test

pytest returns exit code 5 if all tests are skipped.

"""
pass
_run_test(fp8_init, sharding_dims, layer_type)
Loading