Skip to content

Commit

Permalink
Added support for n:m projections in DataLoader. (#7577)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib committed Oct 8, 2024
1 parent c964eb5 commit f5fc2f1
Show file tree
Hide file tree
Showing 7 changed files with 196 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/GreenDonut/src/Core/Projections/ExpressionHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ private static Expression CombineConditionalAndMemberInit(
return Expression.Condition(condition.Test, ifTrue, ifFalse);
}

public static Expression<Func<TRoot, TRoot>> Rewrite<TRoot>(
Expression<Func<TRoot, object?>> selector)
public static Expression<Func<TRoot, TRoot>> Rewrite<TRoot, TKey>(
Expression<Func<TRoot, TKey?>> selector)
{
var parameter = selector.Parameters[0];
var bindings = new List<MemberBinding>();
Expand Down
11 changes: 11 additions & 0 deletions src/GreenDonut/src/Core/Projections/KeyValueResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
namespace GreenDonut.Projections;

/// <summary>
/// This class is a helper that is used to project a key value pair.
/// </summary>
public sealed class KeyValueResult<TKey, TValue>
{
public TKey Key { get; set; } = default!;

public TValue Value { get; set; } = default!;
}
117 changes: 108 additions & 9 deletions src/GreenDonut/src/Core/Projections/SelectionDataLoaderExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ namespace GreenDonut.Projections;
#endif
public static class SelectionDataLoaderExtensions
{
private static readonly MethodInfo _selectMethod =
typeof(Enumerable)
.GetMethods()
.Where(m => m.Name == nameof(Enumerable.Select) && m.GetParameters().Length == 2)
.First(m => m.GetParameters()[1].ParameterType.GetGenericTypeDefinition() == typeof(Func<,>));

/// <summary>
/// Branches a DataLoader and applies a selector to load the data.
/// </summary>
Expand Down Expand Up @@ -65,7 +71,7 @@ static IDataLoader CreateBranch(
IDataLoader<TKey, TValue> dataLoader,
Expression<Func<TValue, TValue>> selector)
{
var branch = new SelectionDataLoader<TKey, TValue>(
var branch = new SelectionDataLoader<TKey, TValue>(
(DataLoaderBase<TKey, TValue>)dataLoader,
key);
var context = new DefaultSelectorBuilder<TValue>();
Expand Down Expand Up @@ -140,14 +146,14 @@ public static IDataLoader<TKey, TValue> Include<TKey, TValue>(
/// Applies the selector from the DataLoader state to a queryable.
/// </summary>
/// <param name="query">
/// The queryable to apply the selector to.
/// </param>
/// <param name="builder">
/// The selector builder.
/// The queryable to apply the selector to.
/// </param>
/// <param name="key">
/// The DataLoader key.
/// </param>
/// <param name="builder">
/// The selector builder.
/// </param>
/// <typeparam name="T">
/// The queryable type.
/// </typeparam>
Expand All @@ -157,10 +163,9 @@ public static IDataLoader<TKey, TValue> Include<TKey, TValue>(
/// <exception cref="ArgumentNullException">
/// Throws if <paramref name="query"/> is <c>null</c>.
/// </exception>
public static IQueryable<T> Select<T>(
this IQueryable<T> query,
ISelectorBuilder builder,
Expression<Func<T, object?>> key)
public static IQueryable<T> Select<T>(this IQueryable<T> query,
Expression<Func<T, object?>> key,
ISelectorBuilder builder)
{
if (query is null)
{
Expand All @@ -181,5 +186,99 @@ public static IQueryable<T> Select<T>(

return query;
}

/// <summary>
/// Applies the selector from the DataLoader state to a queryable.
/// </summary>
/// <param name="query">
/// The queryable to apply the selector to.
/// </param>
/// <param name="key">
/// The DataLoader key.
/// </param>
/// <param name="list">
/// The list selector.
/// </param>
/// <param name="elementSelector">
/// The element selector.
/// </param>
/// <typeparam name="T">
/// The queryable type.
/// </typeparam>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <returns>
/// Returns a selector query on which a key must be applied to fetch the data.
/// </returns>
public static IQueryable<KeyValueResult<TKey, IEnumerable<TValue>>> Select<T, TKey, TValue>(
this IQueryable<T> query,
Expression<Func<T, TKey?>> key,
Expression<Func<T, IEnumerable<TValue>>> list,
ISelectorBuilder elementSelector)
{
// we first create a new parameter expression for the root as we need
// a unified parameter for both expressions (key and list)
var parameter = Expression.Parameter(typeof(T), "root");

// next we replace the parameter within the key and list selectors with the
// unified parameter.
var rewrittenKey = ReplaceParameter(key, key.Parameters[0], parameter);
var rewrittenList = ReplaceParameter(list, list.Parameters[0], parameter);

// next we try to compile an element selector expression.
var elementSelectorExpr = elementSelector.TryCompile<TValue>();

// if we have a element selector to project properties on the list expression
// we will need to combine this into the list expression.
if (elementSelectorExpr is not null)
{
var selectMethod = _selectMethod.MakeGenericMethod(typeof(TValue), typeof(TValue));

list = Expression.Lambda<Func<T, IEnumerable<TValue>>>(
Expression.Call(
selectMethod,
rewrittenList.Body,
elementSelectorExpr),
parameter);
}

// finally we combine key and list expression into a single selector expression
var keyValueSelectorExpr = Expression.Lambda<Func<T, KeyValueResult<TKey, IEnumerable<TValue>>>>(
Expression.MemberInit(
Expression.New(typeof(KeyValueResult<TKey, IEnumerable<TValue>>)),
Expression.Bind(
typeof(KeyValueResult<TKey, IEnumerable<TValue>>).GetProperty(
nameof(KeyValueResult<TKey, IEnumerable<TValue>>.Key))!,
rewrittenKey.Body),
Expression.Bind(
typeof(KeyValueResult<TKey, IEnumerable<TValue>>).GetProperty(
nameof(KeyValueResult<TKey, IEnumerable<TValue>>.Value))!,
list.Body)),
parameter);

// lastly we apply the selector expression to the queryable.
return query.Select(keyValueSelectorExpr);
}

private static Expression<T> ReplaceParameter<T>(
Expression<T> expression,
ParameterExpression oldParameter,
ParameterExpression newParameter)
=> (Expression<T>)new ReplaceParameterVisitor(oldParameter, newParameter).Visit(expression);
}

file sealed class ReplaceParameterVisitor(
ParameterExpression oldParameter,
ParameterExpression newParameter)
: ExpressionVisitor
{
protected override Expression VisitParameter(ParameterExpression node)
=> node == oldParameter
? newParameter
: base.VisitParameter(node);
}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,36 @@ public async Task Brand_With_Default_Field_Over_Node()
.MatchMarkdownSnapshot();
}

[Fact]
public async Task Project_Key_To_Collection_Expression()
{
// Arrange
var queries = new List<string>();
var connectionString = CreateConnectionString();
await CatalogContext.SeedAsync(connectionString);

var services = new ServiceCollection()
.AddScoped(_ => queries)
.AddTransient(_ => new CatalogContext(connectionString))
.AddDataLoader(
sp => new ProductByBrandIdDataLoader(
sp,
sp.GetRequiredService<List<string>>(),
sp.GetRequiredService<IBatchScheduler>(),
sp.GetRequiredService<DataLoaderOptions>()))
.BuildServiceProvider();

// Act
await using var scope = services.CreateAsyncScope();
var dataLoader = scope.ServiceProvider.GetRequiredService<ProductByBrandIdDataLoader>();
await dataLoader.LoadAsync(1);

// Assert
Snapshot.Create()
.AddSql(queries)
.MatchMarkdownSnapshot();
}

public class Query
{
public async Task<Brand?> GetBrandByIdAsync(
Expand Down Expand Up @@ -930,7 +960,7 @@ protected override async Task<IReadOnlyDictionary<int, Brand>> LoadBatchAsync(

var query = catalogContext.Brands
.Where(t => keys.Contains(t.Id))
.Select(context.GetSelector(), b => b.Id);
.Select(b => b.Id, context.GetSelector());

lock (_queries)
{
Expand Down Expand Up @@ -959,7 +989,7 @@ protected override async Task<IReadOnlyDictionary<int, Product>> LoadBatchAsync(

var query = catalogContext.Products
.Where(t => keys.Contains(t.Id))
.Select(context.GetSelector(), b => b.Id);
.Select(b => b.Id, context.GetSelector());

lock (queries)
{
Expand All @@ -971,6 +1001,37 @@ protected override async Task<IReadOnlyDictionary<int, Product>> LoadBatchAsync(
return x;
}
}

public class ProductByBrandIdDataLoader(
IServiceProvider services,
List<string> queries,
IBatchScheduler batchScheduler,
DataLoaderOptions options)
: StatefulBatchDataLoader<int, Product[]>(batchScheduler, options)
{
protected override async Task<IReadOnlyDictionary<int, Product[]>> LoadBatchAsync(
IReadOnlyList<int> keys,
DataLoaderFetchContext<Product[]> context,
CancellationToken cancellationToken)
{
var catalogContext = services.GetRequiredService<CatalogContext>();
var selector = new DefaultSelectorBuilder<Product>();
selector.Add<Product>(t => new Product { Name = t.Name });

var query = catalogContext.Brands
.Where(t => keys.Contains(t.Id))
.Select(t => t.Id, t => t.Products, selector);

lock (queries)
{
queries.Add(query.ToQueryString());
}

var x = await query.ToDictionaryAsync(t => t.Key, t => t.Value.ToArray(), cancellationToken);

return x;
}
}
}

file static class Extensions
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Project_Key_To_Collection_Expression

```text
-- @__keys_0={ '1' } (DbType = Object)
SELECT b."Id", p."Name", p."Id"
FROM "Brands" AS b
LEFT JOIN "Products" AS p ON b."Id" = p."BrandId"
WHERE b."Id" = ANY (@__keys_0)
ORDER BY b."Id"
```
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public static async Task<IDictionary<int, Author>> GetAuthorById(
IQueryable<Author> query,
ISelectorBuilder selector,
CancellationToken ct)
=> await Task.FromResult(query.Select(selector, t => t.Id).ToDictionary(t => t.Id));
=> await Task.FromResult(query.Select(t => t.Id, selector).ToDictionary(t => t.Id));

[DataLoader]
public static async Task<IDictionary<int, Author>> GetAuthorWithPagingById(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ protected override async Task<IReadOnlyDictionary<int, Page<Product>>> LoadBatch

return await catalogContext.Products
.Where(t => keys.Contains(t.BrandId))
.Select(context.GetSelector(), b => b.BrandId)
.Select(b => b.BrandId, context.GetSelector())
.OrderBy(t => t.Name).ThenBy(t => t.Id)
.ToBatchPageAsync(t => t.BrandId, pagingArgs, cancellationToken);
}
Expand Down

0 comments on commit f5fc2f1

Please sign in to comment.