-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: support grad clipping for TP through replicating non-sharded modules #36132
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,7 +14,7 @@ | |
from __future__ import annotations | ||
|
||
import inspect | ||
from functools import lru_cache, wraps | ||
from functools import lru_cache, partial, wraps | ||
from typing import Callable, List, Optional, Set, Tuple, Union | ||
|
||
import torch | ||
|
@@ -45,9 +45,16 @@ | |
_torch_distributed_available = torch.distributed.is_available() | ||
|
||
if is_torch_greater_or_equal("2.5") and _torch_distributed_available: | ||
from torch.distributed.tensor import Replicate | ||
from torch.distributed.tensor import ( | ||
DeviceMesh, | ||
DTensor, | ||
Placement, | ||
Replicate, | ||
distribute_module, | ||
) | ||
from torch.distributed.tensor.parallel import ( | ||
ColwiseParallel, | ||
ParallelStyle, | ||
RowwiseParallel, | ||
) | ||
|
||
|
@@ -344,8 +351,82 @@ def isin_mps_friendly(elements: torch.Tensor, test_elements: torch.Tensor | int) | |
return torch.isin(elements, test_elements) | ||
|
||
|
||
# TODO need to add the __repr__ that shows that it is a colwise parallel | ||
# See https://github.com/pytorch/pytorch/issues/145726 | ||
Comment on lines
-347
to
-348
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: keep this TODO? |
||
class ReplicateParallel(ParallelStyle): | ||
""" | ||
Replicate a nn.Module. | ||
Users can compose it together with other parallel styles like RowwiseParallel to achieve a fully distributed model. | ||
Fully distributed model is needed for gradient clipping. | ||
Comment on lines
+354
to
+358
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @weifengpy @wz337 @tianyu-l |
||
|
||
Keyword Args: | ||
input_layouts (Placement, optional): | ||
The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to | ||
become a DTensor. If not specified, we assume the input tensor to be replicated. | ||
output_layouts (Placement, optional): | ||
The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module | ||
with the user desired layout. If not specified, we assume the output tensor to be replicated. | ||
use_local_output (bool, optional): | ||
Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. | ||
Returns: | ||
A :class:`ParallelStyle` object that represents replication of nn.Module. | ||
|
||
Example:: | ||
>>> # xdoctest: +SKIP(failing) | ||
>>> from torch.distributed.tensor.parallel import parallelize_module, ReplicateParallel | ||
>>> from torch.distributed.device_mesh import init_device_mesh | ||
>>> ... | ||
>>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule | ||
>>> tp_mesh = init_device_mesh("cuda", (8,)) | ||
>>> | ||
>>> # By default, the input and output of the "w1" Linear will be converted to Replicated DTensor | ||
>>> | ||
>>> replicated_mod = parallelize_module(m, tp_mesh, {"w1": ReplicateParallel()}) | ||
>>> ... | ||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
input_layouts: Optional[Placement] = None, | ||
output_layouts: Optional[Placement] = None, | ||
use_local_output: bool = True, | ||
): | ||
super().__init__() | ||
self.input_layouts = (input_layouts or Replicate(),) | ||
self.output_layouts = (output_layouts or Replicate(),) | ||
self.desired_input_layouts = (Replicate(),) | ||
self.use_local_output = use_local_output | ||
|
||
@staticmethod | ||
def _prepare_input_fn(input_layouts, desired_input_layouts, mod, inputs, device_mesh): | ||
# since nn.Linear and nn.Embedding have single input | ||
# we may extend support to other modules since its replicate. | ||
input_tensor = inputs[0] | ||
if isinstance(input_tensor, torch.distributed._functional_collectives.AsyncCollectiveTensor): | ||
input_tensor = input_tensor.trigger_wait() | ||
if not isinstance(input_tensor, DTensor): | ||
input_tensor = DTensor.from_local(input_tensor, device_mesh, input_layouts, run_check=False) | ||
|
||
if input_layouts != desired_input_layouts: | ||
input_tensor = input_tensor.redistribute(placements=desired_input_layouts, async_op=True) | ||
return input_tensor | ||
|
||
@staticmethod | ||
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): | ||
if outputs.placements != output_layouts: | ||
outputs = outputs.redistribute(placements=output_layouts, async_op=True) | ||
return outputs.to_local() if use_local_output else outputs | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
None, | ||
partial(self._prepare_input_fn, self.input_layouts, self.desired_input_layouts), | ||
partial(self._prepare_output_fn, self.output_layouts, self.use_local_output), | ||
) | ||
|
||
|
||
def translate_to_torch_parallel_style(style: str): | ||
""" | ||
In model configurations, we use a neutral type (string) to specify parallel | ||
|
@@ -359,10 +440,16 @@ def translate_to_torch_parallel_style(style: str): | |
return ColwiseParallel() | ||
elif style == "rowwise": | ||
return RowwiseParallel() | ||
elif style == "rowwise_output_dtensor": | ||
return RowwiseParallel(use_local_output=False) | ||
elif style == "colwise_rep": | ||
return ColwiseParallel(output_layouts=Replicate()) | ||
elif style == "rowwise_rep": | ||
return RowwiseParallel(input_layouts=Replicate()) | ||
elif style == "replicateparallel": | ||
return ReplicateParallel() | ||
elif style == "replicateparallel_output_dtensor": | ||
return ReplicateParallel(use_local_output=False) | ||
else: | ||
raise ValueError(f"Unsupported parallel style value: {style}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for extending the configs here.
I wonder if some of these settings would be more interesting to training than to inference?
(On the other hand, I don't know much about HF's user profile -- training more or inference more?)
If some of the settings are specific to training, is it possible to separate them out? Or, shall we make the config somehow customizable at run time?