Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support grad clipping for TP through replicating non-sharded modules #36132

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/models/granite/configuration_granite.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,10 +117,14 @@ class GraniteConfig(PretrainedConfig):
"layers.*.self_attn.q_proj": "colwise",
"layers.*.self_attn.k_proj": "colwise",
"layers.*.self_attn.v_proj": "colwise",
"layers.*.self_attn.o_proj": "rowwise",
"layers.*.self_attn.o_proj": "rowwise_output_dtensor",
"layers.*.mlp.gate_proj": "colwise",
"layers.*.mlp.up_proj": "colwise",
"layers.*.mlp.down_proj": "rowwise",
"layers.*.mlp.down_proj": "rowwise_output_dtensor",
"embed_tokens": "replicateparallel_output_dtensor",
"layers.*.post_attention_layernorm": "replicateparallel_output_dtensor",
"layers.*.input_layernorm": "replicateparallel_output_dtensor",
"norm": "replicateparallel_output_dtensor",
Comment on lines +120 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for extending the configs here.
I wonder if some of these settings would be more interesting to training than to inference?
(On the other hand, I don't know much about HF's user profile -- training more or inference more?)
If some of the settings are specific to training, is it possible to separate them out? Or, shall we make the config somehow customizable at run time?

}
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
Expand Down
95 changes: 91 additions & 4 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import inspect
from functools import lru_cache, wraps
from functools import lru_cache, partial, wraps
from typing import Callable, List, Optional, Set, Tuple, Union

import torch
Expand Down Expand Up @@ -45,9 +45,16 @@
_torch_distributed_available = torch.distributed.is_available()

if is_torch_greater_or_equal("2.5") and _torch_distributed_available:
from torch.distributed.tensor import Replicate
from torch.distributed.tensor import (
DeviceMesh,
DTensor,
Placement,
Replicate,
distribute_module,
)
from torch.distributed.tensor.parallel import (
ColwiseParallel,
ParallelStyle,
RowwiseParallel,
)

Expand Down Expand Up @@ -344,8 +351,82 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int)
return torch.isin(elements, test_elements)


# TODO need to add the __repr__ that shows that it is a colwise parallel
# See https://github.com/pytorch/pytorch/issues/145726
Comment on lines -347 to -348
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep this TODO?

class ReplicateParallel(ParallelStyle):
"""
Replicate a nn.Module.
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model.
Fully distributed model is needed for gradient clipping.
Comment on lines +354 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weifengpy @wz337 @tianyu-l
I wonder if there is anything we can do on DTensor side so that users don't have to annotate the entire model to perform gradient clipping?


Keyword Args:
input_layouts (Placement, optional):
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to
become a DTensor. If not specified, we assume the input tensor to be replicated.
output_layouts (Placement, optional):
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module
with the user desired layout. If not specified, we assume the output tensor to be replicated.
use_local_output (bool, optional):
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True.
Returns:
A :class:`ParallelStyle` object that represents replication of nn.Module.

Example::
>>> # xdoctest: +SKIP(failing)
>>> from torch.distributed.tensor.parallel import parallelize_module, ReplicateParallel
>>> from torch.distributed.device_mesh import init_device_mesh
>>> ...
>>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule
>>> tp_mesh = init_device_mesh("cuda", (8,))
>>>
>>> # By default, the input and output of the "w1" Linear will be converted to Replicated DTensor
>>>
>>> replicated_mod = parallelize_module(m, tp_mesh, {"w1": ReplicateParallel()})
>>> ...

"""

def __init__(
self,
*,
input_layouts: Optional[Placement] = None,
output_layouts: Optional[Placement] = None,
use_local_output: bool = True,
):
super().__init__()
self.input_layouts = (input_layouts or Replicate(),)
self.output_layouts = (output_layouts or Replicate(),)
self.desired_input_layouts = (Replicate(),)
self.use_local_output = use_local_output

@staticmethod
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh):
# since nn.Linear and nn.Embedding have single input
# we may extend support to other modules since its replicate.
input_tensor = inputs[0]
if isinstance(input_tensor, torch.distributed._functional_collectives.AsyncCollectiveTensor):
input_tensor = input_tensor.trigger_wait()
if not isinstance(input_tensor, DTensor):
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False)

if input_layouts != desired_input_layouts:
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True)
return input_tensor

@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
if outputs.placements != output_layouts:
outputs = outputs.redistribute(placements=output_layouts, async_op=True)
return outputs.to_local() if use_local_output else outputs

def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
return distribute_module(
module,
device_mesh,
None,
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts),
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output),
)


def translate_to_torch_parallel_style(style: str):
"""
In model configurations, we use a neutral type (string) to specify parallel
Expand All @@ -359,10 +440,16 @@ def translate_to_torch_parallel_style(style: str):
return ColwiseParallel()
elif style == "rowwise":
return RowwiseParallel()
elif style == "rowwise_output_dtensor":
return RowwiseParallel(use_local_output=False)
elif style == "colwise_rep":
return ColwiseParallel(output_layouts=Replicate())
elif style == "rowwise_rep":
return RowwiseParallel(input_layouts=Replicate())
elif style == "replicateparallel":
return ReplicateParallel()
elif style == "replicateparallel_output_dtensor":
return ReplicateParallel(use_local_output=False)
else:
raise ValueError(f"Unsupported parallel style value: {style}")

Expand Down
10 changes: 8 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
AutocastKwargs,
DistributedDataParallelKwargs,
DistributedType,
TorchTensorParallelPlugin,
load_fsdp_model,
load_fsdp_optimizer,
save_fsdp_model,
Expand Down Expand Up @@ -2309,7 +2310,9 @@ def _inner_training_loop(
else:
debug_overflow = DebugUnderflowOverflow(self.model) # noqa

delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled
delay_optimizer_creation = (
is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled or self.is_tp_enabled
)

# We need to reset the scheduler, as its parameters may be different on subsequent calls
if self._created_lr_scheduler:
Expand Down Expand Up @@ -2364,7 +2367,10 @@ def _inner_training_loop(
if self.use_apex:
model = self.accelerator.prepare(self.model)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
if delay_optimizer_creation:
self.optimizer = self.accelerator.prepare(self.optimizer)
else:
model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
else:
# to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
Expand Down