Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System.Collections;
using System.Collections.Concurrent;
using System.Diagnostics.CodeAnalysis;
using System.Linq.Expressions;
using System.Reflection;
Expand All @@ -22,14 +23,21 @@ public sealed class ProjectableExpressionReplacer : ExpressionVisitor
private IEntityType? _entityType;

// Extract MethodInfo via expression trees (trim-safe; computed once per AppDomain)
private static readonly MethodInfo _select =
private readonly static MethodInfo _select =
((MethodCallExpression)((Expression<Func<IQueryable<object>, IQueryable<object>>>)
(q => q.Select(x => x))).Body).Method.GetGenericMethodDefinition();

private static readonly MethodInfo _where =
private readonly static MethodInfo _where =
((MethodCallExpression)((Expression<Func<IQueryable<object>, IQueryable<object>>>)
(q => q.Where(x => true))).Body).Method.GetGenericMethodDefinition();

// Static caches — keyed by CLR type, shared across all instances for the AppDomain lifetime.
// Safe because: type metadata is immutable, LambdaExpression trees are immutable, MethodInfo is stable.
private readonly static ConcurrentDictionary<Type, bool> _compilerGeneratedClosureCache = new();
private readonly static ConcurrentDictionary<Type, PropertyInfo[]> _projectablePropertiesCache = new();
private readonly static ConcurrentDictionary<Type, MethodInfo> _closedSelectCache = new();
private readonly static ConcurrentDictionary<Type, MethodInfo> _closedWhereCache = new();
Comment thread
PhenX marked this conversation as resolved.
Outdated

public ProjectableExpressionReplacer(IProjectionExpressionResolver projectionExpressionResolver, bool trackByDefault = false)
{
_trackingByDefault = trackByDefault;
Expand Down Expand Up @@ -84,7 +92,6 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
// // case of a first()
// return obj.MyMap(x => new Obj {});
// }


if (call.Method.ReturnType.IsAssignableTo(typeof(IQueryable)))
{
Expand All @@ -101,7 +108,9 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
// before the query become executed by EF (before the .First()), we rewrite the .First(where)
// as .Where(where).Select(x => ...).First()

var where = Expression.Call(null, _where.MakeGenericMethod(_entityType.ClrType), call.Arguments);
var whereMethod = _closedWhereCache.GetOrAdd(
_entityType.ClrType, static (t, m) => m.MakeGenericMethod(t), _where);
var where = Expression.Call(null, whereMethod, call.Arguments);
// The call instance is based on the wrong polymorphied method.
var first = call.Method.DeclaringType?.GetMethods()
.FirstOrDefault(x => x.Name == call.Method.Name && x.GetParameters().Length == 1);
Expand Down Expand Up @@ -138,18 +147,27 @@ bool TryGetReflectedExpression(MemberInfo memberInfo, [NotNullWhen(true)] out La
protected override Expression VisitMethodCall(MethodCallExpression node)
{
// Replace MethodGroup arguments with their reflected expressions.
// Note that MethodCallExpression.Update returns the original Expression if argument values have not changed.
node = node.Update(node.Object, node.Arguments.Select(arg => arg switch {
UnaryExpression {
NodeType: ExpressionType.Convert,
Operand: MethodCallExpression {
NodeType: ExpressionType.Call,
Method: { Name: nameof(MethodInfo.CreateDelegate), DeclaringType.Name: nameof(MethodInfo) },
Object: ConstantExpression { Value: MethodInfo methodInfo }
}
} => TryGetReflectedExpression(methodInfo, out var expressionArg) ? expressionArg : arg,
_ => arg
}));
// No-alloc fast-path: scan args without allocating; only copy the array and call
// Update() when a replacement is actually found (method-group arguments are rare).
Expression[]? updatedArgs = null;
for (var i = 0; i < node.Arguments.Count; i++)
{
if (node.Arguments[i] is UnaryExpression {
NodeType: ExpressionType.Convert,
Operand: MethodCallExpression {
NodeType: ExpressionType.Call,
Method: { Name: nameof(MethodInfo.CreateDelegate), DeclaringType.Name: nameof(MethodInfo) },
Object: ConstantExpression { Value: MethodInfo capturedMethodInfo }
}
} && TryGetReflectedExpression(capturedMethodInfo, out var expressionArg))
{
(updatedArgs ??= [.. node.Arguments])[i] = expressionArg;
}
}
if (updatedArgs is not null)
{
node = node.Update(node.Object, updatedArgs);
}

// Get the overriding methodInfo based on te type of the received of this expression
var methodInfo = node.Object?.Type.GetConcreteMethod(node.Method) ?? node.Method;
Expand All @@ -172,7 +190,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)
{
for (var parameterIndex = 0; parameterIndex < reflectedExpression.Parameters.Count; parameterIndex++)
{
var parameterExpession = reflectedExpression.Parameters[parameterIndex];
var parameterExpression = reflectedExpression.Parameters[parameterIndex];
var mappedArgumentExpression = (parameterIndex, node.Object) switch {
(0, not null) => node.Object,
(_, not null) => node.Arguments[parameterIndex - 1],
Expand All @@ -181,7 +199,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node)

if (mappedArgumentExpression is not null)
{
_expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpession, mappedArgumentExpression);
_expressionArgumentReplacer.ParameterArgumentMapping.Add(parameterExpression, mappedArgumentExpression);
}
}

Expand Down Expand Up @@ -232,19 +250,31 @@ protected override Expression VisitMember(MemberExpression node)
{
// Evaluate captured variables in closures that contain EF queries to inline them into the main query
if (node.Expression is ConstantExpression constant &&
constant.Type.Attributes.HasFlag(TypeAttributes.NestedPrivate) &&
Attribute.IsDefined(constant.Type, typeof(CompilerGeneratedAttribute), inherit: true))
IsCompilerGeneratedClosure(constant.Type))
{
try
{
var value = Expression
.Lambda<Func<object>>(Expression.Convert(node, typeof(object)))
.Compile()
.Invoke();
// Cheap type check first: only call GetValue() when the member could
// possibly hold an IQueryable. FieldType / PropertyType are free
// property reads on an already-materialised MemberInfo object.
var memberType = node.Member switch {
FieldInfo field => field.FieldType,
PropertyInfo prop => prop.PropertyType,
_ => null
};

if (value is IQueryable queryable && ReferenceEquals(queryable.Provider, _currentQueryProvider))
if (memberType is not null && typeof(IQueryable).IsAssignableFrom(memberType))
{
return Visit(queryable.Expression);
var value = node.Member switch {
FieldInfo field => field.GetValue(constant.Value),
PropertyInfo prop => prop.GetValue(constant.Value),
_ => null
};

if (value is IQueryable queryable && ReferenceEquals(queryable.Provider, _currentQueryProvider))
{
return Visit(queryable.Expression);
}
}
Comment thread
PhenX marked this conversation as resolved.
Outdated
}
catch
Expand Down Expand Up @@ -275,16 +305,10 @@ PropertyInfo property when nodeExpression is not null
var updatedBody = _expressionArgumentReplacer.Visit(reflectedExpression.Body);
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();

return base.Visit(
updatedBody
);
}
else
{
return base.Visit(
reflectedExpression.Body
);
return base.Visit(updatedBody);
}

return base.Visit(reflectedExpression.Body);
}

return base.VisitMember(node);
Expand All @@ -303,12 +327,13 @@ protected override Expression VisitExtension(Expression node)

private Expression _AddProjectableSelect(Expression node, IEntityType entityType)
{
var projectableProperties = entityType.ClrType.GetProperties()
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false))
.Where(x => x.CanWrite)
.ToList();
var projectableProperties = _projectablePropertiesCache.GetOrAdd(
entityType.ClrType,
static t => t.GetProperties()
.Where(x => x.IsDefined(typeof(ProjectableAttribute), false) && x.CanWrite)
.ToArray());

if (!projectableProperties.Any())
if (projectableProperties.Length == 0)
{
return node;
}
Expand All @@ -327,7 +352,8 @@ private Expression _AddProjectableSelect(Expression node, IEntityType entityType
.Where(x => projectableProperties.All(y => x.Name != y.Name && x.Name != $"<{y.Name}>k__BackingField"));

// Replace db.Entities to db.Entities.Select(x => new Entity { Property1 = x.Property1, Rewritted = rewrittedProperty })
var select = _select.MakeGenericMethod(entityType.ClrType, entityType.ClrType);
var select = _closedSelectCache.GetOrAdd(
entityType.ClrType, static (t, m) => m.MakeGenericMethod(t, t), _select);
var xParam = Expression.Parameter(entityType.ClrType);
return Expression.Call(
null,
Expand All @@ -354,5 +380,12 @@ private Expression _GetAccessor(PropertyInfo property, ParameterExpression para)
_expressionArgumentReplacer.ParameterArgumentMapping.Clear();
return base.Visit(updatedBody);
}

private static bool IsCompilerGeneratedClosure(Type type) =>
// TypeAttributes.NestedPrivate is a cheap flag check that rules out most types before
// touching the attribute cache.
type.Attributes.HasFlag(TypeAttributes.NestedPrivate) &&
_compilerGeneratedClosureCache.GetOrAdd(type, static t =>
Attribute.IsDefined(t, typeof(CompilerGeneratedAttribute), inherit: true));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 1 AND [e0].[Id] <= 5 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT [e].[Id], (
SELECT COUNT(*)
FROM [Entity] AS [e0]
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT [e].[Id], (
SELECT COUNT(*)
FROM [Entity] AS [e0]
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT [e].[Id], (
SELECT COUNT(*)
FROM [Entity] AS [e0]
WHERE [e0].[Id] * 2 > 4) AS [SubsetCount]
FROM [Entity] AS [e]
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @lowerBound int = 3;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @lowerBound AND [e].[Id] <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__lowerBound_0 int = 3;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lowerBound_0 AND [e].[Id] <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__lowerBound_0 int = 3;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lowerBound_0 AND [e].[Id] <= 10
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DECLARE @minCount int = 1;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE ([e].[Id] >= @minCount AND [e].[Id] <= 50) OR EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 10 AND [e0].[Id] <= 100 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DECLARE @__minCount_0 int = 1;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE ([e].[Id] >= @__minCount_0 AND [e].[Id] <= 50) OR EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 10 AND [e0].[Id] <= 100 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
DECLARE @__minCount_0 int = 1;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE ([e].[Id] >= @__minCount_0 AND [e].[Id] <= 50) OR EXISTS (
SELECT 1
FROM [Entity] AS [e0]
WHERE [e0].[Id] >= 10 AND [e0].[Id] <= 100 AND [e0].[Id] = [e].[Id])
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DECLARE @lower int = 2;
DECLARE @upper int = 8;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @lower AND [e].[Id] <= @upper
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DECLARE @__lower_0 int = 2;
DECLARE @__upper_1 int = 8;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lower_0 AND [e].[Id] <= @__upper_1
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DECLARE @__lower_0 int = 2;
DECLARE @__upper_1 int = 8;

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Id] >= @__lower_0 AND [e].[Id] <= @__upper_1
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @targetName nvarchar(4000) = N'Alice';

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Name] = @targetName
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__targetName_0 nvarchar(4000) = N'Alice';

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Name] = @__targetName_0
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DECLARE @__targetName_0 nvarchar(4000) = N'Alice';

SELECT [e].[Id], [e].[Name]
FROM [Entity] AS [e]
WHERE [e].[Name] = @__targetName_0
Loading
Loading