Skip to content

Fix for issue #54500: Middleware keyed dependency injection #55722

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
using System.Runtime.CompilerServices;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.Internal;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Builder;

Expand All @@ -21,6 +21,7 @@ public static class UseMiddlewareExtensions
internal const string InvokeAsyncMethodName = "InvokeAsync";

private static readonly MethodInfo GetServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetService), BindingFlags.NonPublic | BindingFlags.Static)!;
private static readonly MethodInfo GetKeyedServiceInfo = typeof(UseMiddlewareExtensions).GetMethod(nameof(GetKeyedService), BindingFlags.NonPublic | BindingFlags.Static)!;

// We're going to keep all public constructors and public methods on middleware
private const DynamicallyAccessedMemberTypes MiddlewareAccessibility =
Expand Down Expand Up @@ -215,13 +216,55 @@ private static Func<T, HttpContext, IServiceProvider, Task> ReflectionFallback<T
methodArguments[0] = context;
for (var i = 1; i < parameters.Length; i++)
{
methodArguments[i] = GetService(serviceProvider, parameters[i].ParameterType, methodInfo.DeclaringType!);
var parameter = parameters[i];

var hasServiceKey = TryGetServiceKey(parameter, out object? key);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to happen on every single request, probably worth moving outside of the factory method since the data is the same for every single call.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Brennan,

Thanks for your valuable feedback.

I've made some changes to reflect the feedback. I now precompute and cache the results for each parameter.

Copy link
Contributor Author

@NicoBrabers NicoBrabers May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BrennanConroy, I could eliminate the condition check (hasServiceKey ? GetKeyedService(...) : GetService(...)) as wel by compiling the method calls into delegates, and calling to the delegate directly within the factory. We'd probably be talking about a microsecond though. The code might become slightly less readable though.

What's your opinion on the matter?

var parameterType = parameter.ParameterType;
var declaringType = methodInfo.DeclaringType;

methodArguments[i] = hasServiceKey ? GetKeyedService(serviceProvider, key!, parameterType, declaringType!) : GetService(serviceProvider, parameterType, declaringType!);
}

return (Task)methodInfo.Invoke(middleware, BindingFlags.DoNotWrapExceptions, binder: null, methodArguments, culture: null)!;
};
}

private static bool TryGetServiceKey(ParameterInfo parameterInfo, out object? key)
{
if (parameterInfo.CustomAttributes != null)
{
foreach (var attribute in parameterInfo.GetCustomAttributes(true))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameterInfo.OfType<FromKeyedServicesAttribute>().FirstOrDefault() might be more straightforward here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@captainsafia, @BrennanConroy, the changes within this pull-request target main, and thus the .NET 9.0 version. If I'd want the changes to be available in the latest .NET 8.x.x version as well, what would the appropriate steps for me to take?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@captainsafia, @BrennanConroy, any follow-up on my question above?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this is a backport candidate. This change is a feature, not fixing a major bug/regression introduced in 8.0.

{
if (attribute is FromKeyedServicesAttribute keyed)
{
key = keyed.Key;
return true;
}
}
}
key = null;
return false;
}

private static UnaryExpression GetMethodArgument(ParameterInfo parameter, ParameterExpression providerArg, Type parameterType, Type? declaringType)
{
var parameterTypeExpression = new List<Expression>() { providerArg };
var hasServiceKey = TryGetServiceKey(parameter, out object? key);

if (hasServiceKey)
{
parameterTypeExpression.Add(Expression.Constant(key, typeof(object)));
}

parameterTypeExpression.Add(Expression.Constant(parameterType, typeof(Type)));
parameterTypeExpression.Add(Expression.Constant(declaringType, typeof(Type)));

var getServiceCall = Expression.Call(hasServiceKey ? GetKeyedServiceInfo : GetServiceInfo, parameterTypeExpression);
var methodArgument = Expression.Convert(getServiceCall, parameterType);

return methodArgument;
}

private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>(MethodInfo methodInfo, ParameterInfo[] parameters)
{
Debug.Assert(RuntimeFeature.IsDynamicCodeSupported, "Use compiled expression when dynamic code is supported.");
Expand Down Expand Up @@ -262,21 +305,14 @@ private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>
methodArguments[0] = httpContextArg;
for (var i = 1; i < parameters.Length; i++)
{
var parameterType = parameters[i].ParameterType;
var parameter = parameters[i];
var parameterType = parameter.ParameterType;
if (parameterType.IsByRef)
{
throw new NotSupportedException(Resources.FormatException_InvokeDoesNotSupportRefOrOutParams(InvokeMethodName));
}

var parameterTypeExpression = new Expression[]
{
providerArg,
Expression.Constant(parameterType, typeof(Type)),
Expression.Constant(methodInfo.DeclaringType, typeof(Type))
};

var getServiceCall = Expression.Call(GetServiceInfo, parameterTypeExpression);
methodArguments[i] = Expression.Convert(getServiceCall, parameterType);
methodArguments[i] = GetMethodArgument(parameter, providerArg, parameterType, methodInfo.DeclaringType);
}

Expression middlewareInstanceArg = instanceArg;
Expand All @@ -294,12 +330,20 @@ private static Func<T, HttpContext, IServiceProvider, Task> CompileExpression<T>

private static object GetService(IServiceProvider sp, Type type, Type middleware)
{
var service = sp.GetService(type);
if (service == null)
var service = sp.GetService(type) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

private static object GetKeyedService(IServiceProvider sp, object key, Type type, Type middleware)
{
if (sp is IKeyedServiceProvider ksp)
{
throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));
var service = ksp.GetKeyedService(type, key) ?? throw new InvalidOperationException(Resources.FormatException_InvokeMiddlewareNoService(type, middleware));

return service;
}

return service;
throw new InvalidOperationException(Resources.Exception_KeyedServicesNotSupported);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ Microsoft.AspNetCore.Http.HttpResponse</Description>
<Reference Include="Microsoft.AspNetCore.Http.Features" />
<Reference Include="Microsoft.Net.Http.Headers" />
<Reference Include="Microsoft.Extensions.Logging.Abstractions" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />

<Compile Include="$(SharedSourceRoot)ActivatorUtilities\*.cs" />
<Compile Include="$(SharedSourceRoot)ParameterDefaultValue\*.cs" />
<Compile Include="$(SharedSourceRoot)PropertyHelper\**\*.cs" />
<Compile Include="$(SharedSourceRoot)\UrlDecoder\UrlDecoder.cs" Link="UrlDecoder.cs" />
Expand Down
6 changes: 6 additions & 0 deletions src/Http/Http.Abstractions/src/Resources.resx
Original file line number Diff line number Diff line change
Expand Up @@ -162,4 +162,10 @@
<data name="RouteValueDictionary_DuplicatePropertyName" xml:space="preserve">
<value>The type '{0}' defines properties '{1}' and '{2}' which differ only by casing. This is not supported by {3} which uses case-insensitive comparisons.</value>
</data>
<data name="Exception_KeyedServicesNotSupported" xml:space="preserve">
<value>This service provider doesn't support keyed services.</value>
</data>
<data name="Exception_NoServiceRegistered" xml:space="preserve">
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, there is this one too 😢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hehe, don't sweat it, should've checked it. Will run by all resources 👍

<value>No service for type '{0}' has been registered.</value>
</data>
</root>
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
<Reference Include="Microsoft.AspNetCore.Routing" />
<Reference Include="Microsoft.AspNetCore.TestHost" />
<Reference Include="Mono.TextTemplating" />
<Reference Include="Microsoft.Extensions.DependencyInjection.Abstractions" />
</ItemGroup>

<ItemGroup>
Expand Down
100 changes: 100 additions & 0 deletions src/Http/Http.Abstractions/test/UseMiddlewareTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http.Abstractions;
using Microsoft.Extensions.DependencyInjection;

namespace Microsoft.AspNetCore.Http;

Expand Down Expand Up @@ -130,6 +131,32 @@ public async Task UseMiddleware_ThrowsIfArgCantBeResolvedFromContainer()
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfKeyedArgCantBeResolvedFromContainer()
{
var builder = new ApplicationBuilder(new DummyKeyedServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvokeNoService));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.FormatException_InvokeMiddlewareNoService(
typeof(object),
typeof(MiddlewareKeyedInjectInvokeNoService)),
exception.Message);
}

[Fact]
public async Task UseMiddleware_ThrowsIfServiceProviderIsNotAIKeyedServiceProvider()
{
var builder = new ApplicationBuilder(new DummyServiceProvider());
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvokeNoService));
var app = builder.Build();
var exception = await Assert.ThrowsAsync<InvalidOperationException>(() => app(new DefaultHttpContext()));
Assert.Equal(
Resources.Exception_KeyedServicesNotSupported,
exception.Message);
}

[Fact]
public void UseMiddlewareWithInvokeArg()
{
Expand All @@ -139,6 +166,17 @@ public void UseMiddlewareWithInvokeArg()
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeKeyedArg()
{
var keyedServiceProvider = new DummyKeyedServiceProvider();
keyedServiceProvider.AddKeyedService("test", typeof(DummyKeyedServiceProvider), keyedServiceProvider);
var builder = new ApplicationBuilder(keyedServiceProvider);
builder.UseMiddleware(typeof(MiddlewareKeyedInjectInvoke));
var app = builder.Build();
app(new DefaultHttpContext());
}

[Fact]
public void UseMiddlewareWithInvokeWithOutAndRefThrows()
{
Expand Down Expand Up @@ -274,6 +312,54 @@ private class DummyServiceProvider : IServiceProvider
}
}

private class DummyKeyedServiceProvider : IKeyedServiceProvider
{
private readonly Dictionary<object, Tuple<Type, object>> _services = new Dictionary<object, Tuple<Type, object>>();

public DummyKeyedServiceProvider()
{

}

public void AddKeyedService(object key, Type type, object value) => _services[key] = new Tuple<Type, object>(type, value);

public object? GetKeyedService(Type serviceType, object? serviceKey)
{
if (_services.TryGetValue(serviceKey!, out var value))
{
return value.Item2;
}

return null;
}

public object GetRequiredKeyedService(Type serviceType, object? serviceKey)
{
var service = GetKeyedService(serviceType, serviceKey);

if (service == null)
{
throw new InvalidOperationException(Resources.FormatException_NoServiceRegistered(serviceType));
}

return service;
}

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}
return null;
}
}

public class MiddlewareInjectWithOutAndRefParams
{
public MiddlewareInjectWithOutAndRefParams(RequestDelegate next) { }
Expand All @@ -293,13 +379,27 @@ public MiddlewareInjectInvokeNoService(RequestDelegate next) { }
public Task Invoke(HttpContext context, object value) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvokeNoService
{
public MiddlewareKeyedInjectInvokeNoService(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] object value) => Task.CompletedTask;
}

private class MiddlewareInjectInvoke
{
public MiddlewareInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, IServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareKeyedInjectInvoke
{
public MiddlewareKeyedInjectInvoke(RequestDelegate next) { }

public Task Invoke(HttpContext context, [FromKeyedServices("test")] IKeyedServiceProvider provider) => Task.CompletedTask;
}

private class MiddlewareNoParametersStub
{
public MiddlewareNoParametersStub(RequestDelegate next) { }
Expand Down
24 changes: 22 additions & 2 deletions src/Http/Http.Abstractions/test/UsePathBaseExtensionsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ public IServiceProvider ApplicationServices
public IFeatureCollection ServerFeatures => _wrappedBuilder.ServerFeatures;
public RequestDelegate Build() => _wrappedBuilder.Build();
public IApplicationBuilder New() => _wrappedBuilder.New();

}

[Theory]
Expand Down Expand Up @@ -238,6 +237,27 @@ private static HttpContext CreateRequest(string pathBase, string requestPath)

private static ApplicationBuilder CreateBuilder()
{
return new ApplicationBuilder(serviceProvider: null!);
return new ApplicationBuilder(new DummyServiceProvider());
}

private class DummyServiceProvider : IServiceProvider
{
private readonly Dictionary<Type, object> _services = new Dictionary<Type, object>();

public void AddService(Type type, object value) => _services[type] = value;

public object? GetService(Type serviceType)
{
if (serviceType == typeof(IServiceProvider))
{
return this;
}

if (_services.TryGetValue(serviceType, out var value))
{
return value;
}
return null;
}
}
}