Skip to content

Commit

Permalink
fix: support grad clipping for TP
Browse files Browse the repository at this point in the history
Signed-off-by: Mehant Kammakomati <[email protected]>
  • Loading branch information
kmehant committed Feb 20, 2025
1 parent 5412ff1 commit ebfb17d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 8 deletions.
8 changes: 6 additions & 2 deletions src/transformers/models/granite/configuration_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,14 @@ class GraniteConfig(PretrainedConfig):
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.self_attn.o_proj": "rowwise_output_dtensor",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.down_proj": "rowwise_output_dtensor",
"embed_tokens": "replicateparallel_output_dtensor",
"layers.*.post_attention_layernorm": "replicateparallel_output_dtensor",
"layers.*.input_layernorm": "replicateparallel_output_dtensor",
"norm": "replicateparallel_output_dtensor",
}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
Expand Down
95 changes: 91 additions & 4 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import inspect
from functools import lru_cache, wraps
from functools import lru_cache, partial, wraps
from typing import Callable, List, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -45,9 +45,16 @@
_torch_distributed_available = torch.distributed.is_available()

if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor import (
DeviceMesh,
DTensor,
Placement,
Replicate,
distribute_module,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
ParallelStyle,
RowwiseParallel,
)

Expand Down Expand Up @@ -344,8 +351,82 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)


# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
class ReplicateParallel(ParallelStyle):
"""
Replicate a nn.Module.
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model.
Fully distributed model is needed for gradient clipping.
Keyword Args:
input_layouts (Placement, optional):
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
become a DTensor. If not specified, we assume the input tensor to be replicated.
output_layouts (Placement, optional):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, we assume the output tensor to be replicated.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A :class:`ParallelStyle` object that represents replication of nn.Module.
Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, ReplicateParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input and output of the "w1" Linear will be converted to Replicated DTensor
>>>
>>> replicated_mod = parallelize_module(m, tp_mesh, {"w1": ReplicateParallel()})
>>> ...
"""

def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Replicate(),)
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output

@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# since nn.Linear and nn.Embedding have single input
# we may extend support to other modules since its replicate.
input_tensor = inputs[0]
if isinstance(input_tensor, torch.distributed._functional_collectives.AsyncCollectiveTensor):
input_tensor = input_tensor.trigger_wait()
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
return outputs.to_local() if use_local_output else outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
None,
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
)


def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
Expand All @@ -359,10 +440,16 @@ def translate_to_torch_parallel_style(style: str):
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "rowwise_output_dtensor":
return RowwiseParallel(use_local_output=False)
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
elif style == "replicateparallel":
return ReplicateParallel()
elif style == "replicateparallel_output_dtensor":
return ReplicateParallel(use_local_output=False)
else:
raise ValueError(f"Unsupported parallel style value: {style}")

Expand Down
10 changes: 8 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
TorchTensorParallelPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -2309,7 +2310,9 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = (
is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
)

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -2364,7 +2367,10 @@ def _inner_training_loop(
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if delay_optimizer_creation:
self.optimizer = self.accelerator.prepare(self.optimizer)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
Expand Down

0 comments on commit ebfb17d

Please sign in to comment.