Skip to content

Conversation

xinyuangui2
Copy link
Contributor

@xinyuangui2 xinyuangui2 commented Sep 4, 2025

This PR extends the Ray Train v2 local mode support (from #55487) to enable users to launch multiple local mode processes using torchrun for PyTorch distributed training. With this new feature, users can easily switch between torchrun and Ray Train without modifying their training code.

image

Note

Ray data on multiple processes is not supported. Might need to wait for #55114 or similar components.

Key Changes

Multi-Process Local Mode Support

  • LocalTorchController: New controller that detects torchrun env variables and sets contexts accordingly
  • Torchrun Integration: Users can now launch multiple local mode processes using torchrun command
  • Environment Detection: Automatically detects torchrun environment variables and initializes distributed training

Usage Example

import os
import tempfile

import torch
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor, Normalize, Compose

import ray
from ray.train import Checkpoint, CheckpointConfig, RunConfig, ScalingConfig
from ray.train.torch import TorchTrainer
from ray.train.v2.api.config import FailureConfig
import ray.train.torch

def train_func():
    # Model, Loss, Optimizer
    model = resnet18(num_classes=10)
    model.conv1 = torch.nn.Conv2d(
        1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False
    )
    # [1] Prepare model.
    model = ray.train.torch.prepare_model(model)
    criterion = CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    # Data
    transform = Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])
    data_dir = os.path.join(tempfile.gettempdir(), "data")
    train_data = FashionMNIST(root=data_dir, train=True, download=True, transform=transform)
    train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
    # [2] Prepare dataloader.
    train_loader = ray.train.torch.prepare_data_loader(train_loader)

    # Training
    for epoch in range(10):
        if ray.train.get_context().get_world_size() > 1:
            train_loader.sampler.set_epoch(epoch)

        for images, labels in train_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # [3] Report metrics and checkpoint.
        metrics = {"loss": loss.item(), "epoch": epoch}
        with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            torch.save(
                model.state_dict(),
                os.path.join(temp_checkpoint_dir, "model.pt")
            )
            ray.train.report(
                metrics,
                checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            )
        if ray.train.get_context().get_world_rank() == 0:
            print(metrics)

# Configuration for local mode
use_gpu = True
scaling_config = ScalingConfig(num_workers=0, use_gpu=use_gpu)  # Local mode
run_config = RunConfig(checkpoint_config=CheckpointConfig(num_to_keep=1))

# Note: Ray Data not supported with multiple processes in local mode
# For multi-process training, use PyTorch DataLoader as shown above

# Initialize the Trainer
trainer = TorchTrainer(
    train_loop_per_worker=train_func,
    scaling_config=scaling_config,
    run_config=run_config,
)

# Train the model
result = trainer.fit()

Running Options:

# Option 1: Single process local mode
RAY_TRAIN_V2_ENABLED=1 python test.py

# Option 2: Multi-process local mode with torchrun
RAY_TRAIN_V2_ENABLED=1 torchrun --standalone --nnodes=1 --nproc-per-node=4 test.py

# Option 3: Switch to distributed Ray Train (change num_workers=4)
# Same training code works across all modes!

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests
    • This PR is not tested :(

Signed-off-by: xgui <[email protected]>
@xinyuangui2 xinyuangui2 changed the title [Train] Add PyTorch local mode support with distributed training capabilities [Train] Add PyTorch local mode support for multi-process training with torchrun Sep 4, 2025
@xinyuangui2 xinyuangui2 marked this pull request as ready for review September 4, 2025 20:49
@xinyuangui2 xinyuangui2 requested a review from a team as a code owner September 4, 2025 20:49
@ray-gardener ray-gardener bot added the train Ray Train Related Issue label Sep 5, 2025
logger = logging.getLogger(__name__)


def is_torch_dist_env_set() -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: is_torch_distributed_env_vars_set() to align with https://github.com/ray-project/ray/blob/master/python/ray/train/torch/config.py#L143 ?

is the CUDA_VISIBLE_DEVICES env var needed to be set for cuda env?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call. Let me change to align.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After second thought, I think current env variables are the minimum requirement for processes. to communicate with each other.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
train Ray Train Related Issue
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants