diff --git a/ci/task_unit_test.sh b/ci/task_unit_test.sh index b069f9c5..9f08cd89 100644 --- a/ci/task_unit_test.sh +++ b/ci/task_unit_test.sh @@ -11,9 +11,10 @@ nvidia-smi -L echo "Running unit tests..." # -r: redirect the output of local rank 1 to None so that # only local rank 0's output is printed to the console. -torchrun --nproc_per_node 2 -r 1:1 -m pytest tests +# -p "no:randomly": disable randomly plugin for sharding tests. +torchrun --nproc_per_node 2 -r 1:1 -m pytest -p "no:randomly" tests echo "Downloading test data..." bash benchmark/download_benchmark_dataset.sh echo "Running end-to-end tests..." -python3 -m pytest -s tests/end2end.py +python3 -m pytest -s -p "no:randomly" tests/end2end.py diff --git a/conftest.py b/conftest.py index f22b9c08..d0dcbfa7 100644 --- a/conftest.py +++ b/conftest.py @@ -1,8 +1,9 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import random +import numpy as np import pytest - import torch from torch import distributed as dist @@ -18,7 +19,6 @@ def pytest_collection_modifyitems(items): @pytest.fixture(scope="session") def init_dist(request): """Initialize the distributed group once in the entire test session.""" - torch.manual_seed(9999) try: dist.init_process_group(backend="nccl") except Exception as err: @@ -31,3 +31,17 @@ def destory_dist(): pass request.addfinalizer(destory_dist) + + +@pytest.fixture(scope="function", autouse=True) +def random_seed(): + """Set random seed to 1) make the tests deterministic, and 2) make every + device generate the same weights for tensor parallelism tests. + + Note that if you run pytest with "randomly" plugin enabled, this fixture + will have no effect. You can disable the plugin with + pytest -p "no:randomly" ... + """ + random.seed(9999) + np.random.seed(9999) + torch.manual_seed(9999) diff --git a/examples/gpt/deepspeed_hf.py b/examples/gpt/deepspeed_hf.py index b1b8a8ce..0f594b48 100644 --- a/examples/gpt/deepspeed_hf.py +++ b/examples/gpt/deepspeed_hf.py @@ -8,13 +8,13 @@ import deepspeed import torch import torch.distributed as dist -from torch.distributed.distributed_c10d import _get_global_rank from deepspeed.utils import RepeatingLoader from transformers import GPTNeoForCausalLM, AutoConfig import slapo +from slapo import set_random_seed from slapo.logger import get_logger -from slapo.op.cross_entropy import ParallelCrossEntropy +from slapo.op import ParallelCrossEntropy from slapo.utils.report import report_memory from model import schedule_model @@ -127,6 +127,7 @@ def train(args): delay_init=enable_pipeline, sequence_parallel=args.sequence_parallel, ) + tp_rank = sch.rank loss_fct = ParallelCrossEntropy(group=group) @@ -188,16 +189,14 @@ def loss_fn(outputs, labels): report_memory(msg="After building model") if args.disable_pipeline or args.sequence_parallel: - random_seed = 2013 + dist.get_rank() + set_random_seed(2013, model.mpu.get_data_parallel_rank(), None, tp_rank) else: - random_seed = ( - 2013 - + 100 * model.mpu.get_pipe_parallel_rank() - + 10 * model.mpu.get_data_parallel_rank() + set_random_seed( + 2013, + model.mpu.get_data_parallel_rank(), + model.mpu.get_pipe_parallel_rank(), + tp_rank, ) - random.seed(random_seed) - np.random.seed(random_seed) - torch.manual_seed(random_seed) # for now always use seq_length 1024 # TODO: make the dataloader generic to different sequence length diff --git a/examples/gpt/megatron_hf.py b/examples/gpt/megatron_hf.py index a23da218..ca93bf7f 100644 --- a/examples/gpt/megatron_hf.py +++ b/examples/gpt/megatron_hf.py @@ -60,6 +60,10 @@ def get_model( delay_init=delay_init, ) model, _ = slapo.build(sch, init_weights=model._init_weights) + # Note 1: We assume no DP and PP in this script. + # Note 2: This overrides Megatron random seed management, so we only use + # this script for benchmarking. + slapo.set_random_seed(2013, None, None, sch.rank) report_memory() elif impl == "torchscript": diff --git a/examples/gpt/model.py b/examples/gpt/model.py index 54da7843..13fbb333 100644 --- a/examples/gpt/model.py +++ b/examples/gpt/model.py @@ -51,7 +51,7 @@ def schedule_model( # if MP group > 1. attn_path, out_proj_name = "h.N.attn.attention", "out_proj" if disable_flash_attn: - logger.info("Disabled Flash Attention", rank=0) + logger.info("Disabled Flash Attention", ranks=0) cnt = replace_and_shard_attention( sch[prefix], config, diff --git a/examples/gpt/schedule.py b/examples/gpt/schedule.py index 5b5bc6cb..0a129e44 100644 --- a/examples/gpt/schedule.py +++ b/examples/gpt/schedule.py @@ -11,6 +11,7 @@ from slapo import init_empty_weights from slapo.pattern import call_module from slapo.op.linear import FusedQKV +from slapo import init_empty_weights, get_cuda_rng_tracker def trace_attention(sch, config, attn_path="h.N.attn.attention"): @@ -89,7 +90,7 @@ def replace_and_shard_attention( sequence_parallel=False, ): from epoi.inject.policy.gpt import InjectHFGPTAttentionPolicy - from epoi.ops.xformers_attn import GenericSelfAttention + from epoi.ops.xformers_attn import GenericSelfAttention, MemoryEfficientAttentionOp class SelfAttention(nn.Module): """A wrapper to align the original GPTNeoAttention forward signature.""" @@ -112,6 +113,13 @@ def forward( # is present_key_value and only used by in inference. return outputs[:1] + class MemoryEfficientAttentionWithRNGOp(MemoryEfficientAttentionOp): + def forward(self, query_layer, key_layer, value_layer, attention_mask, p): + with get_cuda_rng_tracker().fork(): + return super().forward( + query_layer, key_layer, value_layer, attention_mask, p + ) + num_layers, num_heads, hidden_size = ( config.num_layers, config.num_heads, @@ -163,6 +171,7 @@ def pattern(x: torch.Tensor) -> torch.Tensor: mode="fwd_post", sync_op_or_fn="reduce_scatter", axis=1 ) else: + # Shard qkv and output projection. sub_sch["module.FusedQKV_0.fused_linear"].sync( mode="bwd_post", sync_op_or_fn="all_reduce" ) @@ -170,6 +179,14 @@ def pattern(x: torch.Tensor) -> torch.Tensor: mode="fwd_post", sync_op_or_fn="all_reduce" ) + # In this case, the attention dropout in between has to + # use different random seeds. + new_op = MemoryEfficientAttentionWithRNGOp( + sub_sch["module"]["attn_op"].mod.attn_op_name, + sub_sch["module"]["attn_op"].mod.apply_causal_mask, + ) + sub_sch["module"]["attn_op"].replace(new_op) + cnt += 1 return cnt diff --git a/slapo/__init__.py b/slapo/__init__.py index 37e6b9ca..b06375d1 100644 --- a/slapo/__init__.py +++ b/slapo/__init__.py @@ -9,3 +9,5 @@ from .tracer import * from .utils import * from .version import __version__ +from .random import set_random_seed, get_cuda_rng_tracker, is_random_seed_set +from .checkpoint import checkpoint diff --git a/slapo/checkpoint.py b/slapo/checkpoint.py new file mode 100644 index 00000000..6f866edf --- /dev/null +++ b/slapo/checkpoint.py @@ -0,0 +1,124 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Modification: Megatron-LM. +# See https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py +"""Model checkpoints and activation checkpointing with the consideration +of 3D parallelism and random states. +""" +import torch +from torch.utils.checkpoint import detach_variable +from torch.utils.checkpoint import checkpoint as torch_checkpoint + +from .random import get_cuda_rng_tracker, is_random_seed_set, _set_cuda_rng_state + + +class CheckpointFunctionWithRNGTracker(torch.autograd.Function): + """This function is adapted from torch.utils.checkpoint with + two main changes: + 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` + 2) the states in the model parallel tracker are also properly + tracked/set/reset. + """ + + # pylint: disable=abstract-method, arguments-differ + + @staticmethod + def forward(ctx, run_function, *args): + ctx.run_function = run_function + + # Copy the rng states. + ctx.fwd_cpu_rng_state = torch.get_rng_state() + ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() + ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + with torch.no_grad(): + outputs = run_function(*args) + + # Save non-tensor inputs in ctx, keep a placeholder None for tensors + # to be filled out during the backward. + ctx.inputs = [] + ctx.tensor_indices = [] + tensor_inputs = [] + for idx, arg in enumerate(args): + if torch.is_tensor(arg): + tensor_inputs.append(arg) + ctx.tensor_indices.append(idx) + ctx.inputs.append(None) + else: + ctx.inputs.append(arg) + + # We detach the tensor inputs to make sure we hold a reference to + # the tensor data. This is needed because when pipeline is enabled, + # the tensor data may be released by the pipeline engine as it does + # not know that the tensor is used in the backward pass. + ctx.save_for_backward(*detach_variable(tuple(tensor_inputs))) + + return outputs + + @staticmethod + def backward(ctx, *args): + if not torch.autograd._is_checkpoint_valid(): + raise RuntimeError( + "Checkpointing is not compatible with .grad(), " + "please use .backward() if possible" + ) + # Copy the list to avoid modifying original list. + inputs = list(ctx.inputs) + tensor_indices = ctx.tensor_indices + tensors = ctx.saved_tensors + + # Fill in inputs with appropriate saved tensors. + for idx, tidx in enumerate(tensor_indices): + inputs[tidx] = tensors[idx] + + # Store the current states. + bwd_cpu_rng_state = torch.get_rng_state() + bwd_cuda_rng_state = torch.cuda.get_rng_state() + bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() + + # Set the states to what it used to be before the forward pass. + torch.set_rng_state(ctx.fwd_cpu_rng_state) + _set_cuda_rng_state(ctx.fwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) + + # Compute the forward pass. + detached_inputs = detach_variable(tuple(inputs)) + with torch.enable_grad(): + outputs = ctx.run_function(*detached_inputs) + + # Set the states back to what it was at the start of this function. + torch.set_rng_state(bwd_cpu_rng_state) + _set_cuda_rng_state(bwd_cuda_rng_state) + get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) + + if isinstance(outputs, torch.Tensor): + outputs = (outputs,) + + # run backward() with only tensor that requires grad + outputs_with_grad = [] + args_with_grad = [] + for idx, output in enumerate(outputs): + if torch.is_tensor(output) and output.requires_grad: + outputs_with_grad.append(output) + args_with_grad.append(args[idx]) + torch.autograd.backward(outputs_with_grad, args_with_grad) + grads = tuple( + inp.grad if isinstance(inp, torch.Tensor) else None + for inp in detached_inputs + ) + return (None,) + grads + + +def checkpoint(function, *args, use_reentrant=True, **kwargs): + """Checkpoint a model or part of the model. See PyTorch checkpoint + for details about behaviors and arguments. The only difference is + when the random seed is set by Slapo, the checkpoint function will + also track the random states and restore them properly. + + TODO: The implementation in Megatron-LM has a mode to distribute + the saved activations across model parallel groups to further reduce + the memory footprint. This is not implemented here yet. + """ + if not is_random_seed_set(): + return torch_checkpoint(function, *args, use_reentrant=use_reentrant, **kwargs) + return CheckpointFunctionWithRNGTracker.apply(function, *args) diff --git a/slapo/op/__init__.py b/slapo/op/__init__.py index 19b4229f..aa09d5cc 100644 --- a/slapo/op/__init__.py +++ b/slapo/op/__init__.py @@ -1,3 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 """Custom Ops.""" +from .cross_entropy import ParallelCrossEntropy +from .dropout import DropoutWithTensorParallel diff --git a/slapo/op/dropout.py b/slapo/op/dropout.py new file mode 100644 index 00000000..9e356e68 --- /dev/null +++ b/slapo/op/dropout.py @@ -0,0 +1,24 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Dropout module.""" + +from torch import nn + +from ..random import get_cuda_rng_tracker + + +class DropoutWithTensorParallel(nn.Dropout): + """The dropout that supposed to be used in parallel region. + In parallel region means the original input tensor is partitioned + due to tensor parallelism or sequence parallelism. In this case, + we need to make sure the dropout on each device in the same + tensor parallel group has DIFFERENT random seed; otherwise each + partitioned tensor will have the same dropout mask, which may hurt + the convergence. + """ + + # pylint: disable=redefined-builtin + + def forward(self, input): + with get_cuda_rng_tracker().fork(): + return super().forward(input) diff --git a/slapo/random.py b/slapo/random.py new file mode 100644 index 00000000..60a2526f --- /dev/null +++ b/slapo/random.py @@ -0,0 +1,214 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Modification: Megatron-LM. +# See https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py +"""Random seed and states management.""" + +import contextlib +import random + +import numpy as np +import torch +from torch.cuda import _lazy_call + +# Default name for the model parallel rng tracker. +_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" + + +def _set_cuda_rng_state(new_state, device=-1): + """Sets the random number generator state of the current GPU. + This function is adapted from PyTorch repo (torch.cuda.set_rng_state) + with a single change: the input state is not cloned. Cloning caused + major performance issues for +4 GPU cases. + + Paramters + --------- + new_state : torch.ByteTensor + The desired state. + + device : int + The GPU device to set the state for. If -1, the current device is used. + """ + if device == -1: + device = torch.device("cuda") + elif isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("cuda", device) + + def cb(): + idx = device.index + if idx is None: + idx = torch.cuda.current_device() + default_generator = torch.cuda.default_generators[idx] + default_generator.set_state(new_state) + + _lazy_call(cb) + + +class CudaRNGStatesTracker: + """Tracker for the cuda RNG states. + Using the `add` method, a cuda rng state is initialized based on + the input `seed` and is assigned to `name`. Later, by forking the + rng state, we can perform operations and return to our starting + cuda state. + """ + + def __init__(self): + # Map from a string name to the cuda rng state. + self.states_ = {} + # Seeds are just for book keeping and ensure no seed is set twice. + self.seeds_ = set() + + def reset(self): + """Set to the initial state (no tracker).""" + self.states_ = {} + self.seeds_ = set() + + def get_states(self): + """Get rng states. Copy the dictionary so we have direct + pointers to the states, not just a pointer to the dictionary.""" + states = {} + for name in self.states_: + states[name] = self.states_[name] + return states + + def set_states(self, states): + """Set the rng states. For efficiency purposes, we do not check + the size of seed for compatibility.""" + self.states_ = states + + def add(self, name, seed): + """Track the rng state.""" + # Check seed is not already used. + if seed in self.seeds_: + raise Exception(f"seed {seed} already exists") + self.seeds_.add(seed) + # Check that state is not already defined. + if name in self.states_: + raise Exception(f"cuda rng state {name} already exists") + # Get the current rng state. + orig_rng_state = torch.cuda.get_rng_state() + # Set the new state and store it. + torch.cuda.manual_seed(seed) + self.states_[name] = torch.cuda.get_rng_state() + # Reset rng state to what it was. + _set_cuda_rng_state(orig_rng_state) + + @contextlib.contextmanager + def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME): + """Fork the cuda rng state, perform operations, and exit with + the original state.""" + # Check if we have added the state + if name not in self.states_: + raise RuntimeError( + f"cuda rng state {name} is not added. " + "Did you call 'set_random_seed'?" + ) + # Store current rng state. + orig_cuda_rng_state = torch.cuda.get_rng_state() + # Set rng state to the desired one + _set_cuda_rng_state(self.states_[name]) + # Do the stuff we wanted to do. + try: + yield + finally: + # Update the current rng state for later use. + self.states_[name] = torch.cuda.get_rng_state() + # And set the state to the original state we started with. + _set_cuda_rng_state(orig_cuda_rng_state) + + +# RNG tracker object. +_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() + + +def get_cuda_rng_tracker(): + """Get cuda rng tracker.""" + return _CUDA_RNG_STATE_TRACKER + + +def model_parallel_cuda_manual_seed(seed, tp_rank): + """Initialize model parallel cuda seed. + This function should be called after the model parallel is + initialized. Also, no torch.cuda.manual_seed should be called + after this function. Basically, this is replacement for that + function. + Two sets of RNG states are tracked: + default state: This is for data parallelism and is the same among a + set of model parallel GPUs but different across + different model paralle groups. This is used for + example for dropout in the non-tensor-model-parallel regions. + tensor-model-parallel state: This state is different among a set of model + parallel GPUs, but the same across data parallel + groups. This is used for example for dropout in + model parallel regions. + + Parameters + ---------- + seed : int + Random seed. + tp_rank : int + Tensor model parallel rank. + + Returns + ------- + int + Tensor model parallel seed of this rank. + """ + # 2718 is just for fun and any POSITIVE value will work. + tensor_model_parallel_seed = seed + 2718 + tp_rank + + _CUDA_RNG_STATE_TRACKER.reset() + # Set the default state. + torch.cuda.manual_seed(seed) + # and model parallel state. + _CUDA_RNG_STATE_TRACKER.add( + _MODEL_PARALLEL_RNG_TRACKER_NAME, tensor_model_parallel_seed + ) + return tensor_model_parallel_seed + + +def is_random_seed_set(): + """Check if random seed is set.""" + return bool(_CUDA_RNG_STATE_TRACKER.get_states()) + + +def set_random_seed(seed=2013, dp_rank=None, pp_rank=None, tp_rank=None): + """Set random seed for reproducability. + + Parameters + ---------- + seed : int + Random seed. Default is 2013. + dp_rank : Optional[int] + Data parallel rank. Default is None means no data parallelism. + pp_rank : Optional[int] + Pipeline parallel rank. Default is None means no pipeline parallelism. + tp_rank : Optional[int] + Tensor model parallel rank. Default is None means no tensor parallelism. + + Returns + ------- + int + Random seed of this rank. + """ + # Ensure each pipeline stage uses different seed. + if pp_rank is not None: + seed += 100 * pp_rank + + # Ensure each data parallel group uses different seed. + if dp_rank is not None: + seed += 10 * dp_rank + + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + # In above cases, devices in the same TP group should have the same seed. + # However, we may need them to have different seeds for some cases, so + # here we maintain different seeds for each device in TP group separately. + if torch.cuda.device_count() > 0 and tp_rank is not None: + model_parallel_cuda_manual_seed(seed, tp_rank) + + return seed diff --git a/slapo/schedule.py b/slapo/schedule.py index fc6ea765..8549dcb8 100644 --- a/slapo/schedule.py +++ b/slapo/schedule.py @@ -16,11 +16,11 @@ import torch import torch.distributed as dist from torch import fx, nn -from torch.utils import checkpoint # pylint: disable=unused-import import torch.nn.functional as F +from .checkpoint import checkpoint as checkpoint_module from .logger import get_logger from .model_dialect import get_dialect_cls from .pipeline import ( @@ -987,7 +987,7 @@ def forward(self, *args, **kwargs): ordered_args = order_args_fn(*args, **kwargs) # Note: checkpoint cannot accept kwargs - return checkpoint.checkpoint(self.mod, *ordered_args) + return checkpoint_module(self.mod, *ordered_args) self.replace(CheckPointWrapper(self.mod)) diff --git a/tests/test_checkpoint.py b/tests/test_checkpoint.py new file mode 100644 index 00000000..9496b49f --- /dev/null +++ b/tests/test_checkpoint.py @@ -0,0 +1,84 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Test checkpoints. Note that this test has to be invoked by torchrun. +See ci/task_unit_tests.sh for an example. +""" +# pylint: disable=unused-argument + +import os + +import pytest +import torch +from torch import distributed as dist +from torch.autograd import Variable + +from slapo import checkpoint, get_cuda_rng_tracker, set_random_seed +from slapo.sharding import reduce_backward_grad, reduce_forward_output + + +def test_activation_checkpoint_with_rng_states(init_dist): + world_size = dist.get_world_size() + full_size = 5 * world_size + + tp_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(tp_rank) + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(full_size, 5, bias=False) + self.dropout1 = torch.nn.Dropout(0.5) + self.linear2 = torch.nn.Linear(5, full_size, bias=False) + self.dropout2 = torch.nn.Dropout(0.5) + + def orig_forward(self, x): + x = reduce_backward_grad(x, None) + x = self.linear1(x) + # The output of lienar1 is partitioned, so we use different seeds. + with get_cuda_rng_tracker().fork(): + x = self.dropout1(x) + x = self.linear2(x) + # The output of linear2 is partial sum, so we use the same seed. + x = self.dropout2(x) + x = reduce_forward_output(x, None) + return x + + def forward(self, x, enable_checkpoint): + if enable_checkpoint: + return checkpoint(self.orig_forward, x) + return self.orig_forward(x) + + data = torch.randn((full_size, full_size), requires_grad=True).cuda(tp_rank) + dist.broadcast(data, src=0) + data = Variable(data, requires_grad=True) + + model = Model().cuda(tp_rank) + + def run(model, data, enable_checkpoint): + # 1. Run the model forward and backward. + out = model(data, enable_checkpoint) + out.mean().backward() + # 2. Retrieve gradients. + linear1_weight_grad = model.linear1.weight.grad.clone() + linear2_weight_grad = model.linear2.weight.grad.clone() + input_grad = data.grad.clone() + # 3. Clear gradients. + model.linear1.weight.grad = None + model.linear2.weight.grad = None + data.grad = None + return out, linear1_weight_grad, linear2_weight_grad, input_grad + + # Run the model without activation checkpointing for reference. + set_random_seed(123, None, None, tp_rank) + refs = run(model, data, False) + # Run the model with activation checkpointing. + set_random_seed(123, None, None, tp_rank) + outs = run(model, data, True) + + for ref, out in zip(refs, outs): + torch.testing.assert_close(out, ref) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_op.py b/tests/test_op.py new file mode 100644 index 00000000..6b994904 --- /dev/null +++ b/tests/test_op.py @@ -0,0 +1,56 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Test custom ops. Note that this test has to be invoked by torchrun since +most custom ops are for tensor parallelism. +""" +# pylint: disable=unused-argument +import os +import pytest + +import torch +from torch import nn +from torch import distributed as dist + +from slapo import op +from slapo.random import set_random_seed, get_cuda_rng_tracker + + +def test_dropout(init_dist): + world_size = dist.get_world_size() + local_rank = int(os.environ["LOCAL_RANK"]) + torch.cuda.set_device(local_rank) + + data = torch.rand(10, 10).cuda(local_rank) + dist.broadcast(data, src=0) + + get_cuda_rng_tracker().reset() + + # The custom dropout should throw error if set_random_seed is not called. + with pytest.raises(Exception): + op.DropoutWithTensorParallel(p=0.5)(data) + + set_random_seed(123, tp_rank=local_rank) + + # Assuming all devices are in the same TP group, the native dropout + # should produce the same output on all devices. + out = nn.Dropout(p=0.5)(data) + out_reduced = out.clone() + dist.all_reduce(out_reduced) + torch.testing.assert_close( + out * world_size, + out_reduced, + msg=lambda msg: f"output mismatch\n{msg}", + ) + + # The custom dropout should produce different outputs on different devices + # even they are in the same TP group. + out = op.DropoutWithTensorParallel(p=0.5)(data) + out_reduced = out.clone() + dist.all_reduce(out_reduced) + with pytest.raises(Exception): + torch.testing.assert_close(out * world_size, out_reduced) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_random.py b/tests/test_random.py new file mode 100644 index 00000000..6850d2a0 --- /dev/null +++ b/tests/test_random.py @@ -0,0 +1,65 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Modification: Megatron-LM. +# See https://github.com/NVIDIA/Megatron-LM/blob/main/tests/tensor_parallel/test_random.py +""" +Test random state managements. Note that this test has to be invoked by torchrun. +See ci/task_unit_tests.sh for an example. +""" + +import os + +import pytest +import torch + +from slapo.random import ( + _CUDA_RNG_STATE_TRACKER, + CudaRNGStatesTracker, + get_cuda_rng_tracker, + model_parallel_cuda_manual_seed, +) + + +def test_cuda_rng_states_tracker(): + rng_tracker = CudaRNGStatesTracker() + rng_tracker.set_states({"state1": 1234}) + assert rng_tracker.get_states()["state1"] == 1234 + + rng_tracker.reset() + assert not rng_tracker.get_states() + + seed = 1111 + rng_tracker.add("state2", seed) + with pytest.raises(Exception): + assert rng_tracker.add("state3", seed) + + with pytest.raises(Exception): + assert rng_tracker.add("state2", 111) + + assert rng_tracker.get_states()["state2"] is not None + with pytest.raises(Exception): + assert () + + rng_tracker.fork("state2") + torch.cuda.manual_seed(seed) + rng_state = torch.cuda.get_rng_state() + assert torch.equal(rng_tracker.get_states()["state2"], rng_state) + + +def test_model_parallel_seed(): + assert torch.cuda.initial_seed() != 123 + + local_rank = int(os.environ["LOCAL_RANK"]) + tp_seed = model_parallel_cuda_manual_seed(123, tp_rank=local_rank) + assert _CUDA_RNG_STATE_TRACKER.get_states()["model-parallel-rng"] is not None + + # Outside the context, the seed should be the same. + assert torch.cuda.initial_seed() == 123 + + # Inside the context, the seed should be different. + with get_cuda_rng_tracker().fork(): + assert torch.cuda.initial_seed() == tp_seed + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_shard.py b/tests/test_shard.py index ddf09434..c917ce7f 100644 --- a/tests/test_shard.py +++ b/tests/test_shard.py @@ -2,8 +2,8 @@ # SPDX-License-Identifier: Apache-2.0 """ -Test sharding primitive. Note that this test has to be invoked by torchrun. For example: -torchrun --nproc_per_node 2 -m pytest test_shard.py +Test sharding primitive. Note that this test has to be invoked by torchrun. +See ci/task_unit_tests.sh for an example. """ # pylint: disable=unused-argument import os diff --git a/tests/test_shard_sync_op.py b/tests/test_shard_sync_op.py index ea69c364..fc1cd3cc 100644 --- a/tests/test_shard_sync_op.py +++ b/tests/test_shard_sync_op.py @@ -3,7 +3,7 @@ """ Test sync ops for sharding. Note that this test has to be invoked by torchrun. -For example: torchrun --nproc_per_node 2 -m pytest test_shard_sync_op.py +See ci/task_unit_tests.sh for an example. """ # pylint: disable=unused-argument import copy @@ -20,9 +20,6 @@ def init_model_and_data(local_rank): model = nn.Linear(10, 10).cuda(local_rank) - # Make sure all devices have the same model. - dist.broadcast(model.weight.data, src=0) - dist.broadcast(model.bias.data, src=0) ref_model = copy.deepcopy(model) data = torch.randn((10, 10), requires_grad=True).cuda(local_rank)