Skip to content

Commit

Permalink
remainig sdk client changes for search (opensearch-project#3522)
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <[email protected]>
(cherry picked from commit 5432f25)
  • Loading branch information
dhrubo-os committed Feb 12, 2025
1 parent f9a8a70 commit e2fb636
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
package org.opensearch.ml.action.agents;

import static org.opensearch.ml.action.handler.MLSearchHandler.wrapRestActionListener;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;

import org.opensearch.OpenSearchStatusException;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.ActionFilters;
Expand All @@ -16,6 +16,7 @@
import org.opensearch.common.inject.Inject;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.CommonValue;
Expand All @@ -25,6 +26,8 @@
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.utils.TenantAwareHelper;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -76,11 +79,6 @@ private void search(SearchRequest request, String tenantId, ActionListener<Searc
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
shouldQuery.should(QueryBuilders.termQuery(MLAgent.IS_HIDDEN_FIELD, false));

// For multi-tenancy
if (tenantId != null) {
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
}

// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));

Expand All @@ -91,7 +89,32 @@ private void search(SearchRequest request, String tenantId, ActionListener<Searc
queryBuilder.filter(shouldQuery);

request.source().query(queryBuilder);
client.search(request, wrappedListener);
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(request.indices())
.searchSourceBuilder(request.source())
.tenantId(tenantId)
.build();

sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
if (throwable != null) {
Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
log.error("Failed to search agent", cause);
wrappedListener.onFailure(cause);
} else {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
log.info("Agent search complete: {}", searchResponse.getHits().getTotalHits());
wrappedListener.onResponse(searchResponse);
} catch (Exception e) {
log.error("Failed to parse model search response", e);
wrappedListener
.onFailure(
new OpenSearchStatusException("Failed to parse model search response", RestStatus.INTERNAL_SERVER_ERROR)
);
}
}
});
} catch (Exception e) {
log.error("failed to search the agent index", e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import static org.opensearch.core.rest.RestStatus.BAD_REQUEST;
import static org.opensearch.core.rest.RestStatus.INTERNAL_SERVER_ERROR;
import static org.opensearch.ml.common.CommonValue.TENANT_ID_FIELD;
import static org.opensearch.ml.utils.RestActionUtils.wrapListenerToHandleSearchIndexNotFound;

import java.util.ArrayList;
Expand Down Expand Up @@ -43,6 +42,9 @@
import org.opensearch.ml.common.exception.MLResourceNotFoundException;
import org.opensearch.ml.helper.ModelAccessControlHelper;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.client.SearchDataObjectRequest;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.SearchHits;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
Expand Down Expand Up @@ -77,10 +79,11 @@ public MLSearchHandler(

/**
* Fetch all the models from the model group index, and then create a combined query to model version index.
* @param sdkClient sdkclient a wrapper of the client
* @param request
* @param actionListener
*/
public void search(SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
public void search(SdkClient sdkClient, SearchRequest request, String tenantId, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
ActionListener<SearchResponse> listener = wrapRestActionListener(actionListener, "Fail to search model version");
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
Expand Down Expand Up @@ -114,11 +117,6 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
// Add a should clause to include documents where IS_HIDDEN_FIELD is false
shouldQuery.should(QueryBuilders.termQuery(MLModel.IS_HIDDEN_FIELD, false));

// For multi-tenancy
if (tenantId != null) {
shouldQuery.should(QueryBuilders.termQuery(TENANT_ID_FIELD, tenantId));
}

// Add a should clause to include documents where IS_HIDDEN_FIELD does not exist or is null
shouldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLModel.IS_HIDDEN_FIELD)));

Expand All @@ -132,10 +130,29 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
request.source().fetchSource(rebuiltFetchSourceContext);
final ActionListener<SearchResponse> doubleWrapperListener = ActionListener
.wrap(wrappedListener::onResponse, e -> wrapListenerToHandleSearchIndexNotFound(e, wrappedListener));
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, doubleWrapperListener);
} else if (!clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {
client.search(request, doubleWrapperListener);
if (modelAccessControlHelper.skipModelAccessControl(user)
|| !clusterService.state().metadata().hasIndex(CommonValue.ML_MODEL_GROUP_INDEX)) {

SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(request.indices())
.searchSourceBuilder(request.source())
.tenantId(tenantId)
.build();
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
doubleWrapperListener.onResponse(searchResponse);
} catch (Exception e) {
doubleWrapperListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
doubleWrapperListener.onFailure(e);
}
});
} else {
SearchSourceBuilder sourceBuilder = modelAccessControlHelper.createSearchSourceBuilder(user);
SearchRequest modelGroupSearchRequest = new SearchRequest();
Expand All @@ -154,17 +171,54 @@ public void search(SearchRequest request, String tenantId, ActionListener<Search
Arrays.stream(r.getHits().getHits()).forEach(hit -> { modelGroupIds.add(hit.getId()); });

request.source().query(rewriteQueryBuilder(request.source().query(), modelGroupIds));
client.search(request, doubleWrapperListener);
} else {
log.debug("No model group found");
request.source().query(rewriteQueryBuilder(request.source().query(), null));
client.search(request, doubleWrapperListener);
}
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(request.indices())
.searchSourceBuilder(request.source())
.tenantId(tenantId)
.build();
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((sr, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(sr.parser());
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
doubleWrapperListener.onResponse(searchResponse);
} catch (Exception e) {
doubleWrapperListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
doubleWrapperListener.onFailure(e);
}
});
}, e -> {
log.error("Fail to search model groups!", e);
wrappedListener.onFailure(e);
});
client.search(modelGroupSearchRequest, modelGroupSearchActionListener);
SearchDataObjectRequest searchDataObjectRequest = SearchDataObjectRequest
.builder()
.indices(modelGroupSearchRequest.indices())
.searchSourceBuilder(modelGroupSearchRequest.source())
.tenantId(tenantId)
.build();
sdkClient.searchDataObjectAsync(searchDataObjectRequest).whenComplete((r, throwable) -> {
if (throwable == null) {
try {
SearchResponse searchResponse = SearchResponse.fromXContent(r.parser());
log.info("Model search complete: {}", searchResponse.getHits().getTotalHits());
modelGroupSearchActionListener.onResponse(searchResponse);
} catch (Exception e) {
modelGroupSearchActionListener.onFailure(e);
}
} else {
Exception e = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class);
modelGroupSearchActionListener.onFailure(e);
}
});
}
} catch (Exception e) {
log.error(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@ protected void doExecute(Task task, MLSearchActionRequest request, ActionListene
if (!TenantAwareHelper.validateTenantId(mlFeatureEnabledSetting, tenantId, actionListener)) {
return;
}
mlSearchHandler.search(request, tenantId, actionListener);
mlSearchHandler.search(sdkClient, request, tenantId, actionListener);
}
}
Loading

0 comments on commit e2fb636

Please sign in to comment.