Skip to content

Commit

Permalink
New event system works
Browse files Browse the repository at this point in the history
  • Loading branch information
OoLunar committed Dec 12, 2024
1 parent 5f44d14 commit 9179ce5
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 87 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -483,4 +483,5 @@ $RECYCLE.BIN/
# Vim temporary swap files
*.swp

logs/
logs/
header.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using Moonlight.Api.Net;
using Moonlight.Protocol.Net;

namespace Moonlight.Api.Events.EventArgs
{
public sealed class PacketReceivedAsyncServerEventArgs : AsyncServerEventArgs
{
public PacketReader Reader { get; init; }
public IPacket Packet { get; init; }

public PacketReceivedAsyncServerEventArgs(IPacket packet, PacketReader reader)
{
Packet = packet;
Reader = reader;
}
}
}
73 changes: 23 additions & 50 deletions src/Moonlight.Api/Net/PacketReader.cs
Original file line number Diff line number Diff line change
@@ -1,68 +1,30 @@
using System;
using System.Buffers;
using System.Collections.Frozen;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Reflection;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Moonlight.Protocol.Net;
using Moonlight.Protocol.VariableTypes;

namespace Moonlight.Api.Net
{
public sealed class PacketReader
public sealed class PacketReader : IDisposable
{
private delegate IPacket DeserializeDelegate(ref SequenceReader<byte> reader);
private static readonly FrozenDictionary<int, DeserializeDelegate> _packetDeserializers;

public readonly PipeReader _pipeReader;
private readonly PacketReaderFactory _factory;
private readonly Stream _stream;
private readonly ILogger<PacketReader> _logger;
private readonly PipeReader _pipeReader;
private object? _disposed;

static PacketReader()
{
Dictionary<int, DeserializeDelegate> packetDeserializers = [];

// Iterate through the assembly and find all classes that implement IServerPacket<>
foreach (Type type in typeof(IServerPacket<>).Assembly.GetExportedTypes())
{
// Ensure we grab a fully implemented packet, not IServerPacket<> or an abstract class that implements it
if (!type.IsClass || type.IsAbstract)
{
continue;
}

// Grab the generic argument of IServerPacket<>
Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0];
if (packetType is null)
{
continue;
}

// Grab the static Id property from Packet<>
VarInt packetId = (VarInt)type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null)!;

// Grab the deserialize method
MethodInfo deserializeMethod = type.GetMethod("Deserialize", [(typeof(SequenceReader<byte>).MakeByRefType())])!;

// Convert the method into a delegate
DeserializeDelegate deserializeDelegate = (DeserializeDelegate)Delegate.CreateDelegate(typeof(DeserializeDelegate), null, deserializeMethod);

// Now we store the pointer in a dictionary for later use
packetDeserializers[packetId.Value] = deserializeDelegate;
}

_packetDeserializers = packetDeserializers.ToFrozenDictionary();
}

public PacketReader(Stream stream, ILogger<PacketReader>? logger = null)
public PacketReader(PacketReaderFactory factory, Stream stream, ILogger<PacketReader> logger)
{
_pipeReader = PipeReader.Create(stream);
_logger = logger ?? NullLogger<PacketReader>.Instance;
_factory = factory;
_stream = stream;
_pipeReader = PipeReader.Create(_stream);
_logger = logger;
}

public async ValueTask<T> ReadPacketAsync<T>(CancellationToken cancellationToken = default) where T : IPacket<T>
Expand Down Expand Up @@ -139,10 +101,10 @@ public async ValueTask<IPacket> ReadPacketAsync(CancellationToken cancellationTo
}

VarInt packetId = VarInt.Deserialize(ref reader);
if (!_packetDeserializers.TryGetValue(packetId.Value, out DeserializeDelegate? packetDeserializerPointer))
if (!_factory.PreparedPacketDeserializers.TryGetValue(packetId.Value, out DeserializerDelegate? packetDeserializerPointer))
{
// Grab the unknown packet deserializer
packetDeserializerPointer = _packetDeserializers[-1];
packetDeserializerPointer = _factory.PreparedPacketDeserializers[-1];

// Rewind so the unknown packet can store the received packet ID.
reader.Rewind(packetId.Length);
Expand All @@ -153,5 +115,16 @@ public async ValueTask<IPacket> ReadPacketAsync(CancellationToken cancellationTo
position = reader.Position;
return packet;
}

public void Dispose()
{
if (_disposed is not null)
{
return;
}

_disposed = new object();
GC.SuppressFinalize(this);
}
}
}
142 changes: 142 additions & 0 deletions src/Moonlight.Api/Net/PacketReaderFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
using System;
using System.Buffers;
using System.Collections.Frozen;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.Logging;
using Moonlight.Protocol.Net;
using Moonlight.Protocol.VariableTypes;

namespace Moonlight.Api.Net
{
public delegate IPacket DeserializerDelegate(ref SequenceReader<byte> reader);

public sealed class PacketReaderFactory
{
public Dictionary<int, DeserializerDelegate> PacketDeserializers { get; init; } = [];
public FrozenDictionary<int, DeserializerDelegate> PreparedPacketDeserializers { get; private set; } = FrozenDictionary<int, DeserializerDelegate>.Empty;

private readonly ILoggerFactory _loggerFactory;

public PacketReaderFactory(ILoggerFactory loggerFactory) => _loggerFactory = loggerFactory;

public void Prepare() => PreparedPacketDeserializers = PacketDeserializers.ToFrozenDictionary();

public PacketReader CreatePacketReader(Stream stream) => new(this, stream, _loggerFactory.CreateLogger<PacketReader>());

public void AddPacketDeserializer<T>(T serverPacket) where T : IServerPacket<T> =>
PacketDeserializers[T.Id] = (DeserializerDelegate)Delegate.CreateDelegate(typeof(T), serverPacket, ((Delegate)T.Deserialize).Method);

public void AddPacketDeserializer<T>() where T : IServerPacket<T> => PacketDeserializers[T.Id] = Unsafe.As<DeserializerDelegate>((Delegate)T.Deserialize);
public void AddPacketDeserializer(Type type)
{
if (type.IsAbstract)
{
throw new InvalidOperationException("Cannot use an abstract class as a packet deserializer.");
}
// See if the type implements IServerPacket<>
else if (type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0] is null)
{
throw new InvalidOperationException("Cannot use a class that does not implement IServerPacket<> as a packet deserializer.");
}
// Grab the IPacket<>.Id property
else if (type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null) is not VarInt packetId)
{
throw new InvalidOperationException("Cannot use a class that does not have a static Id property as a packet deserializer.");
}
else
{
// Grab the Deserialize method
MethodInfo deserializeMethod = type.GetMethod("Deserialize", BindingFlags.Public | BindingFlags.Static, null, [(typeof(SequenceReader<byte>).MakeByRefType())], null)!;

// Convert the method into a delegate
DeserializerDelegate deserializeDelegate = (DeserializerDelegate)Delegate.CreateDelegate(typeof(DeserializerDelegate), null, deserializeMethod);

// Store the delegate in the dictionary
PacketDeserializers[packetId] = deserializeDelegate;
}
}

public void AddPacketDeserializer(int packetId, Type type)
{
if (type.IsAbstract)
{
throw new InvalidOperationException("Cannot use an abstract class as a packet deserializer.");
}

Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0];
if (packetType is null)
{
return;
}

MethodInfo deserializeMethod = type.GetMethod("Deserialize", [(typeof(SequenceReader<byte>).MakeByRefType())])!;
DeserializerDelegate deserializeDelegate = (DeserializerDelegate)Delegate.CreateDelegate(typeof(DeserializerDelegate), null, deserializeMethod);
PacketDeserializers[packetId] = deserializeDelegate;
}

public void AddDefaultPacketDeserializers() => AddPacketDeserializers(typeof(IServerPacket<>).Assembly.GetExportedTypes());

public void AddPacketDeserializers(IEnumerable<Type> types)
{
ILogger<PacketReaderFactory> logger = _loggerFactory.CreateLogger<PacketReaderFactory>();

// Iterate through the assembly and find all classes that implement IServerPacket<>
foreach (Type type in types)
{
// Ensure we grab a fully implemented packet, not IServerPacket<> or an abstract class that implements it
if (type.IsAbstract)
{
continue;
}

// Grab the generic argument of IServerPacket<>
Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0];
if (packetType is null)
{
continue;
}

// Grab the static Id property from Packet<>
VarInt packetId = (VarInt)type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null)!;

// Grab the deserialize method
MethodInfo deserializeMethod = type.GetMethod("Deserialize", [(typeof(SequenceReader<byte>).MakeByRefType())]) ?? throw new InvalidOperationException($"Could not find the method 'Deserialize' in '{type.Name}'.");

// Convert the method into a delegate
DeserializerDelegate deserializeDelegate = (DeserializerDelegate)Delegate.CreateDelegate(typeof(DeserializerDelegate), deserializeMethod);

// Now we store the pointer in a dictionary for later use
if (PacketDeserializers.TryGetValue(packetId.Value, out DeserializerDelegate? existingDelegate))
{
logger.LogWarning("Failed to add packet deserializer for packet {PacketId} ({PacketType}), a deserializer for this packet already exists: {ExistingDelegate}", packetId, packetType, existingDelegate);
continue;
}

PacketDeserializers[packetId.Value] = deserializeDelegate;
}
}

public void RemovePacketDeserializer(int packetId) => PacketDeserializers.Remove(packetId);
public void RemovePacketDeserializer<T>() where T : IServerPacket<T> => PacketDeserializers.Remove(T.Id);
public void RemovePacketDeserializer(Type type)
{
if (type.IsAbstract)
{
return;
}

Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0];
if (packetType is null)
{
return;
}

VarInt packetId = (VarInt)type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null)!;
PacketDeserializers.Remove(packetId);
}
}
}
72 changes: 36 additions & 36 deletions src/Moonlight.Api/Server.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
using System;
using System.IO;
using System.Net;
using System.Net.Sockets;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
using Moonlight.Api.Events;
using Moonlight.Api.Events.EventArgs;
using Moonlight.Api.Net;
using Moonlight.Protocol.Net;

[assembly: InternalsVisibleTo("Moonlight")]
namespace Moonlight.Api
{
public sealed class Server
Expand All @@ -18,58 +17,59 @@ public sealed class Server
public CancellationToken CancellationToken => _cancellationTokenSource.Token;

private readonly ILogger<Server> _logger;
private readonly ILoggerFactory _loggerProvider;
private readonly PacketReaderFactory _packetReaderFactory;
private readonly CancellationTokenSource _cancellationTokenSource = new();
private readonly AsyncServerEvent<PacketReceivedAsyncServerEventArgs> _packetReceivedServerEvent;

public Server(ServerConfiguration configuration, ILoggerFactory? logger = null)
public Server(ServerConfiguration serverConfiguration, PacketReaderFactory packetReaderFactory, ILogger<Server> logger, AsyncServerEvent<PacketReceivedAsyncServerEventArgs> packetReceivedServerEvent)
{
Configuration = configuration;
_loggerProvider = logger ?? NullLoggerFactory.Instance;
_logger = _loggerProvider.CreateLogger<Server>();
Configuration = serverConfiguration;
_packetReaderFactory = packetReaderFactory;
_logger = logger;
_packetReceivedServerEvent = packetReceivedServerEvent;
}

public async Task StartAsync()
{
_packetReaderFactory.Prepare();

_logger.LogInformation("Starting server...");
TcpListener listener = new(IPAddress.Parse(Configuration.Host), Configuration.Port);
listener.Start();

_logger.LogInformation("Server started on {EndPoint}", listener.LocalEndpoint);
while (!_cancellationTokenSource.IsCancellationRequested)
{
TcpClient client = await listener.AcceptTcpClientAsync(CancellationToken);
_logger.LogInformation("Client connected: {EndPoint}", client.Client.RemoteEndPoint);
_ = HandleClientAsync(await listener.AcceptTcpClientAsync(CancellationToken));
}
}

// Try to read the handshake packet
PacketReader reader = new(client.GetStream(), _loggerProvider.CreateLogger<PacketReader>());
HandshakePacket handshake;
private async Task HandleClientAsync(TcpClient client)
{
_logger.LogInformation("Client connected: {EndPoint}", client.Client.RemoteEndPoint);

try
{
handshake = await reader.ReadPacketAsync<HandshakePacket>(CancellationToken);
_logger.LogInformation("Handshake received: {Handshake}", handshake);
}
catch (InvalidOperationException error)
using NetworkStream stream = client.GetStream();
PacketReader reader = _packetReaderFactory.CreatePacketReader(stream);
try
{
HandshakePacket handshake = await reader.ReadPacketAsync<HandshakePacket>(CancellationToken);
if (await _packetReceivedServerEvent.InvokePreHandlersAsync(new PacketReceivedAsyncServerEventArgs(handshake, reader)))
{
_logger.LogError(error, "Failed to read handshake packet.");
continue;
await _packetReceivedServerEvent.InvokePostHandlersAsync(new PacketReceivedAsyncServerEventArgs(handshake, reader));
}
}
catch (InvalidDataException error)
{
_logger.LogError(error, "Error handling client: {EndPoint}", client.Client.RemoteEndPoint);
return;
}

while (!_cancellationTokenSource.IsCancellationRequested)
while (!CancellationToken.IsCancellationRequested)
{
IPacket packet = await reader.ReadPacketAsync(CancellationToken);
if (await _packetReceivedServerEvent.InvokePreHandlersAsync(new PacketReceivedAsyncServerEventArgs(packet, reader)))
{
// Try to read the next packet
IPacket packet;

try
{
packet = await reader.ReadPacketAsync(CancellationToken);
_logger.LogInformation("Packet received: {Packet}", packet);
}
catch (InvalidOperationException error)
{
_logger.LogError(error, "Failed to read packet.");
break;
}
await _packetReceivedServerEvent.InvokePostHandlersAsync(new PacketReceivedAsyncServerEventArgs(packet, reader));
}
}
}
Expand Down
1 change: 1 addition & 0 deletions src/Moonlight.Protocol/VariableTypes/VarInt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,6 @@ public static VarInt Deserialize(ref SequenceReader<byte> reader)

public static implicit operator VarInt(int value) => new(value);
public static implicit operator int(VarInt value) => value.Value;
public override string ToString() => Value.ToString();
}
}
Loading

0 comments on commit 9179ce5

Please sign in to comment.