Skip to content

Fix for issue #54500: Middleware keyed dependency injection #55722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Builder;

Expand All @@ -21,6 +21,7 @@ public static class UseMiddlewareExtensions
internal const string InvokeAsyncMethodName = "InvokeAsync";

private static readonly MethodInfo GetServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetService), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetKeyedServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetKeyedService), BindingFlags.NonPublic | BindingFlags.Static)!;

// We're going to keep all public constructors and public methods on middleware
private const DynamicallyAccessedMemberTypes MiddlewareAccessibility =
Expand Down Expand Up @@ -209,19 +210,58 @@ private static Func<T, HttpContext, IServiceProvider, Task> ReflectionFallback<T
}
}

// Performance optimization: Precompute and cache the key results for each parameter
var precomputedKeys = new object?[parameters.Length];
for (var i = 1; i < parameters.Length; i++)
{
_ = TryGetServiceKey(parameters[i], out object? key);

precomputedKeys[i] = key;
}

return (middleware, context, serviceProvider) =>
{
var methodArguments = new object[parameters.Length];
methodArguments[0] = context;
for (var i = 1; i < parameters.Length; i++)
{
methodArguments[i] = GetService(serviceProvider, parameters[i].ParameterType, methodInfo.DeclaringType!);
var key = precomputedKeys[i];
var parameterType = parameters[i].ParameterType;
var declaringType = methodInfo.DeclaringType!;

methodArguments[i] = key == null ? GetService(serviceProvider, parameterType, declaringType) : GetKeyedService(serviceProvider, key, parameterType, declaringType);
}

return (Task)methodInfo.Invoke(middleware, BindingFlags.DoNotWrapExceptions, binder: null, methodArguments, culture: null)!;
};
}

private static bool TryGetServiceKey(ParameterInfo parameterInfo, [NotNullWhen(true)] out object? key)
{
key = parameterInfo.GetCustomAttribute<FromKeyedServicesAttribute>(false)?.Key;

return key != null;
}

private static UnaryExpression GetMethodArgument(ParameterInfo parameter, ParameterExpression providerArg, Type parameterType, Type? declaringType)
{
var parameterTypeExpression = new List<Expression>() { providerArg };
var hasServiceKey = TryGetServiceKey(parameter, out object? key);

if (hasServiceKey)
{
parameterTypeExpression.Add(Expression.Constant(key, typeof(object)));
}

parameterTypeExpression.Add(Expression.Constant(parameterType, typeof(Type)));
parameterTypeExpression.Add(Expression.Constant(declaringType, typeof(Type)));

var getServiceCall = Expression.Call(hasServiceKey ? GetKeyedServiceInfo : GetServiceInfo, parameterTypeExpression);
var methodArgument = Expression.Convert(getServiceCall, parameterType);

return methodArgument;
}

private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>(MethodInfo methodInfo, ParameterInfo[] parameters)
{
Debug.Assert(RuntimeFeature.IsDynamicCodeSupported, "Use compiled expression when dynamic code is supported.");
Expand Down Expand Up @@ -262,21 +302,14 @@ private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>
methodArguments[0] = httpContextArg;
for (var i = 1; i < parameters.Length; i++)
{
var parameterType = parameters[i].ParameterType;
var parameter = parameters[i];
var parameterType = parameter.ParameterType;
if (parameterType.IsByRef)
{
throw new NotSupportedException(Resources.FormatException_InvokeDoesNotSupportRefOrOutParams(InvokeMethodName));
}

var parameterTypeExpression = new Expression[]
{
providerArg,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(methodInfo.DeclaringType, typeof(Type))
};

var getServiceCall = Expression.Call(GetServiceInfo, parameterTypeExpression);
methodArguments[i] = Expression.Convert(getServiceCall, parameterType);
methodArguments[i] = GetMethodArgument(parameter, providerArg, parameterType, methodInfo.DeclaringType);
}

Expression middlewareInstanceArg = instanceArg;
Expand All @@ -294,12 +327,20 @@ private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>

private static object GetService(IServiceProvider sp, Type type, Type middleware)
{
var service = sp.GetService(type);
if (service == null)
var service = sp.GetService(type) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

private static object GetKeyedService(IServiceProvider sp, object key, Type type, Type middleware)
{
if (sp is IKeyedServiceProvider ksp)
{
throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));
var service = ksp.GetKeyedService(type, key) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

return service;
throw new InvalidOperationException(Resources.Exception_KeyedServicesNotSupported);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ Microsoft.AspNetCore.Http.HttpResponse</Description>
<Reference Include="Microsoft.AspNetCore.Http.Features" />
<Reference Include="Microsoft.Net.Http.Headers" />
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />

<Compile Include="$(SharedSourceRoot)ActivatorUtilities\*.cs" />
<Compile Include="$(SharedSourceRoot)ParameterDefaultValue\*.cs" />
<Compile Include="$(SharedSourceRoot)PropertyHelper\**\*.cs" />
<Compile Include="$(SharedSourceRoot)\UrlDecoder\UrlDecoder.cs" Link="UrlDecoder.cs" />
Expand Down
3 changes: 3 additions & 0 deletions src/Http/Http.Abstractions/src/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,7 @@
<data name="RouteValueDictionary_DuplicatePropertyName" xml:space="preserve">
<value>The type '{0}' defines properties '{1}' and '{2}' which differ only by casing. This is not supported by {3} which uses case-insensitive comparisons.</value>
</data>
<data name="Exception_KeyedServicesNotSupported" xml:space="preserve">
<value>This service provider doesn't support keyed services.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Mono.TextTemplating" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
</ItemGroup>

<ItemGroup>
Expand Down
122 changes: 122 additions & 0 deletions src/Http/Http.Abstractions/test/UseMiddlewareTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Http;

Expand Down Expand Up @@ -130,6 +131,43 @@ public async Task UseMiddleware_ThrowsIfArgCantBeResolvedFromContainer()
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfKeyedArgCantBeResolvedFromContainer()
{
var builder = new ApplicationBuilder(new DummyKeyedServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvoke));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.FormatException_InvokeMiddlewareNoService(
typeof(IKeyedServiceProvider),
typeof(MiddlewareKeyedInjectInvoke)),
exception.Message);
}

[Fact]
public void UseMiddleware_ThrowsIfKeyedConstructorArgCantBeResolvedFromContainer()
{
var builder = new ApplicationBuilder(new DummyKeyedServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedConstructorInjectInvoke));
var exception = Assert.Throws<InvalidOperationException>(builder.Build);
Assert.Equal(
$"Unable to resolve service for type '{typeof(IKeyedServiceProvider)}' while attempting to activate '{typeof(MiddlewareKeyedConstructorInjectInvoke)}'.",
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfServiceProviderIsNotAIKeyedServiceProvider()
{
var builder = new ApplicationBuilder(new DummyServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvoke));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.Exception_KeyedServicesNotSupported,
exception.Message);
}

[Fact]
public void UseMiddlewareWithInvokeArg()
{
Expand All @@ -139,6 +177,28 @@ public void UseMiddlewareWithInvokeArg()
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeKeyedArg()
{
var keyedServiceProvider = new DummyKeyedServiceProvider();
keyedServiceProvider.AddKeyedService("test", typeof(DummyKeyedServiceProvider), keyedServiceProvider);
var builder = new ApplicationBuilder(keyedServiceProvider);
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvoke));
var app = builder.Build();
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithConstructorKeyedArg()
{
var keyedServiceProvider = new DummyKeyedServiceProvider();
keyedServiceProvider.AddKeyedService("test", typeof(DummyKeyedServiceProvider), keyedServiceProvider);
var builder = new ApplicationBuilder(keyedServiceProvider);
builder.UseMiddleware(typeof(MiddlewareKeyedConstructorInjectInvoke));
var app = builder.Build();
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeWithOutAndRefThrows()
{
Expand Down Expand Up @@ -274,6 +334,54 @@ private class DummyServiceProvider : IServiceProvider
}
}

private class DummyKeyedServiceProvider : IKeyedServiceProvider
{
private readonly Dictionary<object, Tuple<Type, object>> _services = new Dictionary<object, Tuple<Type, object>>();

public DummyKeyedServiceProvider()
{

}

public void AddKeyedService(object key, Type type, object value) => _services[key] = new Tuple<Type, object>(type, value);

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (_services.TryGetValue(serviceKey!, out var value))
{
return value.Item2;
}

return null;
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
var service = GetKeyedService(serviceType, serviceKey);

if (service == null)
{
throw new InvalidOperationException($"No service for type '{serviceType}' has been registered.");
}

return service;
}

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}
return null;
}
}

public class MiddlewareInjectWithOutAndRefParams
{
public MiddlewareInjectWithOutAndRefParams(RequestDelegate next) { }
Expand All @@ -300,6 +408,20 @@ public MiddlewareInjectInvoke(RequestDelegate next) { }
public Task Invoke(HttpContext context, IServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvoke
{
public MiddlewareKeyedInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] IKeyedServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareKeyedConstructorInjectInvoke
{
public MiddlewareKeyedConstructorInjectInvoke(RequestDelegate next, [FromKeyedServices("test")] IKeyedServiceProvider provider) { }

public Task Invoke(HttpContext context) => Task.CompletedTask;
}

private class MiddlewareNoParametersStub
{
public MiddlewareNoParametersStub(RequestDelegate next) { }
Expand Down
25 changes: 23 additions & 2 deletions src/Http/Http.Abstractions/test/UsePathBaseExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public IServiceProvider ApplicationServices
public IFeatureCollection ServerFeatures => _wrappedBuilder.ServerFeatures;
public RequestDelegate Build() => _wrappedBuilder.Build();
public IApplicationBuilder New() => _wrappedBuilder.New();

}

[Theory]
Expand Down Expand Up @@ -238,6 +237,28 @@ private static HttpContext CreateRequest(string pathBase, string requestPath)

private static ApplicationBuilder CreateBuilder()
{
return new ApplicationBuilder(serviceProvider: null!);
return new ApplicationBuilder(new DummyServiceProvider());
}

private class DummyServiceProvider : IServiceProvider
{
private readonly Dictionary<Type, object> _services = new Dictionary<Type, object>();

public void AddService(Type type, object value) => _services[type] = value;

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}

return null;
}
}
}
Loading