Skip to content

Commit

Permalink
[Random] Random state management (#38)
Browse files Browse the repository at this point in the history
Co-authored-by: chhzh123 <[email protected]>
  • Loading branch information
comaniac and chhzh123 authored Feb 3, 2023
1 parent 3001c8a commit d93764c
Show file tree
Hide file tree
Showing 17 changed files with 627 additions and 24 deletions.
5 changes: 3 additions & 2 deletions ci/task_unit_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ nvidia-smi -L
echo "Running unit tests..."
# -r: redirect the output of local rank 1 to None so that
# only local rank 0's output is printed to the console.
torchrun --nproc_per_node 2 -r 1:1 -m pytest tests
# -p "no:randomly": disable randomly plugin for sharding tests.
torchrun --nproc_per_node 2 -r 1:1 -m pytest -p "no:randomly" tests

echo "Downloading test data..."
bash benchmark/download_benchmark_dataset.sh
echo "Running end-to-end tests..."
python3 -m pytest -s tests/end2end.py
python3 -m pytest -s -p "no:randomly" tests/end2end.py
18 changes: 16 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import random

import numpy as np
import pytest

import torch
from torch import distributed as dist

Expand All @@ -18,7 +19,6 @@ def pytest_collection_modifyitems(items):
@pytest.fixture(scope="session")
def init_dist(request):
"""Initialize the distributed group once in the entire test session."""
torch.manual_seed(9999)
try:
dist.init_process_group(backend="nccl")
except Exception as err:
Expand All @@ -31,3 +31,17 @@ def destory_dist():
pass

request.addfinalizer(destory_dist)


@pytest.fixture(scope="function", autouse=True)
def random_seed():
"""Set random seed to 1) make the tests deterministic, and 2) make every
device generate the same weights for tensor parallelism tests.
Note that if you run pytest with "randomly" plugin enabled, this fixture
will have no effect. You can disable the plugin with
pytest -p "no:randomly" ...
"""
random.seed(9999)
np.random.seed(9999)
torch.manual_seed(9999)
19 changes: 9 additions & 10 deletions examples/gpt/deepspeed_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import deepspeed
import torch
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_global_rank
from deepspeed.utils import RepeatingLoader
from transformers import GPTNeoForCausalLM, AutoConfig

import slapo
from slapo import set_random_seed
from slapo.logger import get_logger
from slapo.op.cross_entropy import ParallelCrossEntropy
from slapo.op import ParallelCrossEntropy
from slapo.utils.report import report_memory

from model import schedule_model
Expand Down Expand Up @@ -127,6 +127,7 @@ def train(args):
delay_init=enable_pipeline,
sequence_parallel=args.sequence_parallel,
)
tp_rank = sch.rank

loss_fct = ParallelCrossEntropy(group=group)

Expand Down Expand Up @@ -188,16 +189,14 @@ def loss_fn(outputs, labels):
report_memory(msg="After building model")

if args.disable_pipeline or args.sequence_parallel:
random_seed = 2013 + dist.get_rank()
set_random_seed(2013, model.mpu.get_data_parallel_rank(), None, tp_rank)
else:
random_seed = (
2013
+ 100 * model.mpu.get_pipe_parallel_rank()
+ 10 * model.mpu.get_data_parallel_rank()
set_random_seed(
2013,
model.mpu.get_data_parallel_rank(),
model.mpu.get_pipe_parallel_rank(),
tp_rank,
)
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# for now always use seq_length 1024
# TODO: make the dataloader generic to different sequence length
Expand Down
4 changes: 4 additions & 0 deletions examples/gpt/megatron_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def get_model(
delay_init=delay_init,
)
model, _ = slapo.build(sch, init_weights=model._init_weights)
# Note 1: We assume no DP and PP in this script.
# Note 2: This overrides Megatron random seed management, so we only use
# this script for benchmarking.
slapo.set_random_seed(2013, None, None, sch.rank)
report_memory()

elif impl == "torchscript":
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def schedule_model(
# if MP group > 1.
attn_path, out_proj_name = "h.N.attn.attention", "out_proj"
if disable_flash_attn:
logger.info("Disabled Flash Attention", rank=0)
logger.info("Disabled Flash Attention", ranks=0)
cnt = replace_and_shard_attention(
sch[prefix],
config,
Expand Down
19 changes: 18 additions & 1 deletion examples/gpt/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from slapo import init_empty_weights
from slapo.pattern import call_module
from slapo.op.linear import FusedQKV
from slapo import init_empty_weights, get_cuda_rng_tracker


def trace_attention(sch, config, attn_path="h.N.attn.attention"):
Expand Down Expand Up @@ -89,7 +90,7 @@ def replace_and_shard_attention(
sequence_parallel=False,
):
from epoi.inject.policy.gpt import InjectHFGPTAttentionPolicy
from epoi.ops.xformers_attn import GenericSelfAttention
from epoi.ops.xformers_attn import GenericSelfAttention, MemoryEfficientAttentionOp

class SelfAttention(nn.Module):
"""A wrapper to align the original GPTNeoAttention forward signature."""
Expand All @@ -112,6 +113,13 @@ def forward(
# is present_key_value and only used by in inference.
return outputs[:1]

class MemoryEfficientAttentionWithRNGOp(MemoryEfficientAttentionOp):
def forward(self, query_layer, key_layer, value_layer, attention_mask, p):
with get_cuda_rng_tracker().fork():
return super().forward(
query_layer, key_layer, value_layer, attention_mask, p
)

num_layers, num_heads, hidden_size = (
config.num_layers,
config.num_heads,
Expand Down Expand Up @@ -163,13 +171,22 @@ def pattern(x: torch.Tensor) -> torch.Tensor:
mode="fwd_post", sync_op_or_fn="reduce_scatter", axis=1
)
else:
# Shard qkv and output projection.
sub_sch["module.FusedQKV_0.fused_linear"].sync(
mode="bwd_post", sync_op_or_fn="all_reduce"
)
sub_sch["module.out_proj"].sync(
mode="fwd_post", sync_op_or_fn="all_reduce"
)

# In this case, the attention dropout in between has to
# use different random seeds.
new_op = MemoryEfficientAttentionWithRNGOp(
sub_sch["module"]["attn_op"].mod.attn_op_name,
sub_sch["module"]["attn_op"].mod.apply_causal_mask,
)
sub_sch["module"]["attn_op"].replace(new_op)

cnt += 1

return cnt
Expand Down
2 changes: 2 additions & 0 deletions slapo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,5 @@
from .tracer import *
from .utils import *
from .version import __version__
from .random import set_random_seed, get_cuda_rng_tracker, is_random_seed_set
from .checkpoint import checkpoint
124 changes: 124 additions & 0 deletions slapo/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
# Modification: Megatron-LM.
# See https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
"""Model checkpoints and activation checkpointing with the consideration
of 3D parallelism and random states.
"""
import torch
from torch.utils.checkpoint import detach_variable
from torch.utils.checkpoint import checkpoint as torch_checkpoint

from .random import get_cuda_rng_tracker, is_random_seed_set, _set_cuda_rng_state


class CheckpointFunctionWithRNGTracker(torch.autograd.Function):
"""This function is adapted from torch.utils.checkpoint with
two main changes:
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`
2) the states in the model parallel tracker are also properly
tracked/set/reset.
"""

# pylint: disable=abstract-method, arguments-differ

@staticmethod
def forward(ctx, run_function, *args):
ctx.run_function = run_function

# Copy the rng states.
ctx.fwd_cpu_rng_state = torch.get_rng_state()
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

with torch.no_grad():
outputs = run_function(*args)

# Save non-tensor inputs in ctx, keep a placeholder None for tensors
# to be filled out during the backward.
ctx.inputs = []
ctx.tensor_indices = []
tensor_inputs = []
for idx, arg in enumerate(args):
if torch.is_tensor(arg):
tensor_inputs.append(arg)
ctx.tensor_indices.append(idx)
ctx.inputs.append(None)
else:
ctx.inputs.append(arg)

# We detach the tensor inputs to make sure we hold a reference to
# the tensor data. This is needed because when pipeline is enabled,
# the tensor data may be released by the pipeline engine as it does
# not know that the tensor is used in the backward pass.
ctx.save_for_backward(*detach_variable(tuple(tensor_inputs)))

return outputs

@staticmethod
def backward(ctx, *args):
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
# Copy the list to avoid modifying original list.
inputs = list(ctx.inputs)
tensor_indices = ctx.tensor_indices
tensors = ctx.saved_tensors

# Fill in inputs with appropriate saved tensors.
for idx, tidx in enumerate(tensor_indices):
inputs[tidx] = tensors[idx]

# Store the current states.
bwd_cpu_rng_state = torch.get_rng_state()
bwd_cuda_rng_state = torch.cuda.get_rng_state()
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()

# Set the states to what it used to be before the forward pass.
torch.set_rng_state(ctx.fwd_cpu_rng_state)
_set_cuda_rng_state(ctx.fwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)

# Compute the forward pass.
detached_inputs = detach_variable(tuple(inputs))
with torch.enable_grad():
outputs = ctx.run_function(*detached_inputs)

# Set the states back to what it was at the start of this function.
torch.set_rng_state(bwd_cpu_rng_state)
_set_cuda_rng_state(bwd_cuda_rng_state)
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)

# run backward() with only tensor that requires grad
outputs_with_grad = []
args_with_grad = []
for idx, output in enumerate(outputs):
if torch.is_tensor(output) and output.requires_grad:
outputs_with_grad.append(output)
args_with_grad.append(args[idx])
torch.autograd.backward(outputs_with_grad, args_with_grad)
grads = tuple(
inp.grad if isinstance(inp, torch.Tensor) else None
for inp in detached_inputs
)
return (None,) + grads


def checkpoint(function, *args, use_reentrant=True, **kwargs):
"""Checkpoint a model or part of the model. See PyTorch checkpoint
for details about behaviors and arguments. The only difference is
when the random seed is set by Slapo, the checkpoint function will
also track the random states and restore them properly.
TODO: The implementation in Megatron-LM has a mode to distribute
the saved activations across model parallel groups to further reduce
the memory footprint. This is not implemented here yet.
"""
if not is_random_seed_set():
return torch_checkpoint(function, *args, use_reentrant=use_reentrant, **kwargs)
return CheckpointFunctionWithRNGTracker.apply(function, *args)
2 changes: 2 additions & 0 deletions slapo/op/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Custom Ops."""
from .cross_entropy import ParallelCrossEntropy
from .dropout import DropoutWithTensorParallel
24 changes: 24 additions & 0 deletions slapo/op/dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Dropout module."""

from torch import nn

from ..random import get_cuda_rng_tracker


class DropoutWithTensorParallel(nn.Dropout):
"""The dropout that supposed to be used in parallel region.
In parallel region means the original input tensor is partitioned
due to tensor parallelism or sequence parallelism. In this case,
we need to make sure the dropout on each device in the same
tensor parallel group has DIFFERENT random seed; otherwise each
partitioned tensor will have the same dropout mask, which may hurt
the convergence.
"""

# pylint: disable=redefined-builtin

def forward(self, input):
with get_cuda_rng_tracker().fork():
return super().forward(input)
Loading

0 comments on commit d93764c

Please sign in to comment.