Skip to content

Commit

Permalink
Added more include overloads. (#8002)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelstaib authored Feb 6, 2025
1 parent 77b8d12 commit 6878fe9
Show file tree
Hide file tree
Showing 14 changed files with 420 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,18 +102,15 @@ private static IDataLoader<TKey, Page<TValue>> WithInternal<TKey, TValue>(
/// <typeparam name="TValue">
/// The value type of the DataLoader.
/// </typeparam>
/// <typeparam name="TElement">
/// The element type of the projection.
/// </typeparam>
/// <returns>
/// Returns the DataLoader with the added projection.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Throws if the <paramref name="dataLoader"/> is <c>null</c>.
/// </exception>
public static IDataLoader<TKey, Page<TValue>> Select<TElement, TKey, TValue>(
public static IDataLoader<TKey, Page<TValue>> Select<TKey, TValue>(
this IDataLoader<TKey, Page<TValue>> dataLoader,
Expression<Func<TElement, TElement>>? selector)
Expression<Func<TValue, TValue>>? selector)
where TKey : notnull
{
if (dataLoader is null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ public static IDataLoader<TKey, List<TValue>> Select<TKey, TValue>(
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <typeparam name="TResult">
/// The selector result type.
/// </typeparam>
/// <returns>
/// Returns the DataLoader with the property included.
/// </returns>
Expand All @@ -179,10 +182,144 @@ public static IDataLoader<TKey, List<TValue>> Select<TKey, TValue>(
/// <exception cref="ArgumentException">
/// Throws if the include selector is not a property selector.
/// </exception>
public static IDataLoader<TKey, TValue> Include<TKey, TValue>(
public static IDataLoader<TKey, TValue> Include<TKey, TValue, TResult>(
this IDataLoader<TKey, TValue> dataLoader,
Expression<Func<TValue, object?>> includeSelector)
Expression<Func<TValue, TResult>> includeSelector)
where TKey : notnull
{
AssertIncludePossible(dataLoader, includeSelector);

var context = dataLoader.GetOrSetState(
DataLoaderStateKeys.Selector,
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}

/// <summary>
/// Includes a property in the query.
/// </summary>
/// <param name="dataLoader">
/// The DataLoader to include the property in.
/// </param>
/// <param name="includeSelector">
/// The property selector.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <typeparam name="TResult">
/// The selector result type.
/// </typeparam>
/// <returns>
/// Returns the DataLoader with the property included.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Throws if <paramref name="dataLoader"/> is <c>null</c>.
/// </exception>
/// <exception cref="ArgumentException">
/// Throws if the include selector is not a property selector.
/// </exception>
public static IDataLoader<TKey, Page<TValue>> Include<TKey, TValue, TResult>(
this IDataLoader<TKey, Page<TValue>> dataLoader,
Expression<Func<TValue, TResult>> includeSelector)
where TKey : notnull
{
AssertIncludePossible(dataLoader, includeSelector);

var context = dataLoader.GetOrSetState(
DataLoaderStateKeys.Selector,
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}

/// <summary>
/// Includes a property in the query.
/// </summary>
/// <param name="dataLoader">
/// The DataLoader to include the property in.
/// </param>
/// <param name="includeSelector">
/// The property selector.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <typeparam name="TResult">
/// The selector result type.
/// </typeparam>
/// <returns>
/// Returns the DataLoader with the property included.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Throws if <paramref name="dataLoader"/> is <c>null</c>.
/// </exception>
/// <exception cref="ArgumentException">
/// Throws if the include selector is not a property selector.
/// </exception>
public static IDataLoader<TKey, List<TValue>> Include<TKey, TValue, TResult>(
this IDataLoader<TKey, List<TValue>> dataLoader,
Expression<Func<TValue, TResult>> includeSelector)
where TKey : notnull
{
AssertIncludePossible(dataLoader, includeSelector);

var context = dataLoader.GetOrSetState(
DataLoaderStateKeys.Selector,
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}

/// <summary>
/// Includes a property in the query.
/// </summary>
/// <param name="dataLoader">
/// The DataLoader to include the property in.
/// </param>
/// <param name="includeSelector">
/// The property selector.
/// </param>
/// <typeparam name="TKey">
/// The key type.
/// </typeparam>
/// <typeparam name="TValue">
/// The value type.
/// </typeparam>
/// <typeparam name="TResult">
/// The selector result type.
/// </typeparam>
/// <returns>
/// Returns the DataLoader with the property included.
/// </returns>
/// <exception cref="ArgumentNullException">
/// Throws if <paramref name="dataLoader"/> is <c>null</c>.
/// </exception>
/// <exception cref="ArgumentException">
/// Throws if the include selector is not a property selector.
/// </exception>
public static IDataLoader<TKey, TValue[]> Include<TKey, TValue, TResult>(
this IDataLoader<TKey, TValue[]> dataLoader,
Expression<Func<TValue, TResult>> includeSelector)
where TKey : notnull
{
AssertIncludePossible(dataLoader, includeSelector);

var context = dataLoader.GetOrSetState(
DataLoaderStateKeys.Selector,
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}

private static void AssertIncludePossible(IDataLoader dataLoader, Expression includeSelector)
{
if (dataLoader is null)
{
Expand Down Expand Up @@ -214,11 +351,5 @@ public static IDataLoader<TKey, TValue> Include<TKey, TValue>(
"The include selector must be a property selector.",
nameof(includeSelector));
}

var context = dataLoader.GetOrSetState(
DataLoaderStateKeys.Selector,
_ => new DefaultSelectorBuilder());
context.Add(Rewrite(includeSelector));
return dataLoader;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ private static Expression CombineConditionalAndMemberInit(
}

public static Expression<Func<TRoot, TRoot>> Rewrite<TRoot, TKey>(
Expression<Func<TRoot, TKey?>> selector)
Expression<Func<TRoot, TKey>> selector)
{
var parameter = selector.Parameters[0];
var bindings = new List<MemberBinding>();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Linq.Expressions;
using HotChocolate.Execution.Processing;

// ReSharper disable once CheckNamespace
Expand Down
144 changes: 144 additions & 0 deletions src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DataLoaderTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
using GreenDonut;
using GreenDonut.Data;
using HotChocolate.Data.Data;
using HotChocolate.Data.Migrations;
using HotChocolate.Data.Models;
using HotChocolate.Data.Services;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using Squadron;

namespace HotChocolate.Data;

[Collection(PostgresCacheCollectionFixture.DefinitionName)]
public class DataLoaderTests(PostgreSqlResource resource)
{
[Fact]
public async Task Include_On_List_Results()
{
// arrange
using var interceptor = new TestQueryInterceptor();
using var cts = new CancellationTokenSource(2000);
await using var services = CreateServer();
await using var scope = services.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<CatalogContext>();
var seeder = scope.ServiceProvider.GetRequiredService<IDbSeeder<CatalogContext>>();
await context.Database.EnsureCreatedAsync(cts.Token);
await seeder.SeedAsync(context);

var productByBrand = scope.ServiceProvider.GetRequiredService<IProductListByBrandDataLoader>();

// act
var products = await productByBrand
.Select(t => new Product { BrandId = t.BrandId, Name = t.Name })
.Include(t => t.Price)
.LoadRequiredAsync(1, cts.Token);

// assert
Assert.Equal(10, products.Count);
interceptor.MatchSnapshot();
}

[Fact]
public async Task Include_On_Array_Results()
{
// arrange
using var interceptor = new TestQueryInterceptor();
using var cts = new CancellationTokenSource(2000);
await using var services = CreateServer();
await using var scope = services.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<CatalogContext>();
var seeder = scope.ServiceProvider.GetRequiredService<IDbSeeder<CatalogContext>>();
await context.Database.EnsureCreatedAsync(cts.Token);
await seeder.SeedAsync(context);

var productByBrand = scope.ServiceProvider.GetRequiredService<IProductArrayByBrandDataLoader>();

// act
var products = await productByBrand
.Select(t => new Product { BrandId = t.BrandId, Name = t.Name })
.Include(t => t.Price)
.LoadRequiredAsync(1, cts.Token);

// assert
Assert.Equal(10, products.Length);
interceptor.MatchSnapshot();
}

[Fact]
public async Task Include_On_Page_Results()
{
// arrange
using var interceptor = new TestQueryInterceptor();
using var cts = new CancellationTokenSource(2000);
await using var services = CreateServer();
await using var scope = services.CreateAsyncScope();
await using var context = scope.ServiceProvider.GetRequiredService<CatalogContext>();
var seeder = scope.ServiceProvider.GetRequiredService<IDbSeeder<CatalogContext>>();
await context.Database.EnsureCreatedAsync(cts.Token);
await seeder.SeedAsync(context);

var productByBrand = scope.ServiceProvider.GetRequiredService<IProductsByBrandDataLoader>();

// act
var products = await productByBrand
.With(new PagingArguments { First = 5 })
.Select(t => new Product { BrandId = t.BrandId, Name = t.Name })
.Include(t => t.Price)
.LoadRequiredAsync(1, cts.Token);

// assert
Assert.Equal(5, products.Items.Length);
interceptor.MatchSnapshot();
}

private ServiceProvider CreateServer()
{
var db = "db_" + Guid.NewGuid().ToString("N");
var connectionString = resource.GetConnectionString(db);

var services = new ServiceCollection();

services
.AddLogging()
.AddDbContext<CatalogContext>(c => c.UseNpgsql(connectionString));

services
.AddSingleton<BrandService>()
.AddSingleton<ProductService>();

services
.AddGraphQLServer()
.AddCustomTypes()
.AddGlobalObjectIdentification()
.AddPagingArguments()
.AddFiltering()
.AddSorting()
.ModifyRequestOptions(o => o.IncludeExceptionDetails = true);

services.AddSingleton<IDbSeeder<CatalogContext>, CatalogContextSeed>();

return services.BuildServiceProvider();
}
}

file static class Extensions
{
public static void MatchSnapshot(
this TestQueryInterceptor queryInterceptor)
{
#if NET9_0_OR_GREATER
var snapshot = Snapshot.Create();
#else
var snapshot = Snapshot.Create(postFix: "_net_8_0");
#endif

for (var i = 0; i < queryInterceptor.Queries.Count; i++)
{
var sql = queryInterceptor.Queries[i];
snapshot.Add(sql, $"Query {i + 1}", MarkdownLanguages.Sql);
}

snapshot.MatchMarkdown();
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using GreenDonut.Data;
using HotChocolate.Data.Data;
using HotChocolate.Data.Migrations;
using HotChocolate.Data.Services;
Expand Down Expand Up @@ -301,17 +300,4 @@ private static void MatchSnapshot(

snapshot.MatchMarkdown();
}

private class TestQueryInterceptor : PagingQueryInterceptor
{
public List<string> Queries { get; } = new();

public override void OnBeforeExecute<T>(IQueryable<T> query)
{
lock(Queries)
{
Queries.Add(query.ToQueryString());
}
}
}
}
Loading

0 comments on commit 6878fe9

Please sign in to comment.