Skip to content

Commit 264932e

Browse files
authored
Merge pull request #20 from fretje/parallel
Add "UseTaskWhenAll" option
2 parents dca93c5 + a7b31e8 commit 264932e

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Added documentation comments to `ICourier`, `CourierInjector` and `CourierOptions`
99
- Use Invoke iso DynamicInvoke for better performance
1010
- Fix bug where RemoveWeak would sometimes remove too many handlers
11+
- Add `UseTaskWhenAll` option for parallel notification handling
1112

1213
# 6.0.0
1314
- Update MediatR to version 12

MediatR.Courier.Tests/CourierTests.cs

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,68 @@ async Task NotificationAction(TestNotification _, CancellationToken cancellation
119119
Assert.Equal(1, receivedMessageCount);
120120
}
121121

122+
[Fact]
123+
public async Task Handlers_RunInParallel_WhenUseTaskWhenAllIsTrue()
124+
{
125+
var options = new CourierOptions { UseTaskWhenAll = true };
126+
var mediatRCourier = new MediatRCourier(options);
127+
128+
var delays = new[] { 100, 200, 300 };
129+
var started = new bool[delays.Length];
130+
var completed = new bool[delays.Length];
131+
132+
for (int i = 0; i < delays.Length; i++)
133+
{
134+
int idx = i;
135+
mediatRCourier.Subscribe<TestNotification>(async (_, ct) =>
136+
{
137+
started[idx] = true;
138+
await Task.Delay(delays[idx], ct);
139+
completed[idx] = true;
140+
});
141+
}
142+
143+
var sw = System.Diagnostics.Stopwatch.StartNew();
144+
await mediatRCourier.Handle(new TestNotification(), CancellationToken.None);
145+
sw.Stop();
146+
147+
Assert.All(started, Assert.True);
148+
Assert.All(completed, Assert.True);
149+
// Should complete in just over the max delay (parallel)
150+
Assert.InRange(sw.ElapsedMilliseconds, delays.Max(), delays.Max() + 150);
151+
}
152+
153+
[Fact]
154+
public async Task Handlers_RunSequentially_WhenUseTaskWhenAllIsFalse()
155+
{
156+
var options = new CourierOptions { CaptureThreadContext = false, UseTaskWhenAll = false };
157+
var mediatRCourier = new MediatRCourier(options);
158+
159+
var delays = new[] { 100, 200, 300 };
160+
var started = new bool[delays.Length];
161+
var completed = new bool[delays.Length];
162+
163+
for (int i = 0; i < delays.Length; i++)
164+
{
165+
int idx = i;
166+
mediatRCourier.Subscribe<TestNotification>(async (_, ct) =>
167+
{
168+
started[idx] = true;
169+
await Task.Delay(delays[idx], ct);
170+
completed[idx] = true;
171+
});
172+
}
173+
174+
var sw = System.Diagnostics.Stopwatch.StartNew();
175+
await mediatRCourier.Handle(new TestNotification(), CancellationToken.None);
176+
sw.Stop();
177+
178+
Assert.All(started, s => Assert.True(s));
179+
Assert.All(completed, c => Assert.True(c));
180+
// Should complete in just over the sum of delays (sequential)
181+
Assert.InRange(sw.ElapsedMilliseconds, delays.Sum(), delays.Sum() + 150);
182+
}
183+
122184
private sealed class AsyncTestData : IEnumerable<object[]>
123185
{
124186
public IEnumerator<object[]> GetEnumerator()

MediatR.Courier/CourierOptions.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,12 @@ public sealed class CourierOptions
1111
/// Defaults to false.
1212
/// </summary>
1313
public bool CaptureThreadContext { get; set; }
14+
15+
/// <summary>
16+
/// Gets or sets a value indicating whether to use Task.WhenAll to await handler tasks concurrently.
17+
/// When set to true, all notification handlers are collected and awaited using Task.WhenAll.
18+
/// When set to false, notification handlers are awaited sequentially.
19+
/// Defaults to false.
20+
/// </summary>
21+
public bool UseTaskWhenAll { get; set; }
1422
}

MediatR.Courier/MediatRCourier.cs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ async Task HandleLocal(INotification n, CancellationToken c)
2222
if (!_weakActions.TryGetValue(notificationType, out var weakSubscribers)) weakSubscribers = new();
2323

2424
var remainingSubscribers = new ConcurrentBag<(WeakReference<object> target, MethodInfo methodInfo, bool needsToken)>();
25+
var tasks = new List<Task>();
2526

2627
foreach (var (target, methodInfo, needsToken) in weakSubscribers)
2728
{
@@ -32,7 +33,17 @@ async Task HandleLocal(INotification n, CancellationToken c)
3233
: new object[] { n };
3334

3435
var result = methodInfo.Invoke(handler, parameters);
35-
if (result is Task task) await task.ConfigureAwait(_options.CaptureThreadContext);
36+
if (result is Task task)
37+
{
38+
if (_options.UseTaskWhenAll)
39+
{
40+
tasks.Add(task);
41+
}
42+
else
43+
{
44+
await task.ConfigureAwait(_options.CaptureThreadContext);
45+
}
46+
}
3647

3748
remainingSubscribers.Add((target, methodInfo, needsToken));
3849
}
@@ -48,7 +59,22 @@ async Task HandleLocal(INotification n, CancellationToken c)
4859
: new object[] { n };
4960

5061
var result = action.Method.Invoke(action.Target, parameters);
51-
if (result is Task task) await task.ConfigureAwait(_options.CaptureThreadContext);
62+
if (result is Task task)
63+
{
64+
if (_options.UseTaskWhenAll)
65+
{
66+
tasks.Add(task);
67+
}
68+
else
69+
{
70+
await task.ConfigureAwait(_options.CaptureThreadContext);
71+
}
72+
}
73+
}
74+
75+
if (_options.UseTaskWhenAll && tasks.Count > 0)
76+
{
77+
await Task.WhenAll(tasks).ConfigureAwait(_options.CaptureThreadContext);
5278
}
5379
}
5480

0 commit comments

Comments
 (0)