diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java index e9df1b2652..89dc846485 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportSearchAgentAction.java @@ -6,8 +6,8 @@ package org.opensearch.ml.action.agents; import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener; -import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; +import org.opensearch.OpenSearchStatusException; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.action.support.ActionFilters; @@ -16,6 +16,7 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.ml.common.CommonValue; @@ -25,6 +26,8 @@ import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.TenantAwareHelper; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -76,11 +79,6 @@ private void search(SearchRequest request, String tenantId, ActionListener { + if (throwable != null) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class); + log.error("Failed to search agent", cause); + wrappedListener.onFailure(cause); + } else { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Agent search complete: {}", searchResponse.getHits().getTotalHits()); + wrappedListener.onResponse(searchResponse); + } catch (Exception e) { + log.error("Failed to parse model search response", e); + wrappedListener + .onFailure( + new OpenSearchStatusException("Failed to parse model search response", RestStatus.INTERNAL_SERVER_ERROR) + ); + } + } + }); } catch (Exception e) { log.error("failed to search the agent index", e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java index 5e02caa14a..b11a8b4a16 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/action/handler/MLSearchHandler.java @@ -7,7 +7,6 @@ import static org.opensearch.core.rest.RestStatus.BAD_REQUEST; import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR; -import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound; import java.util.ArrayList; @@ -43,6 +42,9 @@ import org.opensearch.ml.common.exception.MLResourceNotFoundException; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.utils.RestActionUtils; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.SearchDataObjectRequest; +import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.FetchSourceContext; @@ -77,10 +79,11 @@ public MLSearchHandler( /** * Fetch all the models from the model group index, and then create a combined query to model version index. + * @param sdkClient sdkclient a wrapper of the client * @param request * @param actionListener */ - public void search(SearchRequest request, String tenantId, ActionListener actionListener) { + public void search(SdkClient sdkClient, SearchRequest request, String tenantId, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); ActionListener listener = wrapRestActionListener(actionListener, "Fail to search model version"); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { @@ -114,11 +117,6 @@ public void search(SearchRequest request, String tenantId, ActionListener doubleWrapperListener = ActionListener .wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener)); - if (modelAccessControlHelper.skipModelAccessControl(user)) { - client.search(request, doubleWrapperListener); - } else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) { - client.search(request, doubleWrapperListener); + if (modelAccessControlHelper.skipModelAccessControl(user) + || !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) { + + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .tenantId(tenantId) + .build(); + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model search complete: {}", searchResponse.getHits().getTotalHits()); + doubleWrapperListener.onResponse(searchResponse); + } catch (Exception e) { + doubleWrapperListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class); + doubleWrapperListener.onFailure(e); + } + }); } else { SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user); SearchRequest modelGroupSearchRequest = new SearchRequest(); @@ -154,17 +171,54 @@ public void search(SearchRequest request, String tenantId, ActionListener { modelGroupIds.add(hit.getId()); }); request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds)); - client.search(request, doubleWrapperListener); } else { log.debug("No model group found"); request.source().query(rewriteQueryBuilder(request.source().query(), null)); - client.search(request, doubleWrapperListener); } + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(request.indices()) + .searchSourceBuilder(request.source()) + .tenantId(tenantId) + .build(); + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((sr, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser()); + log.info("Model search complete: {}", searchResponse.getHits().getTotalHits()); + doubleWrapperListener.onResponse(searchResponse); + } catch (Exception e) { + doubleWrapperListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class); + doubleWrapperListener.onFailure(e); + } + }); }, e -> { log.error("Fail to search model groups!", e); wrappedListener.onFailure(e); }); - client.search(modelGroupSearchRequest, modelGroupSearchActionListener); + SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest + .builder() + .indices(modelGroupSearchRequest.indices()) + .searchSourceBuilder(modelGroupSearchRequest.source()) + .tenantId(tenantId) + .build(); + sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> { + if (throwable == null) { + try { + SearchResponse searchResponse = SearchResponse.fromXContent(r.parser()); + log.info("Model search complete: {}", searchResponse.getHits().getTotalHits()); + modelGroupSearchActionListener.onResponse(searchResponse); + } catch (Exception e) { + modelGroupSearchActionListener.onFailure(e); + } + } else { + Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class); + modelGroupSearchActionListener.onFailure(e); + } + }); } } catch (Exception e) { log.error(e.getMessage(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java index 862222b001..bc313c0205 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/SearchModelTransportAction.java @@ -48,6 +48,6 @@ protected void doExecute(Task task, MLSearchActionRequest request, ActionListene if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) { return; } - mlSearchHandler.search(request, tenantId, actionListener); + mlSearchHandler.search(sdkClient, request, tenantId, actionListener); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java index 7f81249689..9a37a43be4 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/agents/TransportSearchAgentActionTests.java @@ -6,7 +6,6 @@ package org.opensearch.ml.action.agents; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.*; import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD; @@ -107,14 +106,24 @@ public void testDoExecuteWithEmptyQuery() { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + // Capture the actual SearchRequest passed to client.search() + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SearchRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; - }).when(client).search(eq(mlSearchActionRequest), any()); + }).when(client).search(requestCaptor.capture(), any()); transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(client, times(1)).search(eq(mlSearchActionRequest), any()); + verify(client, times(1)).search(any(), any()); + + // Get the actual SearchRequest used in the method + SearchRequest actualRequest = requestCaptor.getValue(); + + // Validate that the query has been modified as expected + assertNotNull(actualRequest.source().query()); + assertTrue(actualRequest.source().query().toString().contains("is_hidden")); + // Use ArgumentCaptor to capture the SearchResponse ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); // Capture the response passed to actionListener.onResponse @@ -124,29 +133,44 @@ public void testDoExecuteWithEmptyQuery() { assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); assertEquals(searchResponse.getHits().getHits().length, capturedResponse.getHits().getHits().length); assertEquals(searchResponse.status(), capturedResponse.status()); - } @Test public void testDoExecuteWithNonEmptyQuery() { + // Create a search request with a MatchAllQuery SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); - sourceBuilder.query(QueryBuilders.matchAllQuery()); + sourceBuilder.query(QueryBuilders.matchAllQuery()); // Non-empty query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + // Capture the actual SearchRequest passed to client.search() + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SearchRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; - }).when(client).search(eq(mlSearchActionRequest), any()); + }).when(client).search(requestCaptor.capture(), any()); + // Execute the method transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(client, times(1)).search(eq(mlSearchActionRequest), any()); - // Use ArgumentCaptor to capture the SearchResponse + // Verify that client.search was called once + verify(client, times(1)).search(any(), any()); + + // Get the actual SearchRequest used in the method + SearchRequest actualRequest = requestCaptor.getValue(); + + // Validate that the original MatchAllQuery is included + assertNotNull(actualRequest.source().query()); + assertTrue(actualRequest.source().query().toString().contains("match_all")); + + // Validate that "is_hidden" filtering logic is applied + assertTrue(actualRequest.source().query().toString().contains("is_hidden")); + + // Capture and validate the SearchResponse ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); - // Capture the response passed to actionListener.onResponse verify(actionListener, times(1)).onResponse(responseCaptor.capture()); + // Assert that the captured response matches the expected values SearchResponse capturedResponse = responseCaptor.getValue(); assertEquals(searchResponse.getHits().getTotalHits(), capturedResponse.getHits().getTotalHits()); @@ -158,16 +182,34 @@ public void testDoExecuteWithNonEmptyQuery() { public void testDoExecuteOnFailure() { SearchRequest request = new SearchRequest("my_index"); MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + + // Capture the actual SearchRequest passed to client.search() + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SearchRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("test exception")); return null; - }).when(client).search(eq(mlSearchActionRequest), any()); + }).when(client).search(requestCaptor.capture(), any()); + // Execute the method transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(client, times(1)).search(eq(mlSearchActionRequest), any()); - verify(actionListener, times(1)).onFailure(any(Exception.class)); + // Verify that client.search was called once + verify(client, times(1)).search(any(), any()); + + // Get the actual SearchRequest used in the method + SearchRequest actualRequest = requestCaptor.getValue(); + assertNotNull(actualRequest.source()); + + // Validate that "is_hidden" filtering logic is applied + assertTrue(actualRequest.source().query().toString().contains("is_hidden")); + + // Verify that actionListener.onFailure was called with an exception + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener, times(1)).onFailure(exceptionCaptor.capture()); + + // Assert that the captured exception has the expected message + assertEquals("Fail to search agent", exceptionCaptor.getValue().getMessage()); } @Test @@ -176,15 +218,17 @@ public void testSearchWithHiddenField() { sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + // Capture the actual SearchRequest passed to client.search() + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SearchRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onResponse(searchResponse); return null; - }).when(client).search(eq(mlSearchActionRequest), any()); + }).when(client).search(requestCaptor.capture(), any()); transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(client, times(1)).search(eq(mlSearchActionRequest), any()); + verify(client, times(1)).search(any(), any()); // Use ArgumentCaptor to capture the SearchResponse ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(SearchResponse.class); // Capture the response passed to actionListener.onResponse @@ -202,11 +246,13 @@ public void testSearchException() { sourceBuilder.query(QueryBuilders.termQuery("field", "value")); // Simulate user query SearchRequest request = new SearchRequest("my_index").source(sourceBuilder); MLSearchActionRequest mlSearchActionRequest = new MLSearchActionRequest(request, null); + // Capture the actual SearchRequest passed to client.search() + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(SearchRequest.class); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); listener.onFailure(new Exception("failed to search the agent index")); return null; - }).when(client).search(eq(mlSearchActionRequest), any()); + }).when(client).search(requestCaptor.capture(), any()); transportSearchAgentAction.doExecute(null, mlSearchActionRequest, actionListener); diff --git a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java index cfa45ce820..31b14ed288 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/models/SearchModelTransportActionTests.java @@ -173,7 +173,7 @@ public void test_DoExecute_admin() { return null; }).when(client).search(any(), any()); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); } @@ -186,7 +186,7 @@ public void test_DoExecute_addBackendRoles() throws IOException { }).when(client).search(any(), any()); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -198,7 +198,7 @@ public void test_DoExecute_addBackendRoles_without_groupIds() { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -210,7 +210,7 @@ public void test_DoExecute_addBackendRoles_exception() { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); } @@ -222,7 +222,7 @@ public void test_DoExecute_searchModel_before_model_creation_no_exception() { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); } @@ -255,7 +255,7 @@ public void test_DoExecute_searchModel_before_model_creation_empty_search() { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(0)).onFailure(any(IndexNotFoundException.class)); verify(actionListener, times(1)).onResponse(any(SearchResponse.class)); @@ -269,7 +269,7 @@ public void test_DoExecute_searchModel_MLResourceNotFoundException_exception() { }).when(client).search(any(), isA(ActionListener.class)); when(modelAccessControlHelper.skipModelAccessControl(any())).thenReturn(true); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(1)).search(any(), any()); verify(actionListener, times(1)).onFailure(any(OpenSearchStatusException.class)); } @@ -284,7 +284,7 @@ public void test_DoExecute_addBackendRoles_boolQuery() throws IOException { when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.boolQuery().must(QueryBuilders.matchQuery("name", "model_IT"))); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -298,7 +298,7 @@ public void test_DoExecute_addBackendRoles_termQuery() throws IOException { when(modelAccessControlHelper.createSearchSourceBuilder(any())).thenReturn(searchSourceBuilder); searchRequest.source().query(QueryBuilders.termQuery("name", "model_IT")); searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, null, actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, null, actionListener); verify(client, times(2)).search(any(), any()); } @@ -336,7 +336,7 @@ public void testDoExecute_MultiTenancyEnabled_TenantFilteringEnabled() throws In searchModelTransportAction.doExecute(null, mlSearchActionRequest, actionListener); - verify(mlSearchHandler).search(mlSearchActionRequest, "123456", actionListener); + verify(mlSearchHandler).search(sdkClient, mlSearchActionRequest, "123456", actionListener); verify(client, times(2)).search(any(), any()); }