From c5ceb48b55b797e860b3f45c2f88254b3db6712c Mon Sep 17 00:00:00 2001 From: Brian Flores Date: Thu, 13 Feb 2025 18:13:12 -0800 Subject: [PATCH] add edge case for models that are marked as not found in cache (#3523) There is a code change that requires to check the response of the model undeploy response object to check that the model has been marked as not found on all nodes. Signed-off-by: Brian Flores --- .../TransportUndeployModelsAction.java | 16 ++++- .../TransportUndeployModelsActionTests.java | 59 +++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java index 3d1ee76996..de44cde042 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsAction.java @@ -6,10 +6,12 @@ package org.opensearch.ml.action.undeploy; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.NOT_FOUND; import java.time.Instant; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import org.opensearch.ExceptionsHelper; @@ -198,7 +200,19 @@ private void undeployModels( * Having this change enables a check that this edge case occurs along with having access to the model id * allowing us to update the stale model index correctly to `UNDEPLOYED` since no nodes service the model. */ - if (response.getNodes().isEmpty()) { + boolean modelNotFoundInNodesCache = response.getNodes().stream().allMatch(nodeResponse -> { + Map status = nodeResponse.getModelUndeployStatus(); + if (status == null) + return false; + // Stream is used to catch all models edge case but only one is ever undeployed + boolean modelCacheMissForModelIds = Arrays.stream(modelIds).allMatch(modelId -> { + String modelStatus = status.get(modelId); + return modelStatus != null && modelStatus.equalsIgnoreCase(NOT_FOUND); + }); + + return modelCacheMissForModelIds; + }); + if (response.getNodes().isEmpty() || modelNotFoundInNodesCache) { bulkSetModelIndexToUndeploy(modelIds, listener, response); return; } diff --git a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java index 10c9ec0050..14d47ae10c 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/undeploy/TransportUndeployModelsActionTests.java @@ -18,10 +18,12 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; +import static org.opensearch.ml.common.CommonValue.NOT_FOUND; import static org.opensearch.ml.task.MLPredictTaskRunnerTests.USER_STRING; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -348,6 +350,63 @@ public void testHiddenModelSuccess() { verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); } + public void testDoExecute_bulkRequestFired_WhenModelNotFoundInAllNodes() { + MLModel mlModel = MLModel + .builder() + .user(User.parse(USER_STRING)) + .modelGroupId("111") + .version("111") + .name(this.modelIds[0]) + .modelId(this.modelIds[0]) + .algorithm(FunctionName.BATCH_RCF) + .content("content") + .totalChunks(2) + .isHidden(true) + .build(); + + // Mock MLModel manager response + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(4); + listener.onResponse(mlModel); + return null; + }).when(mlModelManager).getModel(any(), any(), any(), any(), isA(ActionListener.class)); + + doReturn(true).when(transportUndeployModelsAction).isSuperAdminUserWrapper(clusterService, client); + + List responseList = new ArrayList<>(); + + for (String nodeId : this.nodeIds) { + Map stats = new HashMap<>(); + stats.put(this.modelIds[0], NOT_FOUND); + MLUndeployModelNodeResponse nodeResponse = mock(MLUndeployModelNodeResponse.class); + when(nodeResponse.getModelUndeployStatus()).thenReturn(stats); + responseList.add(nodeResponse); + } + + List failuresList = new ArrayList<>(); + MLUndeployModelNodesResponse nodesResponse = new MLUndeployModelNodesResponse(clusterName, responseList, failuresList); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(nodesResponse); + return null; + }).when(client).execute(any(), any(), isA(ActionListener.class)); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(mock(BulkResponse.class)); + return null; + }).when(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + + MLUndeployModelsRequest request = new MLUndeployModelsRequest(modelIds, nodeIds, null); + + transportUndeployModelsAction.doExecute(task, request, actionListener); + + // Verify that bulk request was fired because all nodes reported "not_found" + verify(client).bulk(any(BulkRequest.class), any(ActionListener.class)); + verify(actionListener).onResponse(any(MLUndeployModelsResponse.class)); + } + public void testHiddenModelPermissionError() { MLModel mlModel = MLModel .builder()