Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test cases for RemoteConnector retry behavior #3504

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ThreadPool;

import lombok.Builder;

public interface RemoteConnectorExecutor {

public String RETRY_EXECUTOR = "opensearch_ml_predict_remote";

default void executeAction(String action, MLInput mlInput, ActionListener<MLTaskResponse> actionListener) {
Expand Down Expand Up @@ -253,38 +257,23 @@ default void invokeRemoteServiceWithRetry(
ExecutionContext executionContext,
ActionListener<Tuple<Integer, ModelTensors>> actionListener
) {
final RetryableAction<Tuple<Integer, ModelTensors>> invokeRemoteModelAction = new RetryableAction<>(
final RetryableAction<Tuple<Integer, ModelTensors>> invokeRemoteModelAction = new RetryableActionExtension(
getLogger(),
getClient().threadPool(),
TimeValue.timeValueMillis(getConnectorClientConfig().getRetryBackoffMillis()),
TimeValue.timeValueSeconds(getConnectorClientConfig().getRetryTimeoutSeconds()),
actionListener,
getRetryBackoffPolicy(getConnectorClientConfig()),
RETRY_EXECUTOR
) {
int retryTimes = 0;

@Override
public void tryAction(ActionListener<Tuple<Integer, ModelTensors>> listener) {
// the listener here is RetryingListener
// If the request success, or can not retry, will call delegate listener
invokeRemoteService(action, mlInput, parameters, payload, executionContext, listener);
}

@Override
public boolean shouldRetry(Exception e) {
Throwable cause = ExceptionsHelper.unwrapCause(e);
Integer maxRetryTimes = getConnectorClientConfig().getMaxRetryTimes();
boolean shouldRetry = cause instanceof RemoteConnectorThrottlingException;
if (++retryTimes > maxRetryTimes && maxRetryTimes != -1) {
shouldRetry = false;
}
if (shouldRetry) {
getLogger().debug(String.format(Locale.ROOT, "The %d-th retry for invoke remote model", retryTimes), e);
}
return shouldRetry;
}
};
RetryableActionExtensionArgs
.builder()
.connectionExecutor(this)
.mlInput(mlInput)
.action(action)
.parameters(parameters)
.executionContext(executionContext)
.payload(payload)
.build()
);
invokeRemoteModelAction.run();
};

Expand All @@ -296,4 +285,56 @@ void invokeRemoteService(
ExecutionContext executionContext,
ActionListener<Tuple<Integer, ModelTensors>> actionListener
);

static class RetryableActionExtension extends RetryableAction<Tuple<Integer, ModelTensors>> {
private final RetryableActionExtensionArgs args;
int retryTimes = 0;

RetryableActionExtension(
Logger logger,
ThreadPool threadPool,
TimeValue initialDelay,
TimeValue timeoutValue,
ActionListener<Tuple<Integer, ModelTensors>> listener,
BackoffPolicy backoffPolicy,
RetryableActionExtensionArgs args
) {
super(logger, threadPool, initialDelay, timeoutValue, listener, backoffPolicy, RETRY_EXECUTOR);
this.args = args;
}

@Override
public void tryAction(ActionListener<Tuple<Integer, ModelTensors>> listener) {
// the listener here is RetryingListener
// If the request success, or can not retry, will call delegate listener
args.connectionExecutor
.invokeRemoteService(args.action, args.mlInput, args.parameters, args.payload, args.executionContext, listener);
}

@Override
public boolean shouldRetry(Exception e) {
Throwable cause = ExceptionsHelper.unwrapCause(e);
Integer maxRetryTimes = args.connectionExecutor.getConnectorClientConfig().getMaxRetryTimes();
boolean shouldRetry = cause instanceof RemoteConnectorThrottlingException;
if (++retryTimes > maxRetryTimes && maxRetryTimes != -1) {
shouldRetry = false;
}
if (shouldRetry) {
args.connectionExecutor
.getLogger()
.debug(String.format(Locale.ROOT, "The %d-th retry for invoke remote model", retryTimes), e);
}
return shouldRetry;
}
}

@Builder
class RetryableActionExtensionArgs {
private final RemoteConnectorExecutor connectionExecutor;
private final MLInput mlInput;
private final String action;
private final Map<String, String> parameters;
private final ExecutionContext executionContext;
private final String payload;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
package org.opensearch.ml.engine.algorithms.remote;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Map;
import java.util.function.Supplier;

import org.apache.logging.log4j.Logger;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.opensearch.action.bulk.BackoffPolicy;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.connector.ConnectorClientConfig;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor.RetryableActionExtension;
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor.RetryableActionExtensionArgs;
import org.opensearch.threadpool.ThreadPool;

@RunWith(MockitoJUnitRunner.class)
public class RemoteConnectorExecutor_RetryableActionExtensionTest {

private static final int TEST_ATTEMPT_LIMIT = 10;

@Mock
Logger logger;
@Mock
ThreadPool threadPool;
@Mock
TimeValue initialDelay;
@Mock
TimeValue timeoutValue;
@Mock
ActionListener<Tuple<Integer, ModelTensors>> listener;
@Mock
BackoffPolicy backoffPolicy;
@Mock
ConnectorClientConfig connectorClientConfig;
@Mock
RemoteConnectorExecutor connectionExecutor;

RetryableActionExtension retryableAction;

@Before
public void setup() {
when(connectionExecutor.getConnectorClientConfig()).thenReturn(connectorClientConfig);
when(connectionExecutor.getLogger()).thenReturn(logger);
var args = RetryableActionExtensionArgs.builder()
.action("action")
.connectionExecutor(connectionExecutor)
.mlInput(mock(MLInput.class))
.parameters(Map.of())
.executionContext(mock(ExecutionContext.class))
.payload("payload")
.build();
var settings = Settings.builder().put("node.name", "test").build();
retryableAction = new RetryableActionExtension(logger, new ThreadPool(settings), TimeValue.timeValueMillis(5), TimeValue.timeValueMillis(500), listener, backoffPolicy, args);
}

@Test
public void test_ShouldRetry_hitLimitOnRetries() {
var attempts = retryAttempts(-1, this::createThrottleException);

assertThat(attempts, equalTo(TEST_ATTEMPT_LIMIT));
}

@Test
@SuppressWarnings("unchecked")
public void test_ShouldRetry_OnlyOnThrottleExceptions() {
var exceptions = mock(Supplier.class);
when(exceptions.get())
.thenReturn(createThrottleException())
.thenReturn(createThrottleException())
.thenReturn(new RuntimeException()); // Stop retrying on 3rd exception
var attempts = retryAttempts(-1, exceptions);

assertThat(attempts, equalTo(3));
}

@Test
public void test_ShouldRetry_stopAtMaxAttempts() {
var attempts = retryAttempts(3, this::createThrottleException);

assertThat(attempts, equalTo(4));
}

private int retryAttempts(int maxAttempts, Supplier<Exception> exception) {
when(connectorClientConfig.getMaxRetryTimes()).thenReturn(maxAttempts);
int attempt = 0;
boolean shouldRetry;
do {
shouldRetry = retryableAction.shouldRetry(exception.get());
attempt++;
} while (attempt < TEST_ATTEMPT_LIMIT && shouldRetry);
return attempt;
}

private RemoteConnectorThrottlingException createThrottleException() {
return new RemoteConnectorThrottlingException("Throttle", RestStatus.TOO_MANY_REQUESTS);
}
}
Loading