From a5c736efd4a4a085d624f283ea59387fc7030fcf Mon Sep 17 00:00:00 2001 From: ndr_brt Date: Thu, 4 Jul 2024 13:58:32 +0200 Subject: [PATCH] fix: avoid thread blocking in ParallelSink (#4333) --- .../DataPlaneFrameworkExtension.java | 56 +++++++---- .../dataplane/util/sink/ParallelSink.java | 58 ++++++----- .../dataplane/util/sink/ParallelSinkTest.java | 97 ++++++++++++------- 3 files changed, 131 insertions(+), 80 deletions(-) diff --git a/core/data-plane/data-plane-core/src/main/java/org/eclipse/edc/connector/dataplane/framework/DataPlaneFrameworkExtension.java b/core/data-plane/data-plane-core/src/main/java/org/eclipse/edc/connector/dataplane/framework/DataPlaneFrameworkExtension.java index 01edbcaa54d..bd92fbc8e08 100644 --- a/core/data-plane/data-plane-core/src/main/java/org/eclipse/edc/connector/dataplane/framework/DataPlaneFrameworkExtension.java +++ b/core/data-plane/data-plane-core/src/main/java/org/eclipse/edc/connector/dataplane/framework/DataPlaneFrameworkExtension.java @@ -52,46 +52,61 @@ /** * Provides core services for the Data Plane Framework. */ -@Provides({ DataPlaneManager.class, DataTransferExecutorServiceContainer.class, TransferServiceRegistry.class }) +@Provides({ DataPlaneManager.class, TransferServiceRegistry.class }) @Extension(value = DataPlaneFrameworkExtension.NAME) public class DataPlaneFrameworkExtension implements ServiceExtension { + public static final String NAME = "Data Plane Framework"; + private static final int DEFAULT_TRANSFER_THREADS = 20; - @Setting(value = "the iteration wait time in milliseconds in the data plane state machine. Default value " + DEFAULT_ITERATION_WAIT, type = "long") + @Setting( + value = "the iteration wait time in milliseconds in the data plane state machine.", + defaultValue = DEFAULT_ITERATION_WAIT + "", + type = "long") private static final String DATAPLANE_MACHINE_ITERATION_WAIT_MILLIS = "edc.dataplane.state-machine.iteration-wait-millis"; - @Setting(value = "the batch size in the data plane state machine. Default value " + DEFAULT_BATCH_SIZE, type = "int") + @Setting( + value = "the batch size in the data plane state machine.", + defaultValue = DEFAULT_BATCH_SIZE + "", + type = "int" + ) private static final String DATAPLANE_MACHINE_BATCH_SIZE = "edc.dataplane.state-machine.batch-size"; - @Setting(value = "how many times a specific operation must be tried before terminating the dataplane with error", type = "int", defaultValue = DEFAULT_SEND_RETRY_LIMIT + "") + @Setting( + value = "how many times a specific operation must be tried before terminating the dataplane with error", + defaultValue = DEFAULT_SEND_RETRY_LIMIT + "", + type = "int" + ) private static final String DATAPLANE_SEND_RETRY_LIMIT = "edc.dataplane.send.retry.limit"; - @Setting(value = "The base delay for the dataplane retry mechanism in millisecond", type = "long", defaultValue = DEFAULT_SEND_RETRY_BASE_DELAY + "") + @Setting( + value = "The base delay for the dataplane retry mechanism in millisecond", + defaultValue = DEFAULT_SEND_RETRY_BASE_DELAY + "", + type = "long" + ) private static final String DATAPLANE_SEND_RETRY_BASE_DELAY_MS = "edc.dataplane.send.retry.base-delay.ms"; - @Setting + @Setting( + value = "Size of the transfer thread pool. It is advisable to set it bigger than the state machine batch size", + defaultValue = DEFAULT_TRANSFER_THREADS + "", + type = "int" + ) private static final String TRANSFER_THREADS = "edc.dataplane.transfer.threads"; - private static final int DEFAULT_TRANSFER_THREADS = 10; + private DataPlaneManagerImpl dataPlaneManager; @Inject private TransferServiceSelectionStrategy transferServiceSelectionStrategy; - @Inject private DataPlaneStore store; - @Inject private TransferProcessApiClient transferProcessApiClient; - @Inject private ExecutorInstrumentation executorInstrumentation; - @Inject private Telemetry telemetry; - @Inject private Clock clock; - @Inject private PipelineService pipelineService; @Inject @@ -112,12 +127,6 @@ public String name() { public void initialize(ServiceExtensionContext context) { var monitor = context.getMonitor(); - var numThreads = context.getSetting(TRANSFER_THREADS, DEFAULT_TRANSFER_THREADS); - var executorService = Executors.newFixedThreadPool(numThreads); - var executorContainer = new DataTransferExecutorServiceContainer( - executorInstrumentation.instrument(executorService, "Data plane transfers")); - context.registerService(DataTransferExecutorServiceContainer.class, executorContainer); - var transferServiceRegistry = new TransferServiceRegistryImpl(transferServiceSelectionStrategy); transferServiceRegistry.registerTransferService(pipelineService); context.registerService(TransferServiceRegistry.class, transferServiceRegistry); @@ -131,6 +140,7 @@ public void initialize(ServiceExtensionContext context) { .clock(clock) .entityRetryProcessConfiguration(getEntityRetryProcessConfiguration(context)) .executorInstrumentation(executorInstrumentation) + .authorizationService(authorizationService) .transferServiceRegistry(transferServiceRegistry) .store(store) .transferProcessClient(transferProcessApiClient) @@ -154,6 +164,14 @@ public void shutdown() { } } + @Provider + public DataTransferExecutorServiceContainer dataTransferExecutorServiceContainer(ServiceExtensionContext context) { + var numThreads = context.getSetting(TRANSFER_THREADS, DEFAULT_TRANSFER_THREADS); + var executorService = Executors.newFixedThreadPool(numThreads); + return new DataTransferExecutorServiceContainer( + executorInstrumentation.instrument(executorService, "Data plane transfers")); + } + @Provider public DataPlaneAuthorizationService authorizationService(ServiceExtensionContext context) { if (authorizationService == null) { diff --git a/core/data-plane/data-plane-util/src/main/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSink.java b/core/data-plane/data-plane-util/src/main/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSink.java index 8fd0b131164..63366fd3e3c 100644 --- a/core/data-plane/data-plane-util/src/main/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSink.java +++ b/core/data-plane/data-plane-util/src/main/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSink.java @@ -17,9 +17,10 @@ import io.opentelemetry.instrumentation.annotations.WithSpan; import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSink; import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSource; +import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamFailure; import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult; +import org.eclipse.edc.spi.EdcException; import org.eclipse.edc.spi.monitor.Monitor; -import org.eclipse.edc.spi.result.AbstractResult; import org.eclipse.edc.spi.telemetry.Telemetry; import org.eclipse.edc.util.stream.PartitionIterator; import org.jetbrains.annotations.NotNull; @@ -30,10 +31,7 @@ import java.util.concurrent.ExecutorService; import java.util.function.Supplier; -import static java.lang.String.format; -import static java.util.concurrent.CompletableFuture.completedFuture; import static java.util.concurrent.CompletableFuture.supplyAsync; -import static org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult.failure; import static org.eclipse.edc.util.async.AsyncUtils.asyncAllOf; /** @@ -49,28 +47,26 @@ public abstract class ParallelSink implements DataSink { @WithSpan @Override public CompletableFuture> transfer(DataSource source) { - try { - var streamResult = source.openPartStream(); - if (streamResult.failed()) { - return completedFuture(failure(streamResult.getFailure())); - } - - try (var partStream = streamResult.getContent()) { - return PartitionIterator.streamOf(partStream, partitionSize) - .map(this::processPartsAsync) - .collect(asyncAllOf()) - .thenApply(results -> results.stream() - .filter(AbstractResult::failed) - .findFirst() - .map(r -> StreamResult.error(String.join(",", r.getFailureMessages()))) - .orElseGet(this::complete)) - .exceptionally(throwable -> StreamResult.error("Unhandled exception raised when transferring data: " + throwable.getMessage())); - } - } catch (Exception e) { - var errorMessage = format("Error processing data transfer request - Request ID: %s", requestId); - monitor.severe(errorMessage, e); - return CompletableFuture.completedFuture(StreamResult.error(errorMessage)); - } + return supplyAsync(() -> source.openPartStream().orElseThrow(StreamException::new), executorService) + .thenCompose(parts -> { + try (parts) { + return PartitionIterator.streamOf(parts, partitionSize) + .map(this::processPartsAsync) + .collect(asyncAllOf()) + .thenApply(results -> results.stream() + .filter(StreamResult::failed) + .findFirst() + .map(r -> StreamResult.failure(r.getFailure())) + .orElseGet(this::complete)); + } + }) + .exceptionally(throwable -> { + if (throwable instanceof StreamException streamException) { + return StreamResult.failure(streamException.failure); + } else { + return StreamResult.error("Error processing data transfer request - Request ID: %s. Message: %s".formatted(requestId, throwable.getMessage())); + } + }); } @NotNull @@ -95,6 +91,16 @@ protected StreamResult complete() { return StreamResult.success(); } + private static class StreamException extends EdcException { + + private final StreamFailure failure; + + StreamException(StreamFailure failure) { + super(failure.getFailureDetail()); + this.failure = failure; + } + } + protected abstract static class Builder, T extends ParallelSink> { protected T sink; diff --git a/core/data-plane/data-plane-util/src/test/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSinkTest.java b/core/data-plane/data-plane-util/src/test/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSinkTest.java index 931210c76be..78c31b2c355 100644 --- a/core/data-plane/data-plane-util/src/test/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSinkTest.java +++ b/core/data-plane/data-plane-util/src/test/java/org/eclipse/edc/connector/dataplane/util/sink/ParallelSinkTest.java @@ -14,24 +14,20 @@ package org.eclipse.edc.connector.dataplane.util.sink; -import org.assertj.core.api.Assertions; import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSource; import org.eclipse.edc.connector.dataplane.spi.pipeline.InputStreamDataSource; import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamFailure; import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult; -import org.eclipse.edc.spi.monitor.Monitor; -import org.eclipse.edc.spi.telemetry.Telemetry; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.io.ByteArrayInputStream; +import java.time.Duration; import java.util.List; -import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; import java.util.function.Supplier; +import java.util.stream.IntStream; -import static java.lang.String.format; +import static java.time.temporal.ChronoUnit.MILLIS; import static java.util.UUID.randomUUID; import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.Mockito.mock; @@ -39,40 +35,33 @@ class ParallelSinkTest { - private final Monitor monitor = mock(Monitor.class); - private final ExecutorService executor = Executors.newFixedThreadPool(2); - private final String dataSourceName = "test-datasource-name"; - private final String dataSourceContent = "test-content"; + private final Duration timeout = Duration.of(500, MILLIS); private final String errorMessage = "test-errormessage"; - private final InputStreamDataSource dataSource = new InputStreamDataSource( - dataSourceName, - new ByteArrayInputStream(dataSourceContent.getBytes())); private final String dataFlowRequestId = randomUUID().toString(); - FakeParallelSink fakeSink; - - @BeforeEach - void setup() { - fakeSink = new FakeParallelSink(); - fakeSink.monitor = monitor; - fakeSink.telemetry = new Telemetry(); // default noop implementation - fakeSink.executorService = executor; - fakeSink.requestId = dataFlowRequestId; - } + private final FakeParallelSink fakeSink = new FakeParallelSink.Builder().monitor(mock()) + .executorService(Executors.newFixedThreadPool(2)) + .requestId(dataFlowRequestId).build(); @Test void transfer_succeeds() { - assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS) - .satisfies(transferResult -> assertThat(transferResult.succeeded()).isTrue()); + var dataSource = dataSource(); - Assertions.assertThat(fakeSink.parts).containsExactly(dataSource); + var future = fakeSink.transfer(dataSource); + + assertThat(future).succeedsWithin(timeout) + .satisfies(transferResult -> assertThat(transferResult.succeeded()).isTrue()); + assertThat(fakeSink.parts).containsExactly(dataSource); assertThat(fakeSink.complete).isEqualTo(1); } @Test void transfer_whenCompleteFails_fails() { + var dataSource = dataSource(); fakeSink.completeResponse = StreamResult.error("General error"); - assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS) - .isEqualTo(fakeSink.completeResponse); + + var future = fakeSink.transfer(dataSource); + + assertThat(future).succeedsWithin(timeout).isEqualTo(fakeSink.completeResponse); } @Test @@ -81,17 +70,23 @@ void transfer_whenExceptionOpeningPartStream_fails() { when(dataSourceMock.openPartStream()).thenThrow(new RuntimeException(errorMessage)); - assertThat(fakeSink.transfer(dataSourceMock)).succeedsWithin(500, TimeUnit.MILLISECONDS) + var future = fakeSink.transfer(dataSourceMock); + + assertThat(future).succeedsWithin(timeout) .satisfies(transferResult -> assertThat(transferResult.failed()).isTrue()) - .satisfies(transferResult -> assertThat(transferResult.getFailureMessages()).containsExactly(format("Error processing data transfer request - Request ID: %s", dataFlowRequestId))); + .satisfies(transferResult -> assertThat(transferResult.getFailureDetail()) + .contains("Error processing data transfer request").contains(dataFlowRequestId).contains(errorMessage)); assertThat(fakeSink.complete).isEqualTo(0); } @Test void transfer_whenFailureDuringTransfer_fails() { + var dataSource = dataSource(); fakeSink.transferResultSupplier = () -> StreamResult.error(errorMessage); - assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS) + var future = fakeSink.transfer(dataSource); + + assertThat(future).succeedsWithin(timeout) .satisfies(transferResult -> assertThat(transferResult.failed()).isTrue()) .satisfies(transferResult -> assertThat(transferResult.getFailure().getReason()).isEqualTo(StreamFailure.Reason.GENERAL_ERROR)) .satisfies(transferResult -> assertThat(transferResult.getFailureMessages()).containsExactly(errorMessage)); @@ -102,20 +97,40 @@ void transfer_whenFailureDuringTransfer_fails() { @Test void transfer_whenExceptionDuringTransfer_fails() { + var dataSource = dataSource(); fakeSink.transferResultSupplier = () -> { throw new RuntimeException(errorMessage); }; - assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS) + var future = fakeSink.transfer(dataSource); + + assertThat(future).succeedsWithin(timeout) .satisfies(transferResult -> assertThat(transferResult.failed()).isTrue()) .satisfies(transferResult -> assertThat(transferResult.getFailure().getReason()).isEqualTo(StreamFailure.Reason.GENERAL_ERROR)) - .satisfies(transferResult -> assertThat(transferResult.getFailureMessages()) - .containsExactly("Unhandled exception raised when transferring data: java.lang.RuntimeException: " + errorMessage)); + .satisfies(transferResult -> assertThat(transferResult.getFailureDetail()) + .contains("Error processing data transfer request").contains(dataFlowRequestId).contains(errorMessage)); assertThat(fakeSink.parts).containsExactly(dataSource); assertThat(fakeSink.complete).isEqualTo(0); } + @Test + void shouldNotBlock_whenDataSourceIsIndefinite() { + var infiniteStream = IntStream.iterate(0, i -> i + 1).mapToObj(i -> mock(DataSource.Part.class)); + var dataSource = mock(DataSource.class); + when(dataSource.openPartStream()).thenReturn(StreamResult.success(infiniteStream)); + + var future = fakeSink.transfer(dataSource); + + assertThat(future).isNotNull(); + } + + private InputStreamDataSource dataSource() { + return new InputStreamDataSource( + "test-datasource-name", + new ByteArrayInputStream("test-content".getBytes())); + } + private static class FakeParallelSink extends ParallelSink { List parts; @@ -134,5 +149,17 @@ protected StreamResult complete() { complete++; return completeResponse; } + + public static class Builder extends ParallelSink.Builder { + + protected Builder() { + super(new FakeParallelSink()); + } + + @Override + protected void validate() { + + } + } } }