From b5cd6a2a58de2ba7df6c940116170dbf758aa743 Mon Sep 17 00:00:00 2001 From: Dhrubo Saha Date: Wed, 12 Feb 2025 11:33:23 -0800 Subject: [PATCH] applying sdkclient changes to config index (#3521) * applying sdkclient changes to config index Signed-off-by: Dhrubo Saha * addressed comments Signed-off-by: Dhrubo Saha --------- Signed-off-by: Dhrubo Saha (cherry picked from commit eccfd4484c79555868022c761091ce649e999ab9) --- .../ml/engine/encryptor/EncryptorImpl.java | 379 ++++++++++++++---- .../engine/encryptor/EncryptorImplTest.java | 247 +++++------- .../ml/plugin/MachineLearningPlugin.java | 3 +- 3 files changed, 392 insertions(+), 237 deletions(-) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 42864e7519..2036767c5b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -12,6 +12,7 @@ import static org.opensearch.ml.common.MLConfig.CREATE_TIME_FIELD; import static org.opensearch.ml.common.utils.StringUtils.hashString; +import java.io.IOException; import java.nio.charset.StandardCharsets; import java.security.SecureRandom; import java.time.Instant; @@ -24,25 +25,30 @@ import javax.crypto.spec.SecretKeySpec; -import org.apache.commons.lang3.exception.ExceptionUtils; +import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; -import org.opensearch.action.DocWriteRequest; -import org.opensearch.action.get.GetRequest; -import org.opensearch.action.index.IndexRequest; -import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.get.GetResponse; +import org.opensearch.action.index.IndexResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.Strings; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.remote.metadata.client.GetDataObjectRequest; +import org.opensearch.remote.metadata.client.GetDataObjectResponse; +import org.opensearch.remote.metadata.client.PutDataObjectRequest; +import org.opensearch.remote.metadata.client.PutDataObjectResponse; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.fetch.subphase.FetchSourceContext; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CommitmentPolicy; import com.amazonaws.encryptionsdk.CryptoResult; import com.amazonaws.encryptionsdk.jce.JceMasterKey; -import com.google.common.collect.ImmutableMap; import lombok.extern.log4j.Log4j2; @@ -53,6 +59,7 @@ public class EncryptorImpl implements Encryptor { "The ML encryption master key has not been initialized yet. Please retry after waiting for 10 seconds."; private ClusterService clusterService; private Client client; + private SdkClient sdkClient; private final Map tenantMasterKeys; private MLIndicesHandler mlIndicesHandler; @@ -60,11 +67,11 @@ public class EncryptorImpl implements Encryptor { // assigning some random string so that it can't be duplicate public static final String DEFAULT_TENANT_ID = "03000200-0400-0500-0006-000700080009"; - public EncryptorImpl(ClusterService clusterService, Client client, MLIndicesHandler mlIndicesHandler) { + public EncryptorImpl(ClusterService clusterService, Client client, SdkClient sdkClient, MLIndicesHandler mlIndicesHandler) { this.tenantMasterKeys = new ConcurrentHashMap<>(); this.clusterService = clusterService; this.client = client; - + this.sdkClient = sdkClient; this.mlIndicesHandler = mlIndicesHandler; } @@ -121,90 +128,22 @@ private void initMasterKey(String tenantId) { if (tenantMasterKeys.containsKey(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID))) { return; } + String masterKeyId = MASTER_KEY; + if (tenantId != null) { + masterKeyId = MASTER_KEY + "_" + hashString(tenantId); + } AtomicReference exceptionRef = new AtomicReference<>(); CountDownLatch latch = new CountDownLatch(1); - mlIndicesHandler.initMLConfigIndex(ActionListener.wrap(r -> { - if (!r) { - exceptionRef.set(new RuntimeException("No response to create ML Config index")); - latch.countDown(); - } else { - String masterKeyId = MASTER_KEY; - if (tenantId != null) { - masterKeyId = MASTER_KEY + "_" + hashString(tenantId); - } - final String MASTER_KEY_ID = masterKeyId; - GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY_ID); - try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getRequest, ActionListener.wrap(getResponse -> { - if (getResponse == null || !getResponse.isExists()) { - IndexRequest indexRequest = new IndexRequest(ML_CONFIG_INDEX).id(MASTER_KEY_ID); - final String generatedMasterKey = generateMasterKey(); - - ImmutableMap.Builder mapBuilder = ImmutableMap.builder(); - mapBuilder.put(MASTER_KEY_ID, generatedMasterKey); - mapBuilder.put(CREATE_TIME_FIELD, Instant.now().toEpochMilli()); - if (tenantId != null) { - mapBuilder.put(TENANT_ID_FIELD, tenantId); - } - indexRequest.source(mapBuilder.build()); - indexRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - indexRequest.opType(DocWriteRequest.OpType.CREATE); - client.index(indexRequest, ActionListener.wrap(indexResponse -> { - this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), generatedMasterKey); - log.info("ML encryption master key initialized successfully"); - latch.countDown(); - }, e -> { - - if (ExceptionUtils.getRootCause(e) instanceof VersionConflictEngineException) { - GetRequest getMasterKeyRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY_ID); - try ( - ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext() - ) { - client.get(getMasterKeyRequest, ActionListener.wrap(getMasterKeyResponse -> { - if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) { - this.tenantMasterKeys - .put( - Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), - (String) getMasterKeyResponse.getSourceAsMap().get(MASTER_KEY_ID) - ); - log.info("ML encryption master key already initialized, no action needed"); - latch.countDown(); - } else { - exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); - latch.countDown(); - } - }, error -> { - log.debug("Failed to get ML encryption master key", e); - exceptionRef.set(error); - latch.countDown(); - })); - } - } else { - log.debug("Failed to index ML encryption master key", e); - exceptionRef.set(e); - latch.countDown(); - } - })); - } else { - final String masterKey = (String) getResponse.getSourceAsMap().get(MASTER_KEY_ID); - this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), masterKey); - log.info("ML encryption master key already initialized, no action needed"); - latch.countDown(); - } - }, e -> { - log.debug("Failed to get ML encryption master key from config index", e); - exceptionRef.set(e); - latch.countDown(); - })); - } - } - }, e -> { - log.debug("Failed to init ML config index", e); - exceptionRef.set(e); - latch.countDown(); - })); + mlIndicesHandler.initMLConfigIndex(createInitMLConfigIndexListener(exceptionRef, latch, tenantId, masterKeyId)); + waitForLatch(latch); + checkMasterKeyInitialization(tenantId, exceptionRef); + } + private void waitForLatch(CountDownLatch latch) { try { + // TODO: we need to find a better way to depend on the listener rather than waiting for a fixed time + // sometimes it may be take more than 1 seconds in multi-tenancy case where we need to + // create index, create a master key and then perform the prediction. boolean completed = latch.await(3, SECONDS); if (!completed) { throw new MLException("Fetching master key timed out."); @@ -212,9 +151,11 @@ private void initMasterKey(String tenantId) { } catch (InterruptedException e) { throw new IllegalStateException(e); } + } + private void checkMasterKeyInitialization(String tenantId, AtomicReference exceptionRef) { if (exceptionRef.get() != null) { - log.debug("Failed to init master key", exceptionRef.get()); + log.debug("Failed to init master key for tenant {}", tenantId, exceptionRef.get()); if (exceptionRef.get() instanceof RuntimeException) { throw (RuntimeException) exceptionRef.get(); } else { @@ -225,4 +166,264 @@ private void initMasterKey(String tenantId) { throw new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR); } } + + private ActionListener createInitMLConfigIndexListener( + AtomicReference exceptionRef, + CountDownLatch latch, + String tenantId, + String masterKeyId + ) { + return ActionListener + .wrap( + r -> handleInitMLConfigIndexSuccess(exceptionRef, latch, tenantId, masterKeyId), + e -> handleInitMLConfigIndexFailure(exceptionRef, latch, masterKeyId, e) + ); + } + + private void handleInitMLConfigIndexSuccess( + AtomicReference exceptionRef, + CountDownLatch latch, + String tenantId, + String masterKeyId + ) { + FetchSourceContext fetchSourceContext = new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY); + GetDataObjectRequest getDataObjectRequest = createGetDataObjectRequest(tenantId, fetchSourceContext); + + try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + sdkClient + .getDataObjectAsync(getDataObjectRequest) + .whenComplete( + (response, throwable) -> handleGetDataObjectResponse( + tenantId, + masterKeyId, + context, + response, + throwable, + exceptionRef, + latch + ) + ); + } + } + + private void handleInitMLConfigIndexFailure( + AtomicReference exceptionRef, + CountDownLatch latch, + String masterKeyId, + Exception e + ) { + log.debug("Failed to init ML config index", e); + exceptionRef.set(new RuntimeException("No response to create ML Config index")); + latch.countDown(); + } + + private void handleGetDataObjectResponse( + String tenantId, + String masterKeyId, + ThreadContext.StoredContext context, + GetDataObjectResponse response, + Throwable throwable, + AtomicReference exceptionRef, + CountDownLatch latch + ) { + log.debug("Completed Get MASTER_KEY Request, for tenant id:{}", tenantId); + + if (throwable != null) { + handleGetDataObjectFailure(throwable, exceptionRef, latch); + } else { + handleGetDataObjectSuccess(response, tenantId, masterKeyId, exceptionRef, latch, context); + } + context.restore(); + } + + private void handleGetDataObjectFailure(Throwable throwable, AtomicReference exceptionRef, CountDownLatch latch) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class); + log.debug("Failed to get ML encryption master key from config index", cause); + exceptionRef.set(cause); + latch.countDown(); + } + + private void handleGetDataObjectSuccess( + GetDataObjectResponse response, + String tenantId, + String masterKeyId, + AtomicReference exceptionRef, + CountDownLatch latch, + ThreadContext.StoredContext context + ) { + try { + GetResponse getMasterKeyResponse = response.parser() == null ? null : GetResponse.fromXContent(response.parser()); + if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) { + this.tenantMasterKeys + .put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), (String) response.source().get(masterKeyId)); + log.info("ML encryption master key already initialized, no action needed"); + latch.countDown(); + } else { + initializeNewMasterKey(tenantId, masterKeyId, exceptionRef, latch, context); + } + } catch (Exception e) { + log.debug("Failed to get ML encryption master key from config index", e); + exceptionRef.set(e); + latch.countDown(); + } + } + + private void initializeNewMasterKey( + String tenantId, + String masterKeyId, + AtomicReference exceptionRef, + CountDownLatch latch, + ThreadContext.StoredContext context + ) { + final String generatedMasterKey = generateMasterKey(); + sdkClient + .putDataObjectAsync(createPutDataObjectRequest(tenantId, masterKeyId, generatedMasterKey)) + .whenComplete((putDataObjectResponse, throwable1) -> { + try { + handlePutDataObjectResponse( + tenantId, + masterKeyId, + context, + putDataObjectResponse, + throwable1, + exceptionRef, + latch, + generatedMasterKey + ); + } catch (IOException e) { + log.debug("Failed to index ML encryption master key to config index", e); + exceptionRef.set(e); + latch.countDown(); + } + }); + } + + private PutDataObjectRequest createPutDataObjectRequest(String tenantId, String masterKeyId, String generatedMasterKey) { + return PutDataObjectRequest + .builder() + .tenantId(tenantId) + .index(ML_CONFIG_INDEX) + .id(masterKeyId) + .overwriteIfExists(false) + .dataObject( + Map + .of( + MASTER_KEY, + generatedMasterKey, + CREATE_TIME_FIELD, + Instant.now().toEpochMilli(), + TENANT_ID_FIELD, + Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID) + ) + ) + .build(); + } + + private void handlePutDataObjectResponse( + String tenantId, + String masterKeyId, + ThreadContext.StoredContext context, + PutDataObjectResponse putDataObjectResponse, + Throwable throwable, + AtomicReference exceptionRef, + CountDownLatch latch, + String generatedMasterKey + ) throws IOException { + context.restore(); + + if (throwable != null) { + handlePutDataObjectFailure(tenantId, masterKeyId, context, throwable, exceptionRef, latch); + } else { + IndexResponse indexResponse = IndexResponse.fromXContent(putDataObjectResponse.parser()); + log.info("Master key creation result: {}, Master key id: {}", indexResponse.getResult(), indexResponse.getId()); + this.tenantMasterKeys.put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), generatedMasterKey); + log.info("ML encryption master key initialized successfully"); + latch.countDown(); + } + } + + private void handlePutDataObjectFailure( + String tenantId, + String masterKeyId, + ThreadContext.StoredContext context, + Throwable throwable, + AtomicReference exceptionRef, + CountDownLatch latch + ) { + Exception cause = SdkClientUtils.unwrapAndConvertToException(throwable, OpenSearchStatusException.class); + if (cause instanceof VersionConflictEngineException) { + handleVersionConflict(tenantId, masterKeyId, context, exceptionRef, latch); + } else { + log.debug("Failed to index ML encryption master key to config index", cause); + exceptionRef.set(cause); + latch.countDown(); + } + } + + private void handleVersionConflict( + String tenantId, + String masterKeyId, + ThreadContext.StoredContext context, + AtomicReference exceptionRef, + CountDownLatch latch + ) { + sdkClient + .getDataObjectAsync( + createGetDataObjectRequest(tenantId, new FetchSourceContext(true, Strings.EMPTY_ARRAY, Strings.EMPTY_ARRAY)) + ) + .whenComplete((response, throwable) -> { + try { + handleVersionConflictResponse(tenantId, masterKeyId, context, response, throwable, exceptionRef, latch); + } catch (IOException e) { + log.debug("Failed to get ML encryption master key from config index", e); + exceptionRef.set(e); + latch.countDown(); + } + }); + } + + private GetDataObjectRequest createGetDataObjectRequest(String tenantId, FetchSourceContext fetchSourceContext) { + String masterKeyId = MASTER_KEY; + if (tenantId != null) { + masterKeyId = MASTER_KEY + "_" + hashString(tenantId); + } + return GetDataObjectRequest + .builder() + .index(ML_CONFIG_INDEX) + .id(masterKeyId) + .tenantId(tenantId) + .fetchSourceContext(fetchSourceContext) + .build(); + } + + private void handleVersionConflictResponse( + String tenantId, + String masterKeyId, + ThreadContext.StoredContext context, + GetDataObjectResponse response1, + Throwable throwable2, + AtomicReference exceptionRef, + CountDownLatch latch + ) throws IOException { + context.restore(); + log.debug("Completed Get config item"); + + if (throwable2 != null) { + Exception cause1 = SdkClientUtils.unwrapAndConvertToException(throwable2, OpenSearchStatusException.class); + log.debug("Failed to get ML encryption master key from config index", cause1); + exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); + latch.countDown(); + } else { + GetResponse getMasterKeyResponse = response1.parser() == null ? null : GetResponse.fromXContent(response1.parser()); + if (getMasterKeyResponse != null && getMasterKeyResponse.isExists()) { + this.tenantMasterKeys + .put(Objects.requireNonNullElse(tenantId, DEFAULT_TENANT_ID), (String) response1.source().get(masterKeyId)); + log.info("ML encryption master key already initialized, no action needed"); + latch.countDown(); + } else { + exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); + latch.countDown(); + } + } + } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java index 415b6c72e3..f00ee3cd17 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/encryptor/EncryptorImplTest.java @@ -14,6 +14,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Collections; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -40,11 +41,14 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.engine.VersionConflictEngineException; import org.opensearch.index.get.GetResult; import org.opensearch.ml.common.exception.MLException; import org.opensearch.ml.engine.indices.MLIndicesHandler; +import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.remote.metadata.client.impl.SdkClientFactory; import org.opensearch.threadpool.ThreadPool; import com.google.common.collect.ImmutableMap; @@ -54,6 +58,7 @@ public class EncryptorImplTest { public ExpectedException exceptionRule = ExpectedException.none(); @Mock Client client; + SdkClient sdkClient; @Mock ClusterService clusterService; @@ -79,6 +84,7 @@ public void setUp() { MockitoAnnotations.openMocks(this); masterKey = new ConcurrentHashMap<>(); masterKey.put(DEFAULT_TENANT_ID, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); + sdkClient = SdkClientFactory.createSdkClient(client, NamedXContentRegistry.EMPTY, Collections.emptyMap()); doAnswer(invocation -> { ActionListener listener = invocation.getArgument(1); @@ -119,7 +125,7 @@ public void setUp() { threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, USER_STRING); when(client.threadPool()).thenReturn(threadPool); when(threadPool.getThreadContext()).thenReturn(threadContext); - encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); } @Test @@ -137,7 +143,7 @@ public void encrypt_ExistingMasterKey() throws IOException { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(null)); String encrypted = encryptor.encrypt("test", null); Assert.assertNotNull(encrypted); @@ -151,25 +157,26 @@ public void encrypt_NonExistingMasterKey() { actionListener.onResponse(true); return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); + IndexResponse indexResponse = prepareIndexResponse(); + doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); + actionListener.onResponse(null); return null; }).when(client).get(any(), any()); + doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - IndexResponse response = mock(IndexResponse.class); - actionListener.onResponse(response); + + actionListener.onResponse(indexResponse); return null; }).when(client).index(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); - String encrypted = encryptor.encrypt("test", TENANT_ID); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey(null)); + String encrypted = encryptor.encrypt("test", null); Assert.assertNotNull(encrypted); - Assert.assertNotEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(TENANT_ID)); + Assert.assertNotEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); } @Test @@ -183,9 +190,7 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey() { }).when(mlIndicesHandler).initMLConfigIndex(any()); doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); + actionListener.onResponse(null); return null; }).when(client).get(any(), any()); doAnswer(invocation -> { @@ -194,9 +199,9 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey() { return null; }).when(client).index(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); - encryptor.encrypt("test", TENANT_ID); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey(null)); + encryptor.encrypt("test", null); } @Test @@ -210,9 +215,7 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_NonRuntimeExceptio }).when(mlIndicesHandler).initMLConfigIndex(any()); doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); + actionListener.onResponse(null); return null; }).when(client).get(any(), any()); doAnswer(invocation -> { @@ -221,13 +224,17 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_NonRuntimeExceptio return null; }).when(client).index(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); - encryptor.encrypt("test", TENANT_ID); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); + Assert.assertNull(encryptor.getMasterKey(null)); + encryptor.encrypt("test", null); } @Test public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict() { + /** + * The context of this unit test is if there's any version conflict then we create new key, but if that fails + * again then we throw ResourceNotFoundException exception. + */ exceptionRule.expect(ResourceNotFoundException.class); exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); doAnswer(invocation -> { @@ -237,48 +244,11 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict() }).when(mlIndicesHandler).initMLConfigIndex(any()); doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); - return null; - }).doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); - return null; - }).when(client).get(any(), any()); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - actionListener - .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); - return null; - }).when(client).index(any(), any()); - - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); - encryptor.encrypt("test", TENANT_ID); - } - - @Test - public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_NullGetResponse() { - exceptionRule.expect(ResourceNotFoundException.class); - exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(0); - actionListener.onResponse(true); - return null; - }).when(mlIndicesHandler).initMLConfigIndex(any()); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); + actionListener.onResponse(null); return null; }).doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = null; - actionListener.onResponse(response); + actionListener.onFailure(new IOException("testing")); return null; }).when(client).get(any(), any()); doAnswer(invocation -> { @@ -288,46 +258,11 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_Nu return null; }).when(client).index(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); encryptor.encrypt("test", TENANT_ID); } - @Test - public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_NullResponse() { - exceptionRule.expect(ResourceNotFoundException.class); - exceptionRule.expectMessage(MASTER_KEY_NOT_READY_ERROR); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(0); - actionListener.onResponse(true); - return null; - }).when(mlIndicesHandler).initMLConfigIndex(any()); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = null; - actionListener.onResponse(response); - return null; - }).doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); - return null; - }).when(client).get(any(), any()); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - actionListener - .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); - return null; - }).when(client).index(any(), any()); - - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(null)); - String encrypted = encryptor.encrypt("test", null); - Assert.assertNotNull(encrypted); - Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); - } - @Test public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_GetExistingMasterKey() throws IOException { doAnswer(invocation -> { @@ -354,41 +289,7 @@ public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_Ge return null; }).when(client).index(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(null)); - String encrypted = encryptor.encrypt("test", null); - Assert.assertNotNull(encrypted); - Assert.assertEquals(masterKey.get(DEFAULT_TENANT_ID), encryptor.getMasterKey(null)); - } - - @Test - public void encrypt_NonExistingMasterKey_FailedToCreateNewKey_VersionConflict_FailedToGetExistingMasterKey() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("random test exception"); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(0); - actionListener.onResponse(true); - return null; - }).when(mlIndicesHandler).initMLConfigIndex(any()); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - GetResponse response = mock(GetResponse.class); - when(response.isExists()).thenReturn(false); - actionListener.onResponse(response); - return null; - }).doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - actionListener.onFailure(new RuntimeException("random test exception")); - return null; - }).when(client).get(any(), any()); - doAnswer(invocation -> { - ActionListener actionListener = (ActionListener) invocation.getArgument(1); - actionListener - .onFailure(new VersionConflictEngineException(new ShardId(ML_CONFIG_INDEX, "index_uuid", 1), "test_id", "failed")); - return null; - }).when(client).index(any(), any()); - - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(null)); String encrypted = encryptor.encrypt("test", null); Assert.assertNotNull(encrypted); @@ -400,20 +301,20 @@ public void encrypt_ThrowExceptionWhenInitMLConfigIndex() { exceptionRule.expect(RuntimeException.class); exceptionRule.expectMessage("test exception"); doThrow(new RuntimeException("test exception")).when(mlIndicesHandler).initMLConfigIndex(any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); encryptor.encrypt(masterKey.get(DEFAULT_TENANT_ID), null); } @Test public void encrypt_FailedToInitMLConfigIndex() { exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("random test exception"); + exceptionRule.expectMessage("No response to create ML Config index"); doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(0); actionListener.onFailure(new RuntimeException("random test exception")); return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); encryptor.encrypt(masterKey.get(DEFAULT_TENANT_ID), null); } @@ -431,7 +332,7 @@ public void encrypt_FailedToGetMasterKey() { actionListener.onFailure(new RuntimeException("random test exception")); return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); encryptor.encrypt(masterKey.get(DEFAULT_TENANT_ID), null); } @@ -463,7 +364,7 @@ public void decrypt() throws IOException { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(null)); String encrypted = encryptor.encrypt("test", null); String decrypted = encryptor.decrypt(encrypted, null); @@ -484,7 +385,7 @@ public void encrypt_NullMasterKey_NullMasterKey_MasterKeyNotExistInIndex() { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(null)); encryptor.encrypt("test", null); } @@ -505,15 +406,13 @@ public void decrypt_NullMasterKey_GetMasterKey_Exception() { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(null)); encryptor.decrypt("test", null); } @Test public void decrypt_NoResponseToInitConfigIndex() { - exceptionRule.expect(RuntimeException.class); - exceptionRule.expectMessage("No response to create ML Config index"); doAnswer(invocation -> { ActionListener actionListener = (ActionListener) invocation.getArgument(0); @@ -521,9 +420,18 @@ public void decrypt_NoResponseToInitConfigIndex() { return null; }).when(mlIndicesHandler).initMLConfigIndex(any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); - Assert.assertNull(encryptor.getMasterKey(null)); - encryptor.decrypt("test", null); + // Mock GetResponse to return a valid MASTER_KEY_ID for the given tenant + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + GetResponse response = prepareMLConfigResponse(TENANT_ID); // Response includes dynamic MASTER_KEY_ID + listener.onResponse(response); + return null; + }).when(client).get(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); + String encrypted = encryptor.encrypt("test", TENANT_ID); + Assert.assertNotNull(encryptor.getMasterKey(TENANT_ID)); + Assert.assertEquals("test", encryptor.decrypt(encrypted, TENANT_ID)); } @Test @@ -540,7 +448,7 @@ public void decrypt_MLConfigIndexNotFound() { return null; }).when(client).get(any(), any()); - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(null)); encryptor.decrypt("test", null); } @@ -563,7 +471,7 @@ public void initMasterKey_AddTenantMasterKeys() throws IOException { }).when(client).get(any(), any()); // Initialize Encryptor and verify no master key exists initially - Encryptor encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); Assert.assertNull(encryptor.getMasterKey(TENANT_ID)); // Encrypt using the specified tenant ID @@ -579,6 +487,51 @@ public void initMasterKey_AddTenantMasterKeys() throws IOException { Assert.assertEquals("m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w=", encryptor.getMasterKey(TENANT_ID)); } + @Test + public void encrypt_SdkClientPutDataObjectFailure() { + exceptionRule.expect(RuntimeException.class); + exceptionRule.expectMessage("Failed to index ML encryption master key"); + + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(1); + listener.onResponse(null); + return null; + }).when(client).get(any(), any()); + + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(1); + actionListener.onFailure(new RuntimeException("Failed to index ML encryption master key")); + return null; + }).when(client).index(any(), any()); + + Encryptor encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); + encryptor.encrypt("test", null); + } + + @Test + public void handleVersionConflictResponse_ShouldThrowException_WhenRetryFails() throws IOException { + doAnswer(invocation -> { + ActionListener actionListener = (ActionListener) invocation.getArgument(0); + actionListener.onResponse(true); + return null; + }).when(mlIndicesHandler).initMLConfigIndex(any()); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new IOException("Failed to get master key")); + return null; + }).when(client).get(any(), any()); + + exceptionRule.expect(MLException.class); + encryptor.encrypt("test", "someTenant"); + } + // Helper method to prepare a valid GetResponse private GetResponse prepareMLConfigResponse(String tenantId) throws IOException { // Compute the masterKeyId based on tenantId diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e8c3a936b3..824778dbd8 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -498,7 +498,6 @@ public Collection createComponents( Path configFile = environment.configFile(); mlIndicesHandler = new MLIndicesHandler(clusterService, client); - encryptor = new EncryptorImpl(clusterService, client, mlIndicesHandler); SdkClient sdkClient = SdkClientFactory .createSdkClient( @@ -523,6 +522,8 @@ public Collection createComponents( client.threadPool().executor(ThreadPool.Names.GENERIC) ); + encryptor = new EncryptorImpl(clusterService, client, sdkClient, mlIndicesHandler); + mlEngine = new MLEngine(dataPath, encryptor); nodeHelper = new DiscoveryNodeHelper(clusterService, settings); modelCacheHelper = new MLModelCacheHelper(clusterService, settings);