Skip to content

Commit

Permalink
add edge case for models that are marked as not found in cache (#3523)
Browse files Browse the repository at this point in the history
There is a code change that requires to check the response of the model undeploy response object to check that the model has been marked as not found on all nodes.

Signed-off-by: Brian Flores <[email protected]>
  • Loading branch information
brianf-aws authored Feb 14, 2025
1 parent 1b8b014 commit c5ceb48
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
package org.opensearch.ml.action.undeploy;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;

import java.time.Instant;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.ExceptionsHelper;
Expand Down Expand Up @@ -198,7 +200,19 @@ private void undeployModels(
* Having this change enables a check that this edge case occurs along with having access to the model id
* allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model.
*/
if (response.getNodes().isEmpty()) {
boolean modelNotFoundInNodesCache = response.getNodes().stream().allMatch(nodeResponse -> {
Map<String, String> status = nodeResponse.getModelUndeployStatus();
if (status == null)
return false;
// Stream is used to catch all models edge case but only one is ever undeployed
boolean modelCacheMissForModelIds = Arrays.stream(modelIds).allMatch(modelId -> {
String modelStatus = status.get(modelId);
return modelStatus != null && modelStatus.equalsIgnoreCase(NOT_FOUND);
});

return modelCacheMissForModelIds;
});
if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) {
bulkSetModelIndexToUndeploy(modelIds, listener, response);
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.NOT_FOUND;
import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

Expand Down Expand Up @@ -348,6 +350,63 @@ public void testHiddenModelSuccess() {
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
}

public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() {
MLModel mlModel = MLModel
.builder()
.user(User.parse(USER_STRING))
.modelGroupId("111")
.version("111")
.name(this.modelIds[0])
.modelId(this.modelIds[0])
.algorithm(FunctionName.BATCH_RCF)
.content("content")
.totalChunks(2)
.isHidden(true)
.build();

// Mock MLModel manager response
doAnswer(invocation -> {
ActionListener<MLModel> listener = invocation.getArgument(4);
listener.onResponse(mlModel);
return null;
}).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class));

doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client);

List<MLUndeployModelNodeResponse> responseList = new ArrayList<>();

for (String nodeId : this.nodeIds) {
Map<String, String> stats = new HashMap<>();
stats.put(this.modelIds[0], NOT_FOUND);
MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class);
when(nodeResponse.getModelUndeployStatus()).thenReturn(stats);
responseList.add(nodeResponse);
}

List<FailedNodeException> failuresList = new ArrayList<>();
MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList);

doAnswer(invocation -> {
ActionListener<MLUndeployModelNodesResponse> listener = invocation.getArgument(2);
listener.onResponse(nodesResponse);
return null;
}).when(client).execute(any(), any(), isA(ActionListener.class));

doAnswer(invocation -> {
ActionListener<BulkResponse> listener = invocation.getArgument(1);
listener.onResponse(mock(BulkResponse.class));
return null;
}).when(client).bulk(any(BulkRequest.class), any(ActionListener.class));

MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null);

transportUndeployModelsAction.doExecute(task, request, actionListener);

// Verify that bulk request was fired because all nodes reported "not_found"
verify(client).bulk(any(BulkRequest.class), any(ActionListener.class));
verify(actionListener).onResponse(any(MLUndeployModelsResponse.class));
}

public void testHiddenModelPermissionError() {
MLModel mlModel = MLModel
.builder()
Expand Down

0 comments on commit c5ceb48

Please sign in to comment.