Skip to content

Commit

Permalink
[Core] Enable Scaling Down for Multi-Host TPU Replicas (#43470)
Browse files Browse the repository at this point in the history
Signed-off-by: ryanaoleary <[email protected]>
  • Loading branch information
ryanaoleary authored Jul 3, 2024
1 parent 86f3e57 commit 2abca38
Show file tree
Hide file tree
Showing 6 changed files with 470 additions and 15 deletions.
18 changes: 17 additions & 1 deletion python/ray/autoscaler/_private/kuberay/node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@

RAY_HEAD_POD_NAME = os.getenv("RAY_HEAD_POD_NAME")

# Key for GKE label that identifies which multi-host replica a pod belongs to
REPLICA_INDEX_KEY = "replicaIndex"

# Design:

# Each modification the autoscaler wants to make is posted to the API server goal state
Expand Down Expand Up @@ -79,7 +82,10 @@ def node_data_from_pod(pod: Dict[str, Any]) -> NodeData:
kind, type = kind_and_type(pod)
status = status_tag(pod)
ip = pod_ip(pod)
return NodeData(kind=kind, type=type, status=status, ip=ip)
replica_index = _replica_index_label(pod)
return NodeData(
kind=kind, type=type, replica_index=replica_index, status=status, ip=ip
)


def kind_and_type(pod: Dict[str, Any]) -> Tuple[NodeKind, NodeType]:
Expand All @@ -96,6 +102,16 @@ def kind_and_type(pod: Dict[str, Any]) -> Tuple[NodeKind, NodeType]:
return kind, type


def _replica_index_label(pod: Dict[str, Any]) -> Optional[str]:
"""Returns the replicaIndex label for a Pod in a multi-host TPU worker group.
The replicaIndex label is set by the GKE TPU Ray webhook and is of
the form {$WORKER_GROUP_NAME-$REPLICA_INDEX} where $REPLICA_INDEX
is an integer from 0 to Replicas-1.
"""
labels = pod["metadata"]["labels"]
return labels.get(REPLICA_INDEX_KEY, None)


def pod_ip(pod: Dict[str, Any]) -> NodeIP:
return pod["status"].get("podIP", "IP not yet assigned")

Expand Down
32 changes: 31 additions & 1 deletion python/ray/autoscaler/batching_node_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NODE_KIND_HEAD,
TAG_RAY_NODE_KIND,
TAG_RAY_NODE_STATUS,
TAG_RAY_REPLICA_INDEX,
TAG_RAY_USER_NODE_TYPE,
)

Expand Down Expand Up @@ -43,6 +44,8 @@ class NodeData:
Attributes:
kind: Whether the node is the head or a worker.
type: The user-defined type of the node.
replica_index: An identifier for nodes in a replica of a TPU worker group.
This value is set as a Pod label by a GKE webhook when TPUs are requested
ip: Cluster-internal ip of the node. ip can be None if the ip
has not yet been assigned.
status: The status of the node. You must adhere to the following semantics
Expand All @@ -58,6 +61,7 @@ class NodeData:
type: NodeType
ip: Optional[NodeIP]
status: NodeStatus
replica_index: Optional[str] = None


class BatchingNodeProvider(NodeProvider):
Expand Down Expand Up @@ -116,6 +120,9 @@ def __init__(

self.scale_request = ScaleRequest()

# Initialize map of replica indices to nodes in that replica
self.replica_index_to_nodes = defaultdict(list[str])

def get_node_data(self) -> Dict[NodeID, NodeData]:
"""Queries cluster manager for node info. Returns a mapping from node id to
NodeData.
Expand Down Expand Up @@ -160,6 +167,12 @@ def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
workers_to_delete=set(), # No workers to delete yet
)
all_nodes = list(self.node_data_dict.keys())
self.replica_index_to_nodes.clear()
for node_id in all_nodes:
replica_index = self.node_data_dict[node_id].replica_index
# Only add node to map if it belongs to a multi-host podslice
if replica_index is not None:
self.replica_index_to_nodes[replica_index].append(node_id)
# Support filtering by TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, and
# TAG_RAY_USER_NODE_TYPE.
# The autoscaler only uses tag_filters={},
Expand Down Expand Up @@ -187,11 +200,14 @@ def _cur_num_workers(self, node_data_dict: Dict[str, Any]):

def node_tags(self, node_id: str) -> Dict[str, str]:
node_data = self.node_data_dict[node_id]
return {
tags = {
TAG_RAY_NODE_KIND: node_data.kind,
TAG_RAY_NODE_STATUS: node_data.status,
TAG_RAY_USER_NODE_TYPE: node_data.type,
}
if node_data.replica_index is not None:
tags[TAG_RAY_REPLICA_INDEX] = node_data.replica_index
return tags

def internal_ip(self, node_id: str) -> str:
return self.node_data_dict[node_id].ip
Expand Down Expand Up @@ -230,6 +246,20 @@ def terminate_node(self, node_id: str) -> Optional[Dict[str, Any]]:
f"{node_type}. Skipping termination request."
)

# Terminate node
self.scale_request.desired_num_workers[node_type] -= 1
self.scale_request.workers_to_delete.add(node_id)

# Scale down all nodes in replica if node_id is part of a multi-host podslice
tags = self.node_tags(node_id)
if TAG_RAY_REPLICA_INDEX in tags:
node_replica_index = tags[TAG_RAY_REPLICA_INDEX]
for worker_id in self.replica_index_to_nodes[node_replica_index]:
# Check if worker has already been scheduled to delete
if worker_id not in self.scale_request.workers_to_delete:
self.scale_request.workers_to_delete.add(worker_id)
logger.info(
f"Autoscaler terminating node {worker_id} "
f"in multi-host replica {node_replica_index}."
)
self.scale_change_needed = True
2 changes: 2 additions & 0 deletions python/ray/autoscaler/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# Tag for user defined node types (e.g., m4xl_spot). This is used for multi
# node type clusters.
TAG_RAY_USER_NODE_TYPE = "ray-user-node-type"
# Tag for index of replica node belongs to. Used for multi-host worker groups.
TAG_RAY_REPLICA_INDEX = "ray-replica-index"
# Tag for autofilled node types for legacy cluster yamls without multi
# node type defined in the cluster configs.
NODE_TYPE_LEGACY_HEAD = "ray-legacy-head-node-type"
Expand Down
Loading

0 comments on commit 2abca38

Please sign in to comment.