Skip to content

Commit c6d55e2

Browse files
concurrent-api: cleanup AsyncContext operations (#3181)
Motivation: Perhaps the main goal of the AsyncContext is to allow a way to bundle up thread-local (ish) state and restore it when an async operation occurs. However, restore operation is essentially inlined in all the context operations which makes the mechanism more difficult to understand than it needs to be. Modifications: Add a few new methods to the AsyncContextProvider - captureContext() the goal of this method is to package up context for propagation to the next continuation. This stands in contrast to the existing context() method which is used for reading the async ContextMap. - attachContext(ContextMap) this method is used to restore the local state for temporary use. It returns a new Scope type, which then reverts this local environment to what it was before the restore. Result: Cleaner async context operations.
1 parent a7e83b4 commit c6d55e2

32 files changed

+366
-586
lines changed
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/*
2+
* Copyright © 2025 Apple Inc. and the ServiceTalk project authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package io.servicetalk.concurrent.api;
17+
18+
import io.servicetalk.context.api.ContextMap;
19+
20+
import org.openjdk.jmh.annotations.Benchmark;
21+
import org.openjdk.jmh.annotations.BenchmarkMode;
22+
import org.openjdk.jmh.annotations.Fork;
23+
import org.openjdk.jmh.annotations.Measurement;
24+
import org.openjdk.jmh.annotations.Mode;
25+
import org.openjdk.jmh.annotations.OutputTimeUnit;
26+
import org.openjdk.jmh.annotations.Scope;
27+
import org.openjdk.jmh.annotations.Setup;
28+
import org.openjdk.jmh.annotations.State;
29+
import org.openjdk.jmh.annotations.Warmup;
30+
31+
import java.util.concurrent.TimeUnit;
32+
import java.util.function.Function;
33+
34+
@Fork(1)
35+
@State(Scope.Benchmark)
36+
@Warmup(iterations = 5, time = 3)
37+
@Measurement(iterations = 5, time = 3)
38+
@OutputTimeUnit(TimeUnit.NANOSECONDS)
39+
@BenchmarkMode(Mode.AverageTime)
40+
public class AsyncContextProviderBenchmark {
41+
42+
/**
43+
* gc profiling of the DefaultAsyncContextProvider shows that the Scope based detachment can be stack allocated
44+
* at least under some conditions.
45+
*
46+
* Benchmark Mode Cnt Score Error Units
47+
* AsyncContextProviderBenchmark.contextRestoreCost avgt 5 3.932 ± 0.022 ns/op
48+
* AsyncContextProviderBenchmark.contextRestoreCost:gc.alloc.rate avgt 5 ≈ 10⁻⁴ MB/sec
49+
* AsyncContextProviderBenchmark.contextRestoreCost:gc.alloc.rate.norm avgt 5 ≈ 10⁻⁶ B/op
50+
* AsyncContextProviderBenchmark.contextRestoreCost:gc.count avgt 5 ≈ 0 counts
51+
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost avgt 5 1.712 ± 0.005 ns/op
52+
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.alloc.rate avgt 5 ≈ 10⁻⁴ MB/sec
53+
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.alloc.rate.norm avgt 5 ≈ 10⁻⁷ B/op
54+
* AsyncContextProviderBenchmark.contextSaveAndRestoreCost:gc.count avgt 5 ≈ 0 counts
55+
*/
56+
57+
private static final ContextMap.Key<String> KEY = ContextMap.Key.newKey("test-key", String.class);
58+
private static final String EXPECTED = "hello, world!";
59+
60+
private Function<String, String> wrappedFunction;
61+
62+
@Setup
63+
public void setup() {
64+
// This will capture the current context
65+
wrappedFunction = AsyncContext.wrapFunction(ignored -> AsyncContext.context().get(KEY));
66+
AsyncContext.context().put(KEY, EXPECTED);
67+
}
68+
69+
@Benchmark
70+
public String contextRestoreCost() {
71+
return wrappedFunction.apply("ignored");
72+
}
73+
74+
@Benchmark
75+
public String contextSaveAndRestoreCost() {
76+
return AsyncContext.wrapFunction(Function.<String>identity()).apply("ignored");
77+
}
78+
}

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContext.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ public final class AsyncContext {
5151
private static final int STATE_INIT = 0;
5252
private static final int STATE_AUTO_ENABLED = 1;
5353
private static final int STATE_ENABLED = 2;
54+
5455
/**
5556
* Note this mechanism is racy. Currently only the {@link #disable()} method is exposed publicly and
5657
* {@link #STATE_DISABLED} is a terminal state. Because we favor going to the disabled state we don't have to worry
@@ -438,7 +439,7 @@ public static ScheduledExecutorService wrapJdkScheduledExecutorService(final Sch
438439
*/
439440
public static Runnable wrapRunnable(final Runnable runnable) {
440441
AsyncContextProvider provider = provider();
441-
return provider.wrapRunnable(runnable, provider.context());
442+
return provider.wrapRunnable(runnable, provider.captureContext());
442443
}
443444

444445
/**
@@ -449,7 +450,7 @@ public static Runnable wrapRunnable(final Runnable runnable) {
449450
*/
450451
public static <V> Callable<V> wrapCallable(final Callable<V> callable) {
451452
AsyncContextProvider provider = provider();
452-
return provider.wrapCallable(callable, provider.context());
453+
return provider.wrapCallable(callable, provider.captureContext());
453454
}
454455

455456
/**
@@ -460,7 +461,7 @@ public static <V> Callable<V> wrapCallable(final Callable<V> callable) {
460461
*/
461462
public static <T> Consumer<T> wrapConsumer(final Consumer<T> consumer) {
462463
AsyncContextProvider provider = provider();
463-
return provider.wrapConsumer(consumer, provider.context());
464+
return provider.wrapConsumer(consumer, provider.captureContext());
464465
}
465466

466467
/**
@@ -472,7 +473,7 @@ public static <T> Consumer<T> wrapConsumer(final Consumer<T> consumer) {
472473
*/
473474
public static <T, U> Function<T, U> wrapFunction(final Function<T, U> func) {
474475
AsyncContextProvider provider = provider();
475-
return provider.wrapFunction(func, provider.context());
476+
return provider.wrapFunction(func, provider.captureContext());
476477
}
477478

478479
/**
@@ -484,7 +485,7 @@ public static <T, U> Function<T, U> wrapFunction(final Function<T, U> func) {
484485
*/
485486
public static <T, U> BiConsumer<T, U> wrapBiConsume(final BiConsumer<T, U> consumer) {
486487
AsyncContextProvider provider = provider();
487-
return provider.wrapBiConsumer(consumer, provider.context());
488+
return provider.wrapBiConsumer(consumer, provider.captureContext());
488489
}
489490

490491
/**
@@ -497,7 +498,7 @@ public static <T, U> BiConsumer<T, U> wrapBiConsume(final BiConsumer<T, U> consu
497498
*/
498499
public static <T, U, V> BiFunction<T, U, V> wrapBiFunction(BiFunction<T, U, V> func) {
499500
AsyncContextProvider provider = provider();
500-
return provider.wrapBiFunction(func, provider.context());
501+
return provider.wrapBiFunction(func, provider.captureContext());
501502
}
502503

503504
/**

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContextMapThreadLocal.java

Lines changed: 0 additions & 44 deletions
This file was deleted.

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/AsyncContextProvider.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,30 @@ interface AsyncContextProvider {
4040
/**
4141
* Get the current context.
4242
*
43+
* Note that this method is for getting the {@link ContextMap} for use by the application code. For saving the
44+
* current state for crossing an async boundary see the {@link AsyncContextProvider#captureContext()} method.
45+
*
4346
* @return The current context.
4447
*/
4548
ContextMap context();
4649

50+
/**
51+
* Capture existing context in preparation for an asynchronous thread jump.
52+
*
53+
* Note that this can do more than just package up the ServiceTalk {@link AsyncContext} and could be enhanced or
54+
* wrapped to bundle up additional contexts such as the OpenTelemetry or grpc contexts.
55+
* @return the saved context state that may be restored later.
56+
*/
57+
ContextMap captureContext();
58+
59+
/**
60+
* Restore the previously saved {@link ContextMap} to the local state.
61+
* @param contextMap representing the state previously saved via {@link AsyncContextProvider#captureContext()} and
62+
* that is intended to be restored.
63+
* @return a {@link Scope} that must be closed at the end of the attachment.
64+
*/
65+
Scope attachContext(ContextMap contextMap);
66+
4767
/**
4868
* Wrap the {@link Cancellable} to ensure it is able to track {@link AsyncContext} correctly.
4969
*

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/Completable.java

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1730,7 +1730,7 @@ public final Future<Void> toFuture() {
17301730
*/
17311731
ContextMap contextForSubscribe(AsyncContextProvider provider) {
17321732
// the default behavior is to copy the map. Some operators may want to use shared map
1733-
return provider.context().copy();
1733+
return provider.captureContext().copy();
17341734
}
17351735

17361736
/**
@@ -1740,9 +1740,19 @@ ContextMap contextForSubscribe(AsyncContextProvider provider) {
17401740
* @param subscriber {@link Subscriber} to subscribe for the result.
17411741
*/
17421742
protected final void subscribeInternal(Subscriber subscriber) {
1743+
requireNonNull(subscriber);
17431744
AsyncContextProvider contextProvider = AsyncContext.provider();
17441745
ContextMap contextMap = contextForSubscribe(contextProvider);
1745-
subscribeWithContext(subscriber, contextProvider, contextMap);
1746+
Subscriber wrapped = contextProvider.wrapCancellable(subscriber, contextMap);
1747+
if (contextProvider.context() == contextMap) {
1748+
// No need to wrap as we are sharing the AsyncContext
1749+
handleSubscribe(wrapped, contextMap, contextProvider);
1750+
} else {
1751+
// Ensure that AsyncContext used for handleSubscribe() is the contextMap for the subscribe()
1752+
try (Scope unused = contextProvider.attachContext(contextMap)) {
1753+
handleSubscribe(wrapped, contextMap, contextProvider);
1754+
}
1755+
}
17461756
}
17471757

17481758
/**
@@ -2262,19 +2272,6 @@ final void delegateSubscribe(Subscriber subscriber,
22622272
handleSubscribe(subscriber, contextMap, contextProvider);
22632273
}
22642274

2265-
private void subscribeWithContext(Subscriber subscriber,
2266-
AsyncContextProvider contextProvider, ContextMap contextMap) {
2267-
requireNonNull(subscriber);
2268-
Subscriber wrapped = contextProvider.wrapCancellable(subscriber, contextMap);
2269-
if (contextProvider.context() == contextMap) {
2270-
// No need to wrap as we are sharing the AsyncContext
2271-
handleSubscribe(wrapped, contextMap, contextProvider);
2272-
} else {
2273-
// Ensure that AsyncContext used for handleSubscribe() is the contextMap for the subscribe()
2274-
contextProvider.wrapRunnable(() -> handleSubscribe(wrapped, contextMap, contextProvider), contextMap).run();
2275-
}
2276-
}
2277-
22782275
/**
22792276
* Override for {@link #handleSubscribe(CompletableSource.Subscriber)}.
22802277
* <p>

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/CompletableShareContextOnSubscribe.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ final class CompletableShareContextOnSubscribe extends AbstractNoHandleSubscribe
2626

2727
@Override
2828
ContextMap contextForSubscribe(AsyncContextProvider provider) {
29-
return provider.context();
29+
return provider.captureContext();
3030
}
3131

3232
@Override

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ContextAwareExecutorUtils.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
import java.util.List;
2323
import java.util.concurrent.Callable;
2424

25-
import static io.servicetalk.concurrent.api.DefaultAsyncContextProvider.INSTANCE;
26-
2725
final class ContextAwareExecutorUtils {
2826

2927
private ContextAwareExecutorUtils() {
@@ -32,7 +30,7 @@ private ContextAwareExecutorUtils() {
3230

3331
static <X> Collection<? extends Callable<X>> wrap(Collection<? extends Callable<X>> tasks) {
3432
List<Callable<X>> wrappedTasks = new ArrayList<>(tasks.size());
35-
ContextMap contextMap = INSTANCE.context();
33+
ContextMap contextMap = AsyncContext.provider().captureContext();
3634
for (Callable<X> task : tasks) {
3735
wrappedTasks.add(new ContextPreservingCallable<>(task, contextMap));
3836
}

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ContextPreservingBiConsumer.java

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
package io.servicetalk.concurrent.api;
1717

1818
import io.servicetalk.context.api.ContextMap;
19-
import io.servicetalk.context.api.ContextMapHolder;
2019

2120
import java.util.function.BiConsumer;
2221

23-
import static io.servicetalk.concurrent.api.AsyncContextMapThreadLocal.CONTEXT_THREAD_LOCAL;
2422
import static java.util.Objects.requireNonNull;
2523

2624
final class ContextPreservingBiConsumer<T, U> implements BiConsumer<T, U> {
@@ -34,28 +32,8 @@ final class ContextPreservingBiConsumer<T, U> implements BiConsumer<T, U> {
3432

3533
@Override
3634
public void accept(T t, U u) {
37-
final Thread currentThread = Thread.currentThread();
38-
if (currentThread instanceof ContextMapHolder) {
39-
final ContextMapHolder asyncContextMapHolder = (ContextMapHolder) currentThread;
40-
ContextMap prev = asyncContextMapHolder.context();
41-
try {
42-
asyncContextMapHolder.context(saved);
43-
delegate.accept(t, u);
44-
} finally {
45-
asyncContextMapHolder.context(prev);
46-
}
47-
} else {
48-
slowPath(t, u);
49-
}
50-
}
51-
52-
private void slowPath(T t, U u) {
53-
ContextMap prev = CONTEXT_THREAD_LOCAL.get();
54-
try {
55-
CONTEXT_THREAD_LOCAL.set(saved);
35+
try (Scope ignored = AsyncContext.provider().attachContext(saved)) {
5636
delegate.accept(t, u);
57-
} finally {
58-
CONTEXT_THREAD_LOCAL.set(prev);
5937
}
6038
}
6139
}

servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ContextPreservingBiFunction.java

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,9 @@
1616
package io.servicetalk.concurrent.api;
1717

1818
import io.servicetalk.context.api.ContextMap;
19-
import io.servicetalk.context.api.ContextMapHolder;
2019

2120
import java.util.function.BiFunction;
2221

23-
import static io.servicetalk.concurrent.api.AsyncContextMapThreadLocal.CONTEXT_THREAD_LOCAL;
2422
import static java.util.Objects.requireNonNull;
2523

2624
final class ContextPreservingBiFunction<T, U, V> implements BiFunction<T, U, V> {
@@ -34,28 +32,8 @@ final class ContextPreservingBiFunction<T, U, V> implements BiFunction<T, U, V>
3432

3533
@Override
3634
public V apply(T t, U u) {
37-
final Thread currentThread = Thread.currentThread();
38-
if (currentThread instanceof ContextMapHolder) {
39-
final ContextMapHolder asyncContextMapHolder = (ContextMapHolder) currentThread;
40-
ContextMap prev = asyncContextMapHolder.context();
41-
try {
42-
asyncContextMapHolder.context(saved);
43-
return delegate.apply(t, u);
44-
} finally {
45-
asyncContextMapHolder.context(prev);
46-
}
47-
} else {
48-
return slowPath(t, u);
49-
}
50-
}
51-
52-
private V slowPath(T t, U u) {
53-
ContextMap prev = CONTEXT_THREAD_LOCAL.get();
54-
try {
55-
CONTEXT_THREAD_LOCAL.set(saved);
35+
try (Scope ignored = AsyncContext.provider().attachContext(saved)) {
5636
return delegate.apply(t, u);
57-
} finally {
58-
CONTEXT_THREAD_LOCAL.set(prev);
5937
}
6038
}
6139
}

0 commit comments

Comments
 (0)