14
14
NODE_KIND_HEAD ,
15
15
TAG_RAY_NODE_KIND ,
16
16
TAG_RAY_NODE_STATUS ,
17
+ TAG_RAY_REPLICA_INDEX ,
17
18
TAG_RAY_USER_NODE_TYPE ,
18
19
)
19
20
@@ -43,6 +44,8 @@ class NodeData:
43
44
Attributes:
44
45
kind: Whether the node is the head or a worker.
45
46
type: The user-defined type of the node.
47
+ replica_index: An identifier for nodes in a replica of a TPU worker group.
48
+ This value is set as a Pod label by a GKE webhook when TPUs are requested
46
49
ip: Cluster-internal ip of the node. ip can be None if the ip
47
50
has not yet been assigned.
48
51
status: The status of the node. You must adhere to the following semantics
@@ -58,6 +61,7 @@ class NodeData:
58
61
type : NodeType
59
62
ip : Optional [NodeIP ]
60
63
status : NodeStatus
64
+ replica_index : Optional [str ] = None
61
65
62
66
63
67
class BatchingNodeProvider (NodeProvider ):
@@ -116,6 +120,9 @@ def __init__(
116
120
117
121
self .scale_request = ScaleRequest ()
118
122
123
+ # Initialize map of replica indices to nodes in that replica
124
+ self .replica_index_to_nodes = defaultdict (list [str ])
125
+
119
126
def get_node_data (self ) -> Dict [NodeID , NodeData ]:
120
127
"""Queries cluster manager for node info. Returns a mapping from node id to
121
128
NodeData.
@@ -160,6 +167,12 @@ def non_terminated_nodes(self, tag_filters: Dict[str, str]) -> List[str]:
160
167
workers_to_delete = set (), # No workers to delete yet
161
168
)
162
169
all_nodes = list (self .node_data_dict .keys ())
170
+ self .replica_index_to_nodes .clear ()
171
+ for node_id in all_nodes :
172
+ replica_index = self .node_data_dict [node_id ].replica_index
173
+ # Only add node to map if it belongs to a multi-host podslice
174
+ if replica_index is not None :
175
+ self .replica_index_to_nodes [replica_index ].append (node_id )
163
176
# Support filtering by TAG_RAY_NODE_KIND, TAG_RAY_NODE_STATUS, and
164
177
# TAG_RAY_USER_NODE_TYPE.
165
178
# The autoscaler only uses tag_filters={},
@@ -187,11 +200,14 @@ def _cur_num_workers(self, node_data_dict: Dict[str, Any]):
187
200
188
201
def node_tags (self , node_id : str ) -> Dict [str , str ]:
189
202
node_data = self .node_data_dict [node_id ]
190
- return {
203
+ tags = {
191
204
TAG_RAY_NODE_KIND : node_data .kind ,
192
205
TAG_RAY_NODE_STATUS : node_data .status ,
193
206
TAG_RAY_USER_NODE_TYPE : node_data .type ,
194
207
}
208
+ if node_data .replica_index is not None :
209
+ tags [TAG_RAY_REPLICA_INDEX ] = node_data .replica_index
210
+ return tags
195
211
196
212
def internal_ip (self , node_id : str ) -> str :
197
213
return self .node_data_dict [node_id ].ip
@@ -230,6 +246,20 @@ def terminate_node(self, node_id: str) -> Optional[Dict[str, Any]]:
230
246
f"{ node_type } . Skipping termination request."
231
247
)
232
248
249
+ # Terminate node
233
250
self .scale_request .desired_num_workers [node_type ] -= 1
234
251
self .scale_request .workers_to_delete .add (node_id )
252
+
253
+ # Scale down all nodes in replica if node_id is part of a multi-host podslice
254
+ tags = self .node_tags (node_id )
255
+ if TAG_RAY_REPLICA_INDEX in tags :
256
+ node_replica_index = tags [TAG_RAY_REPLICA_INDEX ]
257
+ for worker_id in self .replica_index_to_nodes [node_replica_index ]:
258
+ # Check if worker has already been scheduled to delete
259
+ if worker_id not in self .scale_request .workers_to_delete :
260
+ self .scale_request .workers_to_delete .add (worker_id )
261
+ logger .info (
262
+ f"Autoscaler terminating node { worker_id } "
263
+ f"in multi-host replica { node_replica_index } ."
264
+ )
235
265
self .scale_change_needed = True
0 commit comments