Skip to content

Commit

Permalink
✨ Improve and simplify output writer thread handling
Browse files Browse the repository at this point in the history
Expose a straight PipeWriter over the websocket by directly adapting it to Stream so we can use the biult-in PipeWriter.Create over a websocket. Note that we depend on knowledge of how the default pipe writer does its writing, which is by just invoking `WriteAsync` and `FlushAsync` on it. So we leave all other `Stream` members unimplemented.

The `StreamPipeWriter` invokes `FlushAsync` automatically whenever there are bytes written, so it's impossible to distinguish writing partial messages from full messages. This seems like an acceptable trade-off anyway, so we just assume the same (`WriteAsync` == full EndOfMessage for WebSocket message).

Now the `RunAsync` task deals exclusively with the incoming side of the websocket, and it's much easier to reason about it (including whether it's worth creating another thread to keep it running or not).

Fixes #5
  • Loading branch information
kzu committed Sep 23, 2021
1 parent 8404e26 commit 3956e9e
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 82 deletions.
6 changes: 3 additions & 3 deletions src/Tests/SimpleWebSocketPipeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ public record SimpleWebSocketPipeTests(ITestOutputHelper Output)
public async Task WhenWebSocketNotOpen_ThenThrowsAsync()
{
IWebSocketPipe pipe = WebSocketPipe.Create(new ClientWebSocket());
await Assert.ThrowsAsync<InvalidOperationException>(() => pipe.RunAsync().AsTask());
await Assert.ThrowsAsync<InvalidOperationException>(() => pipe.RunAsync());
}

[Fact]
Expand All @@ -26,7 +26,7 @@ public async Task WhenConnected_ThenRuns()
using var pipe = WebSocketPipe.Create(socket);

await Task.WhenAll(
pipe.RunAsync(server.Cancellation.Token).AsTask(),
pipe.RunAsync(server.Cancellation.Token),
Task.Delay(100).ContinueWith(_ => server.Cancellation.Cancel()));
}

Expand All @@ -42,7 +42,7 @@ public async Task WhenServerClosesWebSocket_ThenClientCompletesGracefully()
await server.DisposeAsync();

Task.WaitAny(
run.AsTask(),
run,
Task.Delay(100).ContinueWith(_ => throw new TimeoutException()));
}

Expand Down
2 changes: 1 addition & 1 deletion src/Tests/WebSocketServer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static WebSocketServer Create(Func<IWebSocketPipe, Task>? pipeBehavior, Func<Web
if (pipeBehavior != null)
{
using var pipe = WebSocketPipe.Create(websocket, options);
await Task.WhenAll(pipeBehavior(pipe), pipe.RunAsync(cts.Token).AsTask());
await Task.WhenAll(pipe.RunAsync(cts.Token), pipeBehavior(pipe));
}
else if (socketBehavior != null)
{
Expand Down
4 changes: 2 additions & 2 deletions src/WebSocketPipe/IWebSocketPipe.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public interface IWebSocketPipe : IDuplexPipe, IDisposable
/// <param name="closeStatusDescription">Optional close status description to use if the underlying
/// <see cref="WebSocket"/> is closed.</param>
/// <returns></returns>
public ValueTask CompleteAsync(WebSocketCloseStatus? closeStatus = null, string? closeStatusDescription = null);
public Task CompleteAsync(WebSocketCloseStatus? closeStatus = null, string? closeStatusDescription = null);

/// <summary>
/// Starts populating the <see cref="IDuplexPipe.Input"/> with incoming data from the underlying
Expand All @@ -54,5 +54,5 @@ public interface IWebSocketPipe : IDuplexPipe, IDisposable
/// <see cref="IDuplexPipe.Output"/> are completed, or an explicit invocation of <see cref="CompleteAsync"/>
/// is executed.
/// </returns>
public ValueTask RunAsync(CancellationToken cancellation = default);
public Task RunAsync(CancellationToken cancellation = default);
}
109 changes: 40 additions & 69 deletions src/WebSocketPipe/SimpleWebSocketPipe.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Net.WebSockets;
using System.Threading;
Expand All @@ -15,23 +16,28 @@ class SimpleWebSocketPipe : IWebSocketPipe
// Wait 250 ms before giving up on a Close, same as SignalR WebSocketHandler
static readonly TimeSpan closeTimeout = TimeSpan.FromMilliseconds(250);

readonly CancellationTokenSource disposeCancellation = new CancellationTokenSource();
readonly Pipe inputPipe;
readonly Pipe outputPipe;
readonly PipeWriter outputWriter;

readonly WebSocket webSocket;
readonly WebSocketPipeOptions options;

bool completed;

public SimpleWebSocketPipe(WebSocket webSocket, WebSocketPipeOptions options)
=> (this.webSocket, this.options, inputPipe, outputPipe)
= (webSocket, options, new Pipe(options.InputPipeOptions), new Pipe(options.OutputPipeOptions));
{
this.webSocket = webSocket;
this.options = options;
inputPipe = new Pipe(options.InputPipeOptions);
outputWriter = PipeWriter.Create(new WebSocketStream(webSocket));
}

bool IsClient => webSocket is ClientWebSocket;

public PipeReader Input => inputPipe.Reader;

public PipeWriter Output => outputPipe.Writer;
public PipeWriter Output => outputWriter;

public WebSocketCloseStatus? CloseStatus => webSocket.CloseStatus;

Expand All @@ -41,23 +47,16 @@ public SimpleWebSocketPipe(WebSocket webSocket, WebSocketPipeOptions options)

public string? SubProtocol => webSocket.SubProtocol;

public async ValueTask RunAsync(CancellationToken cancellation = default)
public Task RunAsync(CancellationToken cancellation = default)
{
if (webSocket.State != WebSocketState.Open)
throw new InvalidOperationException($"WebSocket must be opened. State was {webSocket.State}");

var writing = FillInputAsync(cancellation);
var reading = SendOutputAsync(cancellation);

// NOTE: when both are completed, the CompleteAsync will be called automatically
// by both writing and reading, so we ensure CloseWhenCompleted is performed.

// TODO: replace with ValueTask.WhenAll if/when it ships.
// See https://github.com/dotnet/runtime/issues/23625
await Task.WhenAll(reading.AsTask(), writing.AsTask());
var combined = CancellationTokenSource.CreateLinkedTokenSource(cancellation, disposeCancellation.Token);
return ReadInputAsync(combined.Token);
}

public async ValueTask CompleteAsync(WebSocketCloseStatus? closeStatus = null, string? closeStatusDescription = null)
public async Task CompleteAsync(WebSocketCloseStatus? closeStatus = null, string? closeStatusDescription = null)
{
if (completed)
return;
Expand All @@ -68,14 +67,11 @@ public async ValueTask CompleteAsync(WebSocketCloseStatus? closeStatus = null, s
await inputPipe.Writer.CompleteAsync();
await inputPipe.Reader.CompleteAsync();

await outputPipe.Writer.CompleteAsync();
await outputPipe.Reader.CompleteAsync();

if (options.CloseWhenCompleted || closeStatus != null)
await CloseAsync(closeStatus ?? WebSocketCloseStatus.NormalClosure, closeStatusDescription ?? "");
}

async ValueTask CloseAsync(WebSocketCloseStatus closeStatus, string closeStatusDescription)
async Task CloseAsync(WebSocketCloseStatus closeStatus, string closeStatusDescription)
{
var state = State;
if (state == WebSocketState.Closed || state == WebSocketState.CloseSent || state == WebSocketState.Aborted)
Expand All @@ -90,7 +86,7 @@ async ValueTask CloseAsync(WebSocketCloseStatus closeStatus, string closeStatusD
await Task.WhenAny(closeTask, Task.Delay(closeTimeout));
}

async ValueTask FillInputAsync(CancellationToken cancellation)
async Task ReadInputAsync(CancellationToken cancellation)
{
while (webSocket.State == WebSocketState.Open && !cancellation.IsCancellationRequested)
{
Expand Down Expand Up @@ -129,56 +125,31 @@ ex is WebSocketException ||
await CompleteAsync(webSocket.CloseStatus, webSocket.CloseStatusDescription);
}

async ValueTask SendOutputAsync(CancellationToken cancellation)
public void Dispose()
{
while (webSocket.State == WebSocketState.Open && !cancellation.IsCancellationRequested)
{
try
{
var result = await outputPipe.Reader.ReadAsync(cancellation);
if (result.IsCompleted || result.IsCanceled)
break;

if (result.Buffer.IsSingleSegment)
{
await webSocket.SendAsync(result.Buffer.First, WebSocketMessageType.Binary, true, cancellation);
}
else
{
var enumerator = result.Buffer.GetEnumerator();
if (enumerator.MoveNext())
{
// NOTE: we don't use the cancellation here because we don't want to send
// partial messages from an already completely read buffer.
while (true)
{
var current = enumerator.Current;
if (default(ReadOnlyMemory<byte>).Equals(current))
break;

// Peek next to see if we should send an end of message
if (enumerator.MoveNext())
await webSocket.SendAsync(current, WebSocketMessageType.Binary, false, cancellation);
else
await webSocket.SendAsync(current, WebSocketMessageType.Binary, true, cancellation);
}
}
}

outputPipe.Reader.AdvanceTo(result.Buffer.End);

}
catch (Exception ex) when (ex is OperationCanceledException ||
ex is WebSocketException ||
ex is InvalidOperationException)
{
break;
}
}

// Preserve the close status since it might be triggered by a received Close message containing the status and description.
await CompleteAsync(webSocket.CloseStatus, webSocket.CloseStatusDescription);
disposeCancellation.Cancel();
webSocket.Dispose();
}

public void Dispose() => webSocket.Dispose();
class WebSocketStream : Stream
{
readonly WebSocket webSocket;

public WebSocketStream(WebSocket webSocket) => this.webSocket = webSocket;

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
=> webSocket.SendAsync(buffer, WebSocketMessageType.Binary, true, cancellationToken);

public override Task FlushAsync(CancellationToken cancellationToken) => Task.CompletedTask;
public override bool CanRead => throw new NotImplementedException();
public override bool CanSeek => throw new NotImplementedException();
public override bool CanWrite => throw new NotImplementedException();
public override long Length => throw new NotImplementedException();
public override long Position { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
public override void Flush() => throw new NotImplementedException();
public override int Read(byte[] buffer, int offset, int count) => throw new NotImplementedException();
public override long Seek(long offset, SeekOrigin origin) => throw new NotImplementedException();
public override void SetLength(long value) => throw new NotImplementedException();
public override void Write(byte[] buffer, int offset, int count) => throw new NotImplementedException();
}
}
8 changes: 1 addition & 7 deletions src/WebSocketPipe/WebSocketPipeOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,5 @@ public class WebSocketPipeOptions
/// Allows fine-grained configuration options for the incoming side of the
/// websocket pipe. Defaults to <see cref="PipeOptions.Default"/>.
/// </summary>
public PipeOptions InputPipeOptions { get; set; } = PipeOptions.Default;

/// <summary>
/// Allows fine-grained configuration options for the outgoing side of the
/// websocket pipe. Defaults to <see cref="PipeOptions.Default"/>.
/// </summary>
public PipeOptions OutputPipeOptions { get; set; } = PipeOptions.Default;
public PipeOptions InputPipeOptions { get; set; } = new PipeOptions(useSynchronizationContext: false);
}

0 comments on commit 3956e9e

Please sign in to comment.