From ebfb17d45de39c1569317e50c03f5d5302d7ae81 Mon Sep 17 00:00:00 2001 From: Mehant Kammakomati Date: Tue, 11 Feb 2025 17:51:55 +0530 Subject: [PATCH] fix: support grad clipping for TP Signed-off-by: Mehant Kammakomati --- .../models/granite/configuration_granite.py | 8 +- src/transformers/pytorch_utils.py | 95 ++++++++++++++++++- src/transformers/trainer.py | 10 +- 3 files changed, 105 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/granite/configuration_granite.py b/src/transformers/models/granite/configuration_granite.py index fc651a94e1bd..3ac1bb350529 100644 --- a/src/transformers/models/granite/configuration_granite.py +++ b/src/transformers/models/granite/configuration_granite.py @@ -117,10 +117,14 @@ class GraniteConfig(PretrainedConfig): "layers.*.self_attn.q_proj": "colwise", "layers.*.self_attn.k_proj": "colwise", "layers.*.self_attn.v_proj": "colwise", - "layers.*.self_attn.o_proj": "rowwise", + "layers.*.self_attn.o_proj": "rowwise_output_dtensor", "layers.*.mlp.gate_proj": "colwise", "layers.*.mlp.up_proj": "colwise", - "layers.*.mlp.down_proj": "rowwise", + "layers.*.mlp.down_proj": "rowwise_output_dtensor", + "embed_tokens": "replicateparallel_output_dtensor", + "layers.*.post_attention_layernorm": "replicateparallel_output_dtensor", + "layers.*.input_layernorm": "replicateparallel_output_dtensor", + "norm": "replicateparallel_output_dtensor", } base_model_pp_plan = { "embed_tokens": (["input_ids"], ["inputs_embeds"]), diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index c36adffd9722..dc95ba9d5e53 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -14,7 +14,7 @@ from __future__ import annotations import inspect -from functools import lru_cache, wraps +from functools import lru_cache, partial, wraps from typing import Callable, List, Optional, Set, Tuple, Union import torch @@ -45,9 +45,16 @@ _torch_distributed_available = torch.distributed.is_available() if is_torch_greater_or_equal("2.5") and _torch_distributed_available: - from torch.distributed.tensor import Replicate + from torch.distributed.tensor import ( + DeviceMesh, + DTensor, + Placement, + Replicate, + distribute_module, + ) from torch.distributed.tensor.parallel import ( ColwiseParallel, + ParallelStyle, RowwiseParallel, ) @@ -344,8 +351,82 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) return torch.isin(elements, test_elements) -# TODO need to add the __repr__ that shows that it is a colwise parallel -# See https://github.com/pytorch/pytorch/issues/145726 +class ReplicateParallel(ParallelStyle): + """ + Replicate a nn.Module. + Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model. + Fully distributed model is needed for gradient clipping. + + Keyword Args: + input_layouts (Placement, optional): + The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to + become a DTensor. If not specified, we assume the input tensor to be replicated. + output_layouts (Placement, optional): + The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module + with the user desired layout. If not specified, we assume the output tensor to be replicated. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. + Returns: + A :class:`ParallelStyle` object that represents replication of nn.Module. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, ReplicateParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input and output of the "w1" Linear will be converted to Replicated DTensor + >>> + >>> replicated_mod = parallelize_module(m, tp_mesh, {"w1": ReplicateParallel()}) + >>> ... + + """ + + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Replicate(),) + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): + # since nn.Linear and nn.Embedding have single input + # we may extend support to other modules since its replicate. + input_tensor = inputs[0] + if isinstance(input_tensor, torch.distributed._functional_collectives.AsyncCollectiveTensor): + input_tensor = input_tensor.trigger_wait() + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) + return input_tensor + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), + partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), + ) + + def translate_to_torch_parallel_style(style: str): """ In model configurations, we use a neutral type (string) to specify parallel @@ -359,10 +440,16 @@ def translate_to_torch_parallel_style(style: str): return ColwiseParallel() elif style == "rowwise": return RowwiseParallel() + elif style == "rowwise_output_dtensor": + return RowwiseParallel(use_local_output=False) elif style == "colwise_rep": return ColwiseParallel(output_layouts=Replicate()) elif style == "rowwise_rep": return RowwiseParallel(input_layouts=Replicate()) + elif style == "replicateparallel": + return ReplicateParallel() + elif style == "replicateparallel_output_dtensor": + return ReplicateParallel(use_local_output=False) else: raise ValueError(f"Unsupported parallel style value: {style}") diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4d958d30214a..bd8b97c4a80a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -234,6 +234,7 @@ AutocastKwargs, DistributedDataParallelKwargs, DistributedType, + TorchTensorParallelPlugin, load_fsdp_model, load_fsdp_optimizer, save_fsdp_model, @@ -2309,7 +2310,9 @@ def _inner_training_loop( else: debug_overflow = DebugUnderflowOverflow(self.model) # noqa - delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled + delay_optimizer_creation = ( + is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled + ) # We need to reset the scheduler, as its parameters may be different on subsequent calls if self._created_lr_scheduler: @@ -2364,7 +2367,10 @@ def _inner_training_loop( if self.use_apex: model = self.accelerator.prepare(self.model) else: - model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + if delay_optimizer_creation: + self.optimizer = self.accelerator.prepare(self.optimizer) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) else: # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(