diff --git a/.gitignore b/.gitignore index b844d45..b8a2832 100644 --- a/.gitignore +++ b/.gitignore @@ -483,4 +483,5 @@ $RECYCLE.BIN/ # Vim temporary swap files *.swp -logs/ \ No newline at end of file +logs/ +header.md diff --git a/src/Moonlight.Api/Events/EventArgs/PacketReceivedAsyncServerEventArgs.cs b/src/Moonlight.Api/Events/EventArgs/PacketReceivedAsyncServerEventArgs.cs new file mode 100644 index 0000000..b2b9d7e --- /dev/null +++ b/src/Moonlight.Api/Events/EventArgs/PacketReceivedAsyncServerEventArgs.cs @@ -0,0 +1,17 @@ +using Moonlight.Api.Net; +using Moonlight.Protocol.Net; + +namespace Moonlight.Api.Events.EventArgs +{ + public sealed class PacketReceivedAsyncServerEventArgs : AsyncServerEventArgs + { + public PacketReader Reader { get; init; } + public IPacket Packet { get; init; } + + public PacketReceivedAsyncServerEventArgs(IPacket packet, PacketReader reader) + { + Packet = packet; + Reader = reader; + } + } +} diff --git a/src/Moonlight.Api/Net/PacketReader.cs b/src/Moonlight.Api/Net/PacketReader.cs index 5a4e2c5..63b895a 100644 --- a/src/Moonlight.Api/Net/PacketReader.cs +++ b/src/Moonlight.Api/Net/PacketReader.cs @@ -1,68 +1,30 @@ using System; using System.Buffers; -using System.Collections.Frozen; -using System.Collections.Generic; using System.IO; using System.IO.Pipelines; using System.Linq; -using System.Reflection; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; using Moonlight.Protocol.Net; using Moonlight.Protocol.VariableTypes; namespace Moonlight.Api.Net { - public sealed class PacketReader + public sealed class PacketReader : IDisposable { - private delegate IPacket DeserializeDelegate(ref SequenceReader reader); - private static readonly FrozenDictionary _packetDeserializers; - - public readonly PipeReader _pipeReader; + private readonly PacketReaderFactory _factory; + private readonly Stream _stream; private readonly ILogger _logger; + private readonly PipeReader _pipeReader; + private object? _disposed; - static PacketReader() - { - Dictionary packetDeserializers = []; - - // Iterate through the assembly and find all classes that implement IServerPacket<> - foreach (Type type in typeof(IServerPacket<>).Assembly.GetExportedTypes()) - { - // Ensure we grab a fully implemented packet, not IServerPacket<> or an abstract class that implements it - if (!type.IsClass || type.IsAbstract) - { - continue; - } - - // Grab the generic argument of IServerPacket<> - Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0]; - if (packetType is null) - { - continue; - } - - // Grab the static Id property from Packet<> - VarInt packetId = (VarInt)type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null)!; - - // Grab the deserialize method - MethodInfo deserializeMethod = type.GetMethod("Deserialize", [(typeof(SequenceReader).MakeByRefType())])!; - - // Convert the method into a delegate - DeserializeDelegate deserializeDelegate = (DeserializeDelegate)Delegate.CreateDelegate(typeof(DeserializeDelegate), null, deserializeMethod); - - // Now we store the pointer in a dictionary for later use - packetDeserializers[packetId.Value] = deserializeDelegate; - } - - _packetDeserializers = packetDeserializers.ToFrozenDictionary(); - } - - public PacketReader(Stream stream, ILogger? logger = null) + public PacketReader(PacketReaderFactory factory, Stream stream, ILogger logger) { - _pipeReader = PipeReader.Create(stream); - _logger = logger ?? NullLogger.Instance; + _factory = factory; + _stream = stream; + _pipeReader = PipeReader.Create(_stream); + _logger = logger; } public async ValueTask ReadPacketAsync(CancellationToken cancellationToken = default) where T : IPacket @@ -139,10 +101,10 @@ public async ValueTask ReadPacketAsync(CancellationToken cancellationTo } VarInt packetId = VarInt.Deserialize(ref reader); - if (!_packetDeserializers.TryGetValue(packetId.Value, out DeserializeDelegate? packetDeserializerPointer)) + if (!_factory.PreparedPacketDeserializers.TryGetValue(packetId.Value, out DeserializerDelegate? packetDeserializerPointer)) { // Grab the unknown packet deserializer - packetDeserializerPointer = _packetDeserializers[-1]; + packetDeserializerPointer = _factory.PreparedPacketDeserializers[-1]; // Rewind so the unknown packet can store the received packet ID. reader.Rewind(packetId.Length); @@ -153,5 +115,16 @@ public async ValueTask ReadPacketAsync(CancellationToken cancellationTo position = reader.Position; return packet; } + + public void Dispose() + { + if (_disposed is not null) + { + return; + } + + _disposed = new object(); + GC.SuppressFinalize(this); + } } } diff --git a/src/Moonlight.Api/Net/PacketReaderFactory.cs b/src/Moonlight.Api/Net/PacketReaderFactory.cs new file mode 100644 index 0000000..b24cb12 --- /dev/null +++ b/src/Moonlight.Api/Net/PacketReaderFactory.cs @@ -0,0 +1,142 @@ +using System; +using System.Buffers; +using System.Collections.Frozen; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Reflection; +using System.Runtime.CompilerServices; +using Microsoft.Extensions.Logging; +using Moonlight.Protocol.Net; +using Moonlight.Protocol.VariableTypes; + +namespace Moonlight.Api.Net +{ + public delegate IPacket DeserializerDelegate(ref SequenceReader reader); + + public sealed class PacketReaderFactory + { + public Dictionary PacketDeserializers { get; init; } = []; + public FrozenDictionary PreparedPacketDeserializers { get; private set; } = FrozenDictionary.Empty; + + private readonly ILoggerFactory _loggerFactory; + + public PacketReaderFactory(ILoggerFactory loggerFactory) => _loggerFactory = loggerFactory; + + public void Prepare() => PreparedPacketDeserializers = PacketDeserializers.ToFrozenDictionary(); + + public PacketReader CreatePacketReader(Stream stream) => new(this, stream, _loggerFactory.CreateLogger()); + + public void AddPacketDeserializer(T serverPacket) where T : IServerPacket => + PacketDeserializers[T.Id] = (DeserializerDelegate)Delegate.CreateDelegate(typeof(T), serverPacket, ((Delegate)T.Deserialize).Method); + + public void AddPacketDeserializer() where T : IServerPacket => PacketDeserializers[T.Id] = Unsafe.As((Delegate)T.Deserialize); + public void AddPacketDeserializer(Type type) + { + if (type.IsAbstract) + { + throw new InvalidOperationException("Cannot use an abstract class as a packet deserializer."); + } + // See if the type implements IServerPacket<> + else if (type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0] is null) + { + throw new InvalidOperationException("Cannot use a class that does not implement IServerPacket<> as a packet deserializer."); + } + // Grab the IPacket<>.Id property + else if (type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null) is not VarInt packetId) + { + throw new InvalidOperationException("Cannot use a class that does not have a static Id property as a packet deserializer."); + } + else + { + // Grab the Deserialize method + MethodInfo deserializeMethod = type.GetMethod("Deserialize", BindingFlags.Public | BindingFlags.Static, null, [(typeof(SequenceReader).MakeByRefType())], null)!; + + // Convert the method into a delegate + DeserializerDelegate deserializeDelegate = (DeserializerDelegate)Delegate.CreateDelegate(typeof(DeserializerDelegate), null, deserializeMethod); + + // Store the delegate in the dictionary + PacketDeserializers[packetId] = deserializeDelegate; + } + } + + public void AddPacketDeserializer(int packetId, Type type) + { + if (type.IsAbstract) + { + throw new InvalidOperationException("Cannot use an abstract class as a packet deserializer."); + } + + Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0]; + if (packetType is null) + { + return; + } + + MethodInfo deserializeMethod = type.GetMethod("Deserialize", [(typeof(SequenceReader).MakeByRefType())])!; + DeserializerDelegate deserializeDelegate = (DeserializerDelegate)Delegate.CreateDelegate(typeof(DeserializerDelegate), null, deserializeMethod); + PacketDeserializers[packetId] = deserializeDelegate; + } + + public void AddDefaultPacketDeserializers() => AddPacketDeserializers(typeof(IServerPacket<>).Assembly.GetExportedTypes()); + + public void AddPacketDeserializers(IEnumerable types) + { + ILogger logger = _loggerFactory.CreateLogger(); + + // Iterate through the assembly and find all classes that implement IServerPacket<> + foreach (Type type in types) + { + // Ensure we grab a fully implemented packet, not IServerPacket<> or an abstract class that implements it + if (type.IsAbstract) + { + continue; + } + + // Grab the generic argument of IServerPacket<> + Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0]; + if (packetType is null) + { + continue; + } + + // Grab the static Id property from Packet<> + VarInt packetId = (VarInt)type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null)!; + + // Grab the deserialize method + MethodInfo deserializeMethod = type.GetMethod("Deserialize", [(typeof(SequenceReader).MakeByRefType())]) ?? throw new InvalidOperationException($"Could not find the method 'Deserialize' in '{type.Name}'."); + + // Convert the method into a delegate + DeserializerDelegate deserializeDelegate = (DeserializerDelegate)Delegate.CreateDelegate(typeof(DeserializerDelegate), deserializeMethod); + + // Now we store the pointer in a dictionary for later use + if (PacketDeserializers.TryGetValue(packetId.Value, out DeserializerDelegate? existingDelegate)) + { + logger.LogWarning("Failed to add packet deserializer for packet {PacketId} ({PacketType}), a deserializer for this packet already exists: {ExistingDelegate}", packetId, packetType, existingDelegate); + continue; + } + + PacketDeserializers[packetId.Value] = deserializeDelegate; + } + } + + public void RemovePacketDeserializer(int packetId) => PacketDeserializers.Remove(packetId); + public void RemovePacketDeserializer() where T : IServerPacket => PacketDeserializers.Remove(T.Id); + public void RemovePacketDeserializer(Type type) + { + if (type.IsAbstract) + { + return; + } + + Type? packetType = type.GetInterfaces().FirstOrDefault(i => i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IServerPacket<>))?.GetGenericArguments()[0]; + if (packetType is null) + { + return; + } + + VarInt packetId = (VarInt)type.GetProperty("Id", BindingFlags.Public | BindingFlags.Static)!.GetValue(null)!; + PacketDeserializers.Remove(packetId); + } + } +} diff --git a/src/Moonlight.Api/Server.cs b/src/Moonlight.Api/Server.cs index 9dabd18..653669d 100644 --- a/src/Moonlight.Api/Server.cs +++ b/src/Moonlight.Api/Server.cs @@ -1,15 +1,14 @@ -using System; +using System.IO; using System.Net; using System.Net.Sockets; -using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Microsoft.Extensions.Logging; -using Microsoft.Extensions.Logging.Abstractions; +using Moonlight.Api.Events; +using Moonlight.Api.Events.EventArgs; using Moonlight.Api.Net; using Moonlight.Protocol.Net; -[assembly: InternalsVisibleTo("Moonlight")] namespace Moonlight.Api { public sealed class Server @@ -18,18 +17,22 @@ public sealed class Server public CancellationToken CancellationToken => _cancellationTokenSource.Token; private readonly ILogger _logger; - private readonly ILoggerFactory _loggerProvider; + private readonly PacketReaderFactory _packetReaderFactory; private readonly CancellationTokenSource _cancellationTokenSource = new(); + private readonly AsyncServerEvent _packetReceivedServerEvent; - public Server(ServerConfiguration configuration, ILoggerFactory? logger = null) + public Server(ServerConfiguration serverConfiguration, PacketReaderFactory packetReaderFactory, ILogger logger, AsyncServerEvent packetReceivedServerEvent) { - Configuration = configuration; - _loggerProvider = logger ?? NullLoggerFactory.Instance; - _logger = _loggerProvider.CreateLogger(); + Configuration = serverConfiguration; + _packetReaderFactory = packetReaderFactory; + _logger = logger; + _packetReceivedServerEvent = packetReceivedServerEvent; } public async Task StartAsync() { + _packetReaderFactory.Prepare(); + _logger.LogInformation("Starting server..."); TcpListener listener = new(IPAddress.Parse(Configuration.Host), Configuration.Port); listener.Start(); @@ -37,39 +40,36 @@ public async Task StartAsync() _logger.LogInformation("Server started on {EndPoint}", listener.LocalEndpoint); while (!_cancellationTokenSource.IsCancellationRequested) { - TcpClient client = await listener.AcceptTcpClientAsync(CancellationToken); - _logger.LogInformation("Client connected: {EndPoint}", client.Client.RemoteEndPoint); + _ = HandleClientAsync(await listener.AcceptTcpClientAsync(CancellationToken)); + } + } - // Try to read the handshake packet - PacketReader reader = new(client.GetStream(), _loggerProvider.CreateLogger()); - HandshakePacket handshake; + private async Task HandleClientAsync(TcpClient client) + { + _logger.LogInformation("Client connected: {EndPoint}", client.Client.RemoteEndPoint); - try - { - handshake = await reader.ReadPacketAsync(CancellationToken); - _logger.LogInformation("Handshake received: {Handshake}", handshake); - } - catch (InvalidOperationException error) + using NetworkStream stream = client.GetStream(); + PacketReader reader = _packetReaderFactory.CreatePacketReader(stream); + try + { + HandshakePacket handshake = await reader.ReadPacketAsync(CancellationToken); + if (await _packetReceivedServerEvent.InvokePreHandlersAsync(new PacketReceivedAsyncServerEventArgs(handshake, reader))) { - _logger.LogError(error, "Failed to read handshake packet."); - continue; + await _packetReceivedServerEvent.InvokePostHandlersAsync(new PacketReceivedAsyncServerEventArgs(handshake, reader)); } + } + catch (InvalidDataException error) + { + _logger.LogError(error, "Error handling client: {EndPoint}", client.Client.RemoteEndPoint); + return; + } - while (!_cancellationTokenSource.IsCancellationRequested) + while (!CancellationToken.IsCancellationRequested) + { + IPacket packet = await reader.ReadPacketAsync(CancellationToken); + if (await _packetReceivedServerEvent.InvokePreHandlersAsync(new PacketReceivedAsyncServerEventArgs(packet, reader))) { - // Try to read the next packet - IPacket packet; - - try - { - packet = await reader.ReadPacketAsync(CancellationToken); - _logger.LogInformation("Packet received: {Packet}", packet); - } - catch (InvalidOperationException error) - { - _logger.LogError(error, "Failed to read packet."); - break; - } + await _packetReceivedServerEvent.InvokePostHandlersAsync(new PacketReceivedAsyncServerEventArgs(packet, reader)); } } } diff --git a/src/Moonlight.Protocol/VariableTypes/VarInt.cs b/src/Moonlight.Protocol/VariableTypes/VarInt.cs index 167dac6..c623260 100644 --- a/src/Moonlight.Protocol/VariableTypes/VarInt.cs +++ b/src/Moonlight.Protocol/VariableTypes/VarInt.cs @@ -89,5 +89,6 @@ public static VarInt Deserialize(ref SequenceReader reader) public static implicit operator VarInt(int value) => new(value); public static implicit operator int(VarInt value) => value.Value; + public override string ToString() => Value.ToString(); } } diff --git a/src/Moonlight/EventHandlers/PacketReceiverLogger.cs b/src/Moonlight/EventHandlers/PacketReceiverLogger.cs new file mode 100644 index 0000000..67de317 --- /dev/null +++ b/src/Moonlight/EventHandlers/PacketReceiverLogger.cs @@ -0,0 +1,19 @@ +using System.Threading.Tasks; +using Microsoft.Extensions.Logging; +using Moonlight.Api.Events.EventArgs; + +namespace Moonlight.EventHandlers +{ + public sealed class PacketReceiverLogger + { + private readonly ILogger _logger; + + public PacketReceiverLogger(ILogger logger) => _logger = logger; + + public ValueTask LogPacketReceivedAsync(PacketReceivedAsyncServerEventArgs eventArgs) + { + _logger.LogInformation("Packet received: {Packet}", eventArgs.Packet); + return ValueTask.FromResult(true); + } + } +} diff --git a/src/Moonlight/Program.cs b/src/Moonlight/Program.cs index 3b59c70..0d414e9 100644 --- a/src/Moonlight/Program.cs +++ b/src/Moonlight/Program.cs @@ -1,9 +1,13 @@ using System; using System.Globalization; +using System.Reflection; using System.Threading.Tasks; using Microsoft.Extensions.Configuration; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Logging; using Moonlight.Api; +using Moonlight.Api.Events; +using Moonlight.Api.Net; using Serilog; using Serilog.Events; using Serilog.Sinks.SystemConsole.Themes; @@ -64,10 +68,82 @@ public static async Task Main(string[] args) logging.AddSerilog(serilogLoggerConfiguration.CreateLogger()); }); + serviceCollection.AddSingleton((serviceProvider) => + { + PacketReaderFactory packetReaderFactory = new(serviceProvider.GetRequiredService()); + packetReaderFactory.AddDefaultPacketDeserializers(); + return packetReaderFactory; + }); + + serviceCollection.AddSingleton(typeof(AsyncServerEvent<>)); serviceCollection.AddSingleton(); serviceCollection.AddSingleton(); ServiceProvider serviceProvider = serviceCollection.BuildServiceProvider(); + + // Register all event handlers + foreach (Type type in typeof(Program).Assembly.GetTypes()) + { + foreach (MethodInfo method in type.GetMethods(BindingFlags.Public | BindingFlags.Instance | BindingFlags.Static)) + { + ParameterInfo[] parameters = method.GetParameters(); + if (parameters.Length != 1 || !parameters[0].ParameterType.IsAssignableTo(typeof(AsyncServerEventArgs))) + { + continue; + } + + object asyncServerEvent = serviceProvider.GetRequiredService(typeof(AsyncServerEvent<>).MakeGenericType(parameters[0].ParameterType)); + MethodInfo addPreHandler = asyncServerEvent.GetType().GetMethod("AddPreHandler") ?? throw new InvalidOperationException("Could not find the method 'AddPreHandler' in 'AsyncServerEvent<>'."); + MethodInfo addPostHandler = asyncServerEvent.GetType().GetMethod("AddPostHandler") ?? throw new InvalidOperationException("Could not find the method 'AddPostHandler' in 'AsyncServerEvent<>'."); + if (method.ReturnType == typeof(ValueTask)) + { + if (method.IsStatic) + { + // Invoke void AddPreHandler(AsyncServerEventPreHandler handler, AsyncServerEventPriority priority = AsyncServerEventPriority.Normal) + addPreHandler.Invoke(asyncServerEvent, [ + // Create a delegate of the method + Delegate.CreateDelegate(typeof(AsyncServerEventPreHandler<>).MakeGenericType(parameters[0].ParameterType), method), + // Normal priority + AsyncServerEventPriority.Normal + ]); + } + else + { + // Invoke void AddPreHandler(AsyncServerEventPreHandler handler, AsyncServerEventPriority priority = AsyncServerEventPriority.Normal) + addPreHandler.Invoke(asyncServerEvent, [ + // Create a delegate of the method + Delegate.CreateDelegate(typeof(AsyncServerEventPreHandler<>).MakeGenericType(parameters[0].ParameterType), ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider, type), method), + // Normal priority + AsyncServerEventPriority.Normal + ]); + } + } + else + { + if (method.IsStatic) + { + // Invoke void AddPostHandler(AsyncServerEventHandler handler, AsyncServerEventPriority priority = AsyncServerEventPriority.Normal) + addPostHandler.Invoke(asyncServerEvent, [ + // Create a delegate of the method + Delegate.CreateDelegate(typeof(AsyncServerEventHandler<>).MakeGenericType(parameters[0].ParameterType), method), + // Normal priority + AsyncServerEventPriority.Normal + ]); + } + else + { + // Invoke void AddPostHandler(AsyncServerEventHandler handler, AsyncServerEventPriority priority = AsyncServerEventPriority.Normal) + addPostHandler.Invoke(asyncServerEvent, [ + // Create a delegate of the method + Delegate.CreateDelegate(typeof(AsyncServerEventHandler<>).MakeGenericType(parameters[0].ParameterType), ActivatorUtilities.GetServiceOrCreateInstance(serviceProvider, type), method), + // Normal priority + AsyncServerEventPriority.Normal + ]); + } + } + } + } + Server server = serviceProvider.GetRequiredService(); await server.StartAsync(); }