Skip to content

Commit 716c354

Browse files
Merge branch 'master' into kk-self-play
2 parents 69e1e6e + b738755 commit 716c354

File tree

14 files changed

+362
-176
lines changed

14 files changed

+362
-176
lines changed

python/ray/train/constants.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ def _v2_migration_warnings_enabled() -> bool:
125125
"TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE"
126126
)
127127

128+
# Seconds to wait for torch process group to shut down.
129+
# Shutting down a healthy torch process group, which we may want to do for reasons
130+
# like restarting a group of workers if an async checkpoint upload fails, can hang.
131+
# This is a workaround until we figure out how to avoid this hang.
132+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = "TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S"
133+
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S = 30
128134

129135
# NOTE: When adding a new environment variable, please track it in this list.
130136
TRAIN_ENV_VARS = {
@@ -137,6 +143,7 @@ def _v2_migration_warnings_enabled() -> bool:
137143
RAY_TRAIN_COUNT_PREEMPTION_AS_FAILURE,
138144
RAY_TRAIN_ENABLE_STATE_TRACKING,
139145
TUNE_ONLY_STORE_CHECKPOINT_SCORE_ATTRIBUTE,
146+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
140147
}
141148

142149
# Key for AIR Checkpoint metadata in TrainingResult metadata

python/ray/train/tests/test_backend.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ray.train.constants import (
2929
ENABLE_SHARE_CUDA_VISIBLE_DEVICES_ENV,
3030
ENABLE_SHARE_NEURON_CORES_ACCELERATOR_ENV,
31+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
3132
TRAIN_ENABLE_WORKER_SPREAD_ENV,
3233
)
3334
from ray.train.torch import TorchConfig
@@ -364,6 +365,24 @@ def check_process_group():
364365
assert not any(e.finish_training())
365366

366367

368+
@pytest.mark.parametrize(
369+
"init_method, timeout_s", [("env", 5), ("tcp", 5), ("env", 0), ("tcp", 0)]
370+
)
371+
def test_torch_process_group_shutdown_timeout(
372+
ray_start_2_cpus, monkeypatch, init_method, timeout_s
373+
):
374+
monkeypatch.setenv(TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S, timeout_s)
375+
torch_config = TorchConfig(backend="gloo", init_method=init_method)
376+
e = BackendExecutor(torch_config, num_workers=2)
377+
e.start()
378+
379+
_start_training(e, lambda: 1)
380+
assert e.finish_training() == [1, 1]
381+
382+
# Verify that we do not raise an exception even if we time out
383+
e._backend.on_shutdown(e.worker_group, e._backend_config)
384+
385+
367386
@pytest.mark.parametrize(
368387
"worker_results",
369388
[

python/ray/train/torch/config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,16 @@
1010

1111
import ray
1212
from ray._common.network_utils import build_address
13+
from ray._private import ray_constants
1314
from ray.air._internal.device_manager import register_custom_torch_dist_backend
15+
from ray.exceptions import GetTimeoutError
1416
from ray.train._internal.utils import get_address_and_port
1517
from ray.train._internal.worker_group import WorkerGroup
1618
from ray.train.backend import Backend, BackendConfig
19+
from ray.train.constants import (
20+
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
21+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
22+
)
1723
from ray.util import PublicAPI
1824

1925
logger = logging.getLogger(__name__)
@@ -202,11 +208,21 @@ def set_env_vars(addr, port):
202208
else:
203209
raise RuntimeError("Distributed torch is not available.")
204210

205-
def on_shutdown(self, worker_group: WorkerGroup, backend_config: TorchConfig):
206-
worker_group.execute(
211+
def on_shutdown(self, worker_group, backend_config):
212+
futures = worker_group.execute_async(
207213
_shutdown_torch,
208214
destroy_process_group=len(worker_group) > 1,
209215
)
216+
timeout_s = ray_constants.env_integer(
217+
TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
218+
DEFAULT_TORCH_PROCESS_GROUP_SHUTDOWN_TIMEOUT_S,
219+
)
220+
try:
221+
ray.get(futures, timeout=timeout_s)
222+
except GetTimeoutError:
223+
logger.warning(
224+
f"Torch process group shutdown timed out after {timeout_s} seconds"
225+
)
210226

211227
def on_training_start(
212228
self, worker_group: WorkerGroup, backend_config: BackendConfig

src/ray/core_worker/BUILD.bazel

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ ray_cc_library(
2626
":experimental_mutable_object_provider",
2727
":future_resolver",
2828
":generator_waiter",
29+
":grpc_service",
2930
":memory_store",
3031
":object_recovery_manager",
3132
":plasma_store_provider",
@@ -46,7 +47,6 @@ ray_cc_library(
4647
"//src/ray/pubsub:subscriber",
4748
"//src/ray/raylet_client:raylet_client_lib",
4849
"//src/ray/rpc:core_worker_client",
49-
"//src/ray/rpc:core_worker_server",
5050
"//src/ray/rpc:metrics_agent_client",
5151
"//src/ray/stats:stats_lib",
5252
"//src/ray/util:container_util",
@@ -65,6 +65,24 @@ ray_cc_library(
6565
],
6666
)
6767

68+
ray_cc_library(
69+
name = "grpc_service",
70+
srcs = [
71+
"grpc_service.cc",
72+
],
73+
hdrs = [
74+
"grpc_service.h",
75+
],
76+
visibility = [":__subpackages__"],
77+
deps = [
78+
"//src/ray/common:asio",
79+
"//src/ray/protobuf:core_worker_cc_grpc",
80+
"//src/ray/protobuf:core_worker_cc_proto",
81+
"//src/ray/rpc:grpc_server",
82+
"//src/ray/rpc:server_call",
83+
],
84+
)
85+
6886
ray_cc_library(
6987
name = "shutdown_coordinator",
7088
srcs = [

src/ray/core_worker/core_worker.h

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,10 @@
5353
#include "ray/pubsub/publisher.h"
5454
#include "ray/pubsub/subscriber.h"
5555
#include "ray/raylet_client/raylet_client_interface.h"
56-
#include "ray/rpc/worker/core_worker_server.h"
5756
#include "ray/util/process.h"
5857
#include "ray/util/shared_lru.h"
5958
#include "src/ray/protobuf/pubsub.pb.h"
6059

61-
/// The set of gRPC handlers and their associated level of concurrency. If you want to
62-
/// add a new call to the worker gRPC server, do the following:
63-
/// 1) Add the rpc to the CoreWorkerService in core_worker.proto, e.g., "ExampleCall"
64-
/// 2) Add a new macro to RAY_CORE_WORKER_DECLARE_RPC_HANDLERS
65-
/// in core_worker_server.h,
66-
// e.g. "DECLARE_VOID_RPC_SERVICE_HANDLER_METHOD(ExampleCall)"
67-
/// 3) Add a new macro to RAY_CORE_WORKER_RPC_HANDLERS in core_worker_server.h, e.g.
68-
/// "RPC_SERVICE_HANDLER(CoreWorkerService, ExampleCall, 1)"
69-
/// 4) Add a method to the CoreWorker class below: "CoreWorker::HandleExampleCall"
70-
7160
namespace ray::core {
7261

7362
JobID GetProcessJobID(const CoreWorkerOptions &options);

src/ray/core_worker/core_worker_process.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,8 @@ std::shared_ptr<CoreWorker> CoreWorkerProcessImpl::CreateCoreWorker(
255255
// Start RPC server after all the task receivers are properly initialized and we have
256256
// our assigned port from the raylet.
257257
core_worker_server->RegisterService(
258-
std::make_unique<rpc::CoreWorkerGrpcService>(io_service_, *service_handler_),
258+
std::make_unique<rpc::CoreWorkerGrpcService>(
259+
io_service_, *service_handler_, /*max_active_rpcs_per_handler_=*/-1),
259260
false /* token_auth */);
260261
core_worker_server->Run();
261262

src/ray/core_worker/core_worker_process.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <string>
2020

2121
#include "ray/core_worker/core_worker_options.h"
22+
#include "ray/core_worker/grpc_service.h"
2223
#include "ray/rpc/metrics_agent_client.h"
2324
#include "ray/util/mutex_protected.h"
2425

src/ray/core_worker/grpc_service.cc

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright 2025 The Ray Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "ray/core_worker/grpc_service.h"
16+
17+
#include <memory>
18+
#include <vector>
19+
20+
namespace ray {
21+
namespace rpc {
22+
23+
void CoreWorkerGrpcService::InitServerCallFactories(
24+
const std::unique_ptr<grpc::ServerCompletionQueue> &cq,
25+
std::vector<std::unique_ptr<ServerCallFactory>> *server_call_factories,
26+
const ClusterID &cluster_id) {
27+
/// TODO(vitsai): Remove this when auth is implemented for node manager.
28+
/// Disable gRPC server metrics since it incurs too high cardinality.
29+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
30+
CoreWorkerService, PushTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
31+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
32+
ActorCallArgWaitComplete,
33+
max_active_rpcs_per_handler_,
34+
AuthType::NO_AUTH);
35+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
36+
RayletNotifyGCSRestart,
37+
max_active_rpcs_per_handler_,
38+
AuthType::NO_AUTH);
39+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
40+
GetObjectStatus,
41+
max_active_rpcs_per_handler_,
42+
AuthType::NO_AUTH);
43+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
44+
WaitForActorRefDeleted,
45+
max_active_rpcs_per_handler_,
46+
AuthType::NO_AUTH);
47+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
48+
PubsubLongPolling,
49+
max_active_rpcs_per_handler_,
50+
AuthType::NO_AUTH);
51+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
52+
PubsubCommandBatch,
53+
max_active_rpcs_per_handler_,
54+
AuthType::NO_AUTH);
55+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
56+
UpdateObjectLocationBatch,
57+
max_active_rpcs_per_handler_,
58+
AuthType::NO_AUTH);
59+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
60+
GetObjectLocationsOwner,
61+
max_active_rpcs_per_handler_,
62+
AuthType::NO_AUTH);
63+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
64+
ReportGeneratorItemReturns,
65+
max_active_rpcs_per_handler_,
66+
AuthType::NO_AUTH);
67+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
68+
CoreWorkerService, KillActor, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
69+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
70+
CoreWorkerService, CancelTask, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
71+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
72+
RemoteCancelTask,
73+
max_active_rpcs_per_handler_,
74+
AuthType::NO_AUTH);
75+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
76+
RegisterMutableObjectReader,
77+
max_active_rpcs_per_handler_,
78+
AuthType::NO_AUTH);
79+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
80+
GetCoreWorkerStats,
81+
max_active_rpcs_per_handler_,
82+
AuthType::NO_AUTH);
83+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
84+
CoreWorkerService, LocalGC, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
85+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
86+
CoreWorkerService, DeleteObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
87+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
88+
CoreWorkerService, SpillObjects, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
89+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
90+
RestoreSpilledObjects,
91+
max_active_rpcs_per_handler_,
92+
AuthType::NO_AUTH);
93+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
94+
DeleteSpilledObjects,
95+
max_active_rpcs_per_handler_,
96+
AuthType::NO_AUTH);
97+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
98+
PlasmaObjectReady,
99+
max_active_rpcs_per_handler_,
100+
AuthType::NO_AUTH);
101+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(
102+
CoreWorkerService, Exit, max_active_rpcs_per_handler_, AuthType::NO_AUTH);
103+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
104+
AssignObjectOwner,
105+
max_active_rpcs_per_handler_,
106+
AuthType::NO_AUTH);
107+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
108+
NumPendingTasks,
109+
max_active_rpcs_per_handler_,
110+
AuthType::NO_AUTH);
111+
RPC_SERVICE_HANDLER_CUSTOM_AUTH_SERVER_METRICS_DISABLED(CoreWorkerService,
112+
FreeActorObject,
113+
max_active_rpcs_per_handler_,
114+
AuthType::NO_AUTH);
115+
}
116+
117+
} // namespace rpc
118+
} // namespace ray

0 commit comments

Comments
 (0)