Skip to content

Commit 2abca38

Browse files
authored
[Core] Enable Scaling Down for Multi-Host TPU Replicas (#43470)
Signed-off-by: ryanaoleary <[email protected]>
1 parent 86f3e57 commit 2abca38

File tree

6 files changed

+470
-15
lines changed

6 files changed

+470
-15
lines changed

python/ray/autoscaler/_private/kuberay/node_provider.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@
4545

4646
RAY_HEAD_POD_NAME = os.getenv("RAY_HEAD_POD_NAME")
4747

48+
# Key for GKE label that identifies which multi-host replica a pod belongs to
49+
REPLICA_INDEX_KEY = "replicaIndex"
50+
4851
# Design:
4952

5053
# Each modification the autoscaler wants to make is posted to the API server goal state
@@ -79,7 +82,10 @@ def node_data_from_pod(pod: Dict[str, Any]) -> NodeData:
7982
kind, type = kind_and_type(pod)
8083
status = status_tag(pod)
8184
ip = pod_ip(pod)
82-
return NodeData(kind=kind, type=type, status=status, ip=ip)
85+
replica_index = _replica_index_label(pod)
86+
return NodeData(
87+
kind=kind, type=type, replica_index=replica_index, status=status, ip=ip
88+
)
8389

8490

8591
def kind_and_type(pod: Dict[str, Any]) -> Tuple[NodeKind, NodeType]:
@@ -96,6 +102,16 @@ def kind_and_type(pod: Dict[str, Any]) -> Tuple[NodeKind, NodeType]:
96102
return kind, type
97103

98104

105+
def _replica_index_label(pod: Dict[str, Any]) -> Optional[str]:
106+
"""Returns the replicaIndex label for a Pod in a multi-host TPU worker group.
107+
The replicaIndex label is set by the GKE TPU Ray webhook and is of
108+
the form {$WORKER_GROUP_NAME-$REPLICA_INDEX} where $REPLICA_INDEX
109+
is an integer from 0 to Replicas-1.
110+
"""
111+
labels = pod["metadata"]["labels"]
112+
return labels.get(REPLICA_INDEX_KEY, None)
113+
114+
99115
def pod_ip(pod: Dict[str, Any]) -> NodeIP:
100116
return pod["status"].get("podIP", "IP not yet assigned")
101117

python/ray/autoscaler/batching_node_provider.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
NODE_KIND_HEAD,
1515
TAG_RAY_NODE_KIND,
1616
TAG_RAY_NODE_STATUS,
17+
TAG_RAY_REPLICA_INDEX,
1718
TAG_RAY_USER_NODE_TYPE,
1819
)
1920

@@ -43,6 +44,8 @@ class NodeData:
4344
Attributes:
4445
kind: Whether the node is the head or a worker.
4546
type: The user-defined type of the node.
47+
replica_index: An identifier for nodes in a replica of a TPU worker group.
48+
This value is set as a Pod label by a GKE webhook when TPUs are requested
4649
ip: Cluster-internal ip of the node. ip can be None if the ip
4750
has not yet been assigned.
4851
status: The status of the node. You must adhere to the following semantics
@@ -58,6 +61,7 @@ class NodeData:
5861
type: NodeType
5962
ip: Optional[NodeIP]
6063
status: NodeStatus
64+
replica_index: Optional[str] = None
6165

6266

6367
class BatchingNodeProvider(NodeProvider):
@@ -116,6 +120,9 @@ def __init__(
116120

117121
self.scale_request = ScaleRequest()
118122

123+
# Initialize map of replica indices to nodes in that replica
124+
self.replica_index_to_nodes = defaultdict(list[str])
125+
119126
def get_node_data(self) -> Dict[NodeID, NodeData]:
120127
"""Queries cluster manager for node info. Returns a mapping from node id to
121128
NodeData.
@@ -160,6 +167,12 @@ def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
160167
workers_to_delete=set(), # No workers to delete yet
161168
)
162169
all_nodes = list(self.node_data_dict.keys())
170+
self.replica_index_to_nodes.clear()
171+
for node_id in all_nodes:
172+
replica_index = self.node_data_dict[node_id].replica_index
173+
# Only add node to map if it belongs to a multi-host podslice
174+
if replica_index is not None:
175+
self.replica_index_to_nodes[replica_index].append(node_id)
163176
# Support filtering by TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, and
164177
# TAG_RAY_USER_NODE_TYPE.
165178
# The autoscaler only uses tag_filters={},
@@ -187,11 +200,14 @@ def _cur_num_workers(self, node_data_dict: Dict[str, Any]):
187200

188201
def node_tags(self, node_id: str) -> Dict[str, str]:
189202
node_data = self.node_data_dict[node_id]
190-
return {
203+
tags = {
191204
TAG_RAY_NODE_KIND: node_data.kind,
192205
TAG_RAY_NODE_STATUS: node_data.status,
193206
TAG_RAY_USER_NODE_TYPE: node_data.type,
194207
}
208+
if node_data.replica_index is not None:
209+
tags[TAG_RAY_REPLICA_INDEX] = node_data.replica_index
210+
return tags
195211

196212
def internal_ip(self, node_id: str) -> str:
197213
return self.node_data_dict[node_id].ip
@@ -230,6 +246,20 @@ def terminate_node(self, node_id: str) -> Optional[Dict[str, Any]]:
230246
f"{node_type}. Skipping termination request."
231247
)
232248

249+
# Terminate node
233250
self.scale_request.desired_num_workers[node_type] -= 1
234251
self.scale_request.workers_to_delete.add(node_id)
252+
253+
# Scale down all nodes in replica if node_id is part of a multi-host podslice
254+
tags = self.node_tags(node_id)
255+
if TAG_RAY_REPLICA_INDEX in tags:
256+
node_replica_index = tags[TAG_RAY_REPLICA_INDEX]
257+
for worker_id in self.replica_index_to_nodes[node_replica_index]:
258+
# Check if worker has already been scheduled to delete
259+
if worker_id not in self.scale_request.workers_to_delete:
260+
self.scale_request.workers_to_delete.add(worker_id)
261+
logger.info(
262+
f"Autoscaler terminating node {worker_id} "
263+
f"in multi-host replica {node_replica_index}."
264+
)
235265
self.scale_change_needed = True

python/ray/autoscaler/tags.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# Tag for user defined node types (e.g., m4xl_spot). This is used for multi
1414
# node type clusters.
1515
TAG_RAY_USER_NODE_TYPE = "ray-user-node-type"
16+
# Tag for index of replica node belongs to. Used for multi-host worker groups.
17+
TAG_RAY_REPLICA_INDEX = "ray-replica-index"
1618
# Tag for autofilled node types for legacy cluster yamls without multi
1719
# node type defined in the cluster configs.
1820
NODE_TYPE_LEGACY_HEAD = "ray-legacy-head-node-type"

0 commit comments

Comments
 (0)