Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: avoid thread blocking in ParallelSink #4333

Merged
merged 1 commit into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,46 +52,61 @@
/**
* Provides core services for the Data Plane Framework.
*/
@Provides({ DataPlaneManager.class, DataTransferExecutorServiceContainer.class, TransferServiceRegistry.class })
@Provides({ DataPlaneManager.class, TransferServiceRegistry.class })
@Extension(value = DataPlaneFrameworkExtension.NAME)
public class DataPlaneFrameworkExtension implements ServiceExtension {

public static final String NAME = "Data Plane Framework";
private static final int DEFAULT_TRANSFER_THREADS = 20;

@Setting(value = "the iteration wait time in milliseconds in the data plane state machine. Default value " + DEFAULT_ITERATION_WAIT, type = "long")
@Setting(
value = "the iteration wait time in milliseconds in the data plane state machine.",
defaultValue = DEFAULT_ITERATION_WAIT + "",
type = "long")
private static final String DATAPLANE_MACHINE_ITERATION_WAIT_MILLIS = "edc.dataplane.state-machine.iteration-wait-millis";

@Setting(value = "the batch size in the data plane state machine. Default value " + DEFAULT_BATCH_SIZE, type = "int")
@Setting(
value = "the batch size in the data plane state machine.",
defaultValue = DEFAULT_BATCH_SIZE + "",
type = "int"
)
private static final String DATAPLANE_MACHINE_BATCH_SIZE = "edc.dataplane.state-machine.batch-size";

@Setting(value = "how many times a specific operation must be tried before terminating the dataplane with error", type = "int", defaultValue = DEFAULT_SEND_RETRY_LIMIT + "")
@Setting(
value = "how many times a specific operation must be tried before terminating the dataplane with error",
defaultValue = DEFAULT_SEND_RETRY_LIMIT + "",
type = "int"
)
private static final String DATAPLANE_SEND_RETRY_LIMIT = "edc.dataplane.send.retry.limit";

@Setting(value = "The base delay for the dataplane retry mechanism in millisecond", type = "long", defaultValue = DEFAULT_SEND_RETRY_BASE_DELAY + "")
@Setting(
value = "The base delay for the dataplane retry mechanism in millisecond",
defaultValue = DEFAULT_SEND_RETRY_BASE_DELAY + "",
type = "long"
)
private static final String DATAPLANE_SEND_RETRY_BASE_DELAY_MS = "edc.dataplane.send.retry.base-delay.ms";

@Setting
@Setting(
value = "Size of the transfer thread pool. It is advisable to set it bigger than the state machine batch size",
defaultValue = DEFAULT_TRANSFER_THREADS + "",
type = "int"
)
private static final String TRANSFER_THREADS = "edc.dataplane.transfer.threads";
private static final int DEFAULT_TRANSFER_THREADS = 10;

private DataPlaneManagerImpl dataPlaneManager;

@Inject
private TransferServiceSelectionStrategy transferServiceSelectionStrategy;

@Inject
private DataPlaneStore store;

@Inject
private TransferProcessApiClient transferProcessApiClient;

@Inject
private ExecutorInstrumentation executorInstrumentation;

@Inject
private Telemetry telemetry;

@Inject
private Clock clock;

@Inject
private PipelineService pipelineService;
@Inject
Expand All @@ -112,12 +127,6 @@ public String name() {
public void initialize(ServiceExtensionContext context) {
var monitor = context.getMonitor();

var numThreads = context.getSetting(TRANSFER_THREADS, DEFAULT_TRANSFER_THREADS);
var executorService = Executors.newFixedThreadPool(numThreads);
var executorContainer = new DataTransferExecutorServiceContainer(
executorInstrumentation.instrument(executorService, "Data plane transfers"));
context.registerService(DataTransferExecutorServiceContainer.class, executorContainer);

var transferServiceRegistry = new TransferServiceRegistryImpl(transferServiceSelectionStrategy);
transferServiceRegistry.registerTransferService(pipelineService);
context.registerService(TransferServiceRegistry.class, transferServiceRegistry);
Expand All @@ -131,6 +140,7 @@ public void initialize(ServiceExtensionContext context) {
.clock(clock)
.entityRetryProcessConfiguration(getEntityRetryProcessConfiguration(context))
.executorInstrumentation(executorInstrumentation)
.authorizationService(authorizationService)
.transferServiceRegistry(transferServiceRegistry)
.store(store)
.transferProcessClient(transferProcessApiClient)
Expand All @@ -154,6 +164,14 @@ public void shutdown() {
}
}

@Provider
public DataTransferExecutorServiceContainer dataTransferExecutorServiceContainer(ServiceExtensionContext context) {
var numThreads = context.getSetting(TRANSFER_THREADS, DEFAULT_TRANSFER_THREADS);
var executorService = Executors.newFixedThreadPool(numThreads);
return new DataTransferExecutorServiceContainer(
executorInstrumentation.instrument(executorService, "Data plane transfers"));
}

@Provider
public DataPlaneAuthorizationService authorizationService(ServiceExtensionContext context) {
if (authorizationService == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import io.opentelemetry.instrumentation.annotations.WithSpan;
import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSink;
import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSource;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamFailure;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult;
import org.eclipse.edc.spi.EdcException;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.result.AbstractResult;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.eclipse.edc.util.stream.PartitionIterator;
import org.jetbrains.annotations.NotNull;
Expand All @@ -30,10 +31,7 @@
import java.util.concurrent.ExecutorService;
import java.util.function.Supplier;

import static java.lang.String.format;
import static java.util.concurrent.CompletableFuture.completedFuture;
import static java.util.concurrent.CompletableFuture.supplyAsync;
import static org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult.failure;
import static org.eclipse.edc.util.async.AsyncUtils.asyncAllOf;

/**
Expand All @@ -49,28 +47,26 @@ public abstract class ParallelSink implements DataSink {
@WithSpan
@Override
public CompletableFuture<StreamResult<Object>> transfer(DataSource source) {
try {
var streamResult = source.openPartStream();
if (streamResult.failed()) {
return completedFuture(failure(streamResult.getFailure()));
}

try (var partStream = streamResult.getContent()) {
return PartitionIterator.streamOf(partStream, partitionSize)
.map(this::processPartsAsync)
.collect(asyncAllOf())
.thenApply(results -> results.stream()
.filter(AbstractResult::failed)
.findFirst()
.map(r -> StreamResult.<Object>error(String.join(",", r.getFailureMessages())))
.orElseGet(this::complete))
.exceptionally(throwable -> StreamResult.error("Unhandled exception raised when transferring data: " + throwable.getMessage()));
}
} catch (Exception e) {
var errorMessage = format("Error processing data transfer request - Request ID: %s", requestId);
monitor.severe(errorMessage, e);
return CompletableFuture.completedFuture(StreamResult.error(errorMessage));
}
return supplyAsync(() -> source.openPartStream().orElseThrow(StreamException::new), executorService)
.thenCompose(parts -> {
try (parts) {
return PartitionIterator.streamOf(parts, partitionSize)
.map(this::processPartsAsync)
.collect(asyncAllOf())
.thenApply(results -> results.stream()
.filter(StreamResult::failed)
.findFirst()
.map(r -> StreamResult.failure(r.getFailure()))
.orElseGet(this::complete));
}
})
.exceptionally(throwable -> {
if (throwable instanceof StreamException streamException) {
return StreamResult.failure(streamException.failure);
} else {
return StreamResult.error("Error processing data transfer request - Request ID: %s. Message: %s".formatted(requestId, throwable.getMessage()));
}
});
}

@NotNull
Expand All @@ -95,6 +91,16 @@ protected StreamResult<Object> complete() {
return StreamResult.success();
}

private static class StreamException extends EdcException {

private final StreamFailure failure;

StreamException(StreamFailure failure) {
super(failure.getFailureDetail());
this.failure = failure;
}
}

protected abstract static class Builder<B extends Builder<B, T>, T extends ParallelSink> {
protected T sink;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,65 +14,54 @@

package org.eclipse.edc.connector.dataplane.util.sink;

import org.assertj.core.api.Assertions;
import org.eclipse.edc.connector.dataplane.spi.pipeline.DataSource;
import org.eclipse.edc.connector.dataplane.spi.pipeline.InputStreamDataSource;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamFailure;
import org.eclipse.edc.connector.dataplane.spi.pipeline.StreamResult;
import org.eclipse.edc.spi.monitor.Monitor;
import org.eclipse.edc.spi.telemetry.Telemetry;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import java.util.stream.IntStream;

import static java.lang.String.format;
import static java.time.temporal.ChronoUnit.MILLIS;
import static java.util.UUID.randomUUID;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

class ParallelSinkTest {

private final Monitor monitor = mock(Monitor.class);
private final ExecutorService executor = Executors.newFixedThreadPool(2);
private final String dataSourceName = "test-datasource-name";
private final String dataSourceContent = "test-content";
private final Duration timeout = Duration.of(500, MILLIS);
private final String errorMessage = "test-errormessage";
private final InputStreamDataSource dataSource = new InputStreamDataSource(
dataSourceName,
new ByteArrayInputStream(dataSourceContent.getBytes()));
private final String dataFlowRequestId = randomUUID().toString();
FakeParallelSink fakeSink;

@BeforeEach
void setup() {
fakeSink = new FakeParallelSink();
fakeSink.monitor = monitor;
fakeSink.telemetry = new Telemetry(); // default noop implementation
fakeSink.executorService = executor;
fakeSink.requestId = dataFlowRequestId;
}
private final FakeParallelSink fakeSink = new FakeParallelSink.Builder().monitor(mock())
.executorService(Executors.newFixedThreadPool(2))
.requestId(dataFlowRequestId).build();

@Test
void transfer_succeeds() {
assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
.satisfies(transferResult -> assertThat(transferResult.succeeded()).isTrue());
var dataSource = dataSource();

Assertions.assertThat(fakeSink.parts).containsExactly(dataSource);
var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.succeeded()).isTrue());
assertThat(fakeSink.parts).containsExactly(dataSource);
assertThat(fakeSink.complete).isEqualTo(1);
}

@Test
void transfer_whenCompleteFails_fails() {
var dataSource = dataSource();
fakeSink.completeResponse = StreamResult.error("General error");
assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
.isEqualTo(fakeSink.completeResponse);

var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout).isEqualTo(fakeSink.completeResponse);
}

@Test
Expand All @@ -81,17 +70,23 @@ void transfer_whenExceptionOpeningPartStream_fails() {

when(dataSourceMock.openPartStream()).thenThrow(new RuntimeException(errorMessage));

assertThat(fakeSink.transfer(dataSourceMock)).succeedsWithin(500, TimeUnit.MILLISECONDS)
var future = fakeSink.transfer(dataSourceMock);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.failed()).isTrue())
.satisfies(transferResult -> assertThat(transferResult.getFailureMessages()).containsExactly(format("Error processing data transfer request - Request ID: %s", dataFlowRequestId)));
.satisfies(transferResult -> assertThat(transferResult.getFailureDetail())
.contains("Error processing data transfer request").contains(dataFlowRequestId).contains(errorMessage));
assertThat(fakeSink.complete).isEqualTo(0);
}

@Test
void transfer_whenFailureDuringTransfer_fails() {
var dataSource = dataSource();
fakeSink.transferResultSupplier = () -> StreamResult.error(errorMessage);

assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.failed()).isTrue())
.satisfies(transferResult -> assertThat(transferResult.getFailure().getReason()).isEqualTo(StreamFailure.Reason.GENERAL_ERROR))
.satisfies(transferResult -> assertThat(transferResult.getFailureMessages()).containsExactly(errorMessage));
Expand All @@ -102,20 +97,40 @@ void transfer_whenFailureDuringTransfer_fails() {

@Test
void transfer_whenExceptionDuringTransfer_fails() {
var dataSource = dataSource();
fakeSink.transferResultSupplier = () -> {
throw new RuntimeException(errorMessage);
};

assertThat(fakeSink.transfer(dataSource)).succeedsWithin(500, TimeUnit.MILLISECONDS)
var future = fakeSink.transfer(dataSource);

assertThat(future).succeedsWithin(timeout)
.satisfies(transferResult -> assertThat(transferResult.failed()).isTrue())
.satisfies(transferResult -> assertThat(transferResult.getFailure().getReason()).isEqualTo(StreamFailure.Reason.GENERAL_ERROR))
.satisfies(transferResult -> assertThat(transferResult.getFailureMessages())
.containsExactly("Unhandled exception raised when transferring data: java.lang.RuntimeException: " + errorMessage));
.satisfies(transferResult -> assertThat(transferResult.getFailureDetail())
.contains("Error processing data transfer request").contains(dataFlowRequestId).contains(errorMessage));

assertThat(fakeSink.parts).containsExactly(dataSource);
assertThat(fakeSink.complete).isEqualTo(0);
}

@Test
void shouldNotBlock_whenDataSourceIsIndefinite() {
var infiniteStream = IntStream.iterate(0, i -> i + 1).mapToObj(i -> mock(DataSource.Part.class));
var dataSource = mock(DataSource.class);
when(dataSource.openPartStream()).thenReturn(StreamResult.success(infiniteStream));

var future = fakeSink.transfer(dataSource);

assertThat(future).isNotNull();
}

private InputStreamDataSource dataSource() {
return new InputStreamDataSource(
"test-datasource-name",
new ByteArrayInputStream("test-content".getBytes()));
}

private static class FakeParallelSink extends ParallelSink {

List<DataSource.Part> parts;
Expand All @@ -134,5 +149,17 @@ protected StreamResult<Object> complete() {
complete++;
return completeResponse;
}

public static class Builder extends ParallelSink.Builder<Builder, FakeParallelSink> {

protected Builder() {
super(new FakeParallelSink());
}

@Override
protected void validate() {

}
}
}
}
Loading