diff --git a/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutPaginationBatchingDataLoaderExtensions.cs b/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutPaginationBatchingDataLoaderExtensions.cs index 0844e9be0db..cb6b0aec82c 100644 --- a/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutPaginationBatchingDataLoaderExtensions.cs +++ b/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutPaginationBatchingDataLoaderExtensions.cs @@ -102,18 +102,15 @@ private static IDataLoader> WithInternal( /// /// The value type of the DataLoader. /// - /// - /// The element type of the projection. - /// /// /// Returns the DataLoader with the added projection. /// /// /// Throws if the is null. /// - public static IDataLoader> Select( + public static IDataLoader> Select( this IDataLoader> dataLoader, - Expression>? selector) + Expression>? selector) where TKey : notnull { if (dataLoader is null) diff --git a/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutSelectionDataLoaderExtensions.cs b/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutSelectionDataLoaderExtensions.cs index 9f8260fe7cb..fbe35e7c976 100644 --- a/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutSelectionDataLoaderExtensions.cs +++ b/src/GreenDonut/src/GreenDonut.Data/Extensions/GreenDonutSelectionDataLoaderExtensions.cs @@ -170,6 +170,9 @@ public static IDataLoader> Select( /// /// The value type. /// + /// + /// The selector result type. + /// /// /// Returns the DataLoader with the property included. /// @@ -179,10 +182,144 @@ public static IDataLoader> Select( /// /// Throws if the include selector is not a property selector. /// - public static IDataLoader Include( + public static IDataLoader Include( this IDataLoader dataLoader, - Expression> includeSelector) + Expression> includeSelector) + where TKey : notnull + { + AssertIncludePossible(dataLoader, includeSelector); + + var context = dataLoader.GetOrSetState( + DataLoaderStateKeys.Selector, + _ => new DefaultSelectorBuilder()); + context.Add(Rewrite(includeSelector)); + return dataLoader; + } + + /// + /// Includes a property in the query. + /// + /// + /// The DataLoader to include the property in. + /// + /// + /// The property selector. + /// + /// + /// The key type. + /// + /// + /// The value type. + /// + /// + /// The selector result type. + /// + /// + /// Returns the DataLoader with the property included. + /// + /// + /// Throws if is null. + /// + /// + /// Throws if the include selector is not a property selector. + /// + public static IDataLoader> Include( + this IDataLoader> dataLoader, + Expression> includeSelector) where TKey : notnull + { + AssertIncludePossible(dataLoader, includeSelector); + + var context = dataLoader.GetOrSetState( + DataLoaderStateKeys.Selector, + _ => new DefaultSelectorBuilder()); + context.Add(Rewrite(includeSelector)); + return dataLoader; + } + + /// + /// Includes a property in the query. + /// + /// + /// The DataLoader to include the property in. + /// + /// + /// The property selector. + /// + /// + /// The key type. + /// + /// + /// The value type. + /// + /// + /// The selector result type. + /// + /// + /// Returns the DataLoader with the property included. + /// + /// + /// Throws if is null. + /// + /// + /// Throws if the include selector is not a property selector. + /// + public static IDataLoader> Include( + this IDataLoader> dataLoader, + Expression> includeSelector) + where TKey : notnull + { + AssertIncludePossible(dataLoader, includeSelector); + + var context = dataLoader.GetOrSetState( + DataLoaderStateKeys.Selector, + _ => new DefaultSelectorBuilder()); + context.Add(Rewrite(includeSelector)); + return dataLoader; + } + + /// + /// Includes a property in the query. + /// + /// + /// The DataLoader to include the property in. + /// + /// + /// The property selector. + /// + /// + /// The key type. + /// + /// + /// The value type. + /// + /// + /// The selector result type. + /// + /// + /// Returns the DataLoader with the property included. + /// + /// + /// Throws if is null. + /// + /// + /// Throws if the include selector is not a property selector. + /// + public static IDataLoader Include( + this IDataLoader dataLoader, + Expression> includeSelector) + where TKey : notnull + { + AssertIncludePossible(dataLoader, includeSelector); + + var context = dataLoader.GetOrSetState( + DataLoaderStateKeys.Selector, + _ => new DefaultSelectorBuilder()); + context.Add(Rewrite(includeSelector)); + return dataLoader; + } + + private static void AssertIncludePossible(IDataLoader dataLoader, Expression includeSelector) { if (dataLoader is null) { @@ -214,11 +351,5 @@ public static IDataLoader Include( "The include selector must be a property selector.", nameof(includeSelector)); } - - var context = dataLoader.GetOrSetState( - DataLoaderStateKeys.Selector, - _ => new DefaultSelectorBuilder()); - context.Add(Rewrite(includeSelector)); - return dataLoader; } } diff --git a/src/GreenDonut/src/GreenDonut.Data/Internal/ExpressionHelpers.cs b/src/GreenDonut/src/GreenDonut.Data/Internal/ExpressionHelpers.cs index 0a065ca869e..71838d82bb5 100644 --- a/src/GreenDonut/src/GreenDonut.Data/Internal/ExpressionHelpers.cs +++ b/src/GreenDonut/src/GreenDonut.Data/Internal/ExpressionHelpers.cs @@ -180,7 +180,7 @@ private static Expression CombineConditionalAndMemberInit( } public static Expression> Rewrite( - Expression> selector) + Expression> selector) { var parameter = selector.Parameters[0]; var bindings = new List(); diff --git a/src/HotChocolate/Core/src/Execution.Projections/Extensions/HotChocolateExecutionDataLoaderExtensions.cs b/src/HotChocolate/Core/src/Execution.Projections/Extensions/HotChocolateExecutionDataLoaderExtensions.cs index eb1caaf00f8..77321628e47 100644 --- a/src/HotChocolate/Core/src/Execution.Projections/Extensions/HotChocolateExecutionDataLoaderExtensions.cs +++ b/src/HotChocolate/Core/src/Execution.Projections/Extensions/HotChocolateExecutionDataLoaderExtensions.cs @@ -1,3 +1,4 @@ +using System.Linq.Expressions; using HotChocolate.Execution.Processing; // ReSharper disable once CheckNamespace diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DataLoaderTests.cs b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DataLoaderTests.cs new file mode 100644 index 00000000000..22eef667bb1 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/DataLoaderTests.cs @@ -0,0 +1,144 @@ +using GreenDonut; +using GreenDonut.Data; +using HotChocolate.Data.Data; +using HotChocolate.Data.Migrations; +using HotChocolate.Data.Models; +using HotChocolate.Data.Services; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; +using Squadron; + +namespace HotChocolate.Data; + +[Collection(PostgresCacheCollectionFixture.DefinitionName)] +public class DataLoaderTests(PostgreSqlResource resource) +{ + [Fact] + public async Task Include_On_List_Results() + { + // arrange + using var interceptor = new TestQueryInterceptor(); + using var cts = new CancellationTokenSource(2000); + await using var services = CreateServer(); + await using var scope = services.CreateAsyncScope(); + await using var context = scope.ServiceProvider.GetRequiredService(); + var seeder = scope.ServiceProvider.GetRequiredService>(); + await context.Database.EnsureCreatedAsync(cts.Token); + await seeder.SeedAsync(context); + + var productByBrand = scope.ServiceProvider.GetRequiredService(); + + // act + var products = await productByBrand + .Select(t => new Product { BrandId = t.BrandId, Name = t.Name }) + .Include(t => t.Price) + .LoadRequiredAsync(1, cts.Token); + + // assert + Assert.Equal(10, products.Count); + interceptor.MatchSnapshot(); + } + + [Fact] + public async Task Include_On_Array_Results() + { + // arrange + using var interceptor = new TestQueryInterceptor(); + using var cts = new CancellationTokenSource(2000); + await using var services = CreateServer(); + await using var scope = services.CreateAsyncScope(); + await using var context = scope.ServiceProvider.GetRequiredService(); + var seeder = scope.ServiceProvider.GetRequiredService>(); + await context.Database.EnsureCreatedAsync(cts.Token); + await seeder.SeedAsync(context); + + var productByBrand = scope.ServiceProvider.GetRequiredService(); + + // act + var products = await productByBrand + .Select(t => new Product { BrandId = t.BrandId, Name = t.Name }) + .Include(t => t.Price) + .LoadRequiredAsync(1, cts.Token); + + // assert + Assert.Equal(10, products.Length); + interceptor.MatchSnapshot(); + } + + [Fact] + public async Task Include_On_Page_Results() + { + // arrange + using var interceptor = new TestQueryInterceptor(); + using var cts = new CancellationTokenSource(2000); + await using var services = CreateServer(); + await using var scope = services.CreateAsyncScope(); + await using var context = scope.ServiceProvider.GetRequiredService(); + var seeder = scope.ServiceProvider.GetRequiredService>(); + await context.Database.EnsureCreatedAsync(cts.Token); + await seeder.SeedAsync(context); + + var productByBrand = scope.ServiceProvider.GetRequiredService(); + + // act + var products = await productByBrand + .With(new PagingArguments { First = 5 }) + .Select(t => new Product { BrandId = t.BrandId, Name = t.Name }) + .Include(t => t.Price) + .LoadRequiredAsync(1, cts.Token); + + // assert + Assert.Equal(5, products.Items.Length); + interceptor.MatchSnapshot(); + } + + private ServiceProvider CreateServer() + { + var db = "db_" + Guid.NewGuid().ToString("N"); + var connectionString = resource.GetConnectionString(db); + + var services = new ServiceCollection(); + + services + .AddLogging() + .AddDbContext(c => c.UseNpgsql(connectionString)); + + services + .AddSingleton() + .AddSingleton(); + + services + .AddGraphQLServer() + .AddCustomTypes() + .AddGlobalObjectIdentification() + .AddPagingArguments() + .AddFiltering() + .AddSorting() + .ModifyRequestOptions(o => o.IncludeExceptionDetails = true); + + services.AddSingleton, CatalogContextSeed>(); + + return services.BuildServiceProvider(); + } +} + +file static class Extensions +{ + public static void MatchSnapshot( + this TestQueryInterceptor queryInterceptor) + { +#if NET9_0_OR_GREATER + var snapshot = Snapshot.Create(); +#else + var snapshot = Snapshot.Create(postFix: "_net_8_0"); +#endif + + for (var i = 0; i < queryInterceptor.Queries.Count; i++) + { + var sql = queryInterceptor.Queries[i]; + snapshot.Add(sql, $"Query {i + 1}", MarkdownLanguages.Sql); + } + + snapshot.MatchMarkdown(); + } +} diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/IntegrationTests.cs b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/IntegrationTests.cs index a55901513e5..ba3f4d67448 100644 --- a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/IntegrationTests.cs +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/IntegrationTests.cs @@ -1,4 +1,3 @@ -using GreenDonut.Data; using HotChocolate.Data.Data; using HotChocolate.Data.Migrations; using HotChocolate.Data.Services; @@ -301,17 +300,4 @@ private static void MatchSnapshot( snapshot.MatchMarkdown(); } - - private class TestQueryInterceptor : PagingQueryInterceptor - { - public List Queries { get; } = new(); - - public override void OnBeforeExecute(IQueryable query) - { - lock(Queries) - { - Queries.Add(query.ToQueryString()); - } - } - } } diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/Services/ProductDataLoader.cs b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/Services/ProductDataLoader.cs index 667f1df4d55..0cada9cd787 100644 --- a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/Services/ProductDataLoader.cs +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/Services/ProductDataLoader.cs @@ -38,4 +38,40 @@ public static async Task>> GetProductsByBrandAsync .With(query, s => s.AddAscending(t => t.Id)) .ToBatchPageAsync(t => t.BrandId, pagingArgs, cancellationToken); } + + [DataLoader] + public static async Task>> GetProductListByBrandAsync( + IReadOnlyList brandIds, + QueryContext query, + CatalogContext context, + CancellationToken cancellationToken) + { + brandIds = brandIds.EnsureOrdered(); + + var queryable = context.Products + .Where(t => brandIds.Contains(t.BrandId)) + .With(query, s => s.AddAscending(t => t.Id)) + .GroupBy(t => t.BrandId); + PagingQueryInterceptor.Publish(queryable); + + return await queryable.ToDictionaryAsync(t => t.Key, t => t.ToList(), cancellationToken); + } + + [DataLoader] + public static async Task> GetProductArrayByBrandAsync( + IReadOnlyList brandIds, + QueryContext query, + CatalogContext context, + CancellationToken cancellationToken) + { + brandIds = brandIds.EnsureOrdered(); + + var queryable = context.Products + .Where(t => brandIds.Contains(t.BrandId)) + .With(query, s => s.AddAscending(t => t.Id)) + .GroupBy(t => t.BrandId); + PagingQueryInterceptor.Publish(queryable); + + return await queryable.ToDictionaryAsync(t => t.Key, t => t.ToArray(), cancellationToken); + } } diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/TestQueryInterceptor.cs b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/TestQueryInterceptor.cs new file mode 100644 index 00000000000..0a544b4e048 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/TestQueryInterceptor.cs @@ -0,0 +1,17 @@ +using GreenDonut.Data; +using Microsoft.EntityFrameworkCore; + +namespace HotChocolate.Data; + +public class TestQueryInterceptor : PagingQueryInterceptor +{ + public List Queries { get; } = new(); + + public override void OnBeforeExecute(IQueryable query) + { + lock(Queries) + { + Queries.Add(query.ToQueryString()); + } + } +} diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Array_Results.md b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Array_Results.md new file mode 100644 index 00000000000..33964c97293 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Array_Results.md @@ -0,0 +1,9 @@ +# Include_On_Array_Results + +```sql +-- @__brandIds_0={ '1' } (DbType = Object) +SELECT p."BrandId", p."Name", p."Price" +FROM "Products" AS p +WHERE p."BrandId" = ANY (@__brandIds_0) +ORDER BY p."BrandId" +``` diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Array_Results__net_8_0.md b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Array_Results__net_8_0.md new file mode 100644 index 00000000000..33964c97293 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Array_Results__net_8_0.md @@ -0,0 +1,9 @@ +# Include_On_Array_Results + +```sql +-- @__brandIds_0={ '1' } (DbType = Object) +SELECT p."BrandId", p."Name", p."Price" +FROM "Products" AS p +WHERE p."BrandId" = ANY (@__brandIds_0) +ORDER BY p."BrandId" +``` diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_List_Results.md b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_List_Results.md new file mode 100644 index 00000000000..fbdb8d3c7c6 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_List_Results.md @@ -0,0 +1,9 @@ +# Include_On_List_Results + +```sql +-- @__brandIds_0={ '1' } (DbType = Object) +SELECT p."BrandId", p."Name", p."Price" +FROM "Products" AS p +WHERE p."BrandId" = ANY (@__brandIds_0) +ORDER BY p."BrandId" +``` diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_List_Results__net_8_0.md b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_List_Results__net_8_0.md new file mode 100644 index 00000000000..fbdb8d3c7c6 --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_List_Results__net_8_0.md @@ -0,0 +1,9 @@ +# Include_On_List_Results + +```sql +-- @__brandIds_0={ '1' } (DbType = Object) +SELECT p."BrandId", p."Name", p."Price" +FROM "Products" AS p +WHERE p."BrandId" = ANY (@__brandIds_0) +ORDER BY p."BrandId" +``` diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Page_Results.md b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Page_Results.md new file mode 100644 index 00000000000..8f0a1d4626e --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Page_Results.md @@ -0,0 +1,22 @@ +# Include_On_Page_Results + +```sql +-- @__brandIds_0={ '1' } (DbType = Object) +SELECT p1."BrandId", p3."BrandId", p3."Name", p3."Price", p3."Id" +FROM ( + SELECT p."BrandId" + FROM "Products" AS p + WHERE p."BrandId" = ANY (@__brandIds_0) + GROUP BY p."BrandId" +) AS p1 +LEFT JOIN ( + SELECT p2."BrandId", p2."Name", p2."Price", p2."Id" + FROM ( + SELECT p0."BrandId", p0."Name", p0."Price", p0."Id", ROW_NUMBER() OVER(PARTITION BY p0."BrandId" ORDER BY p0."Id") AS row + FROM "Products" AS p0 + WHERE p0."BrandId" = ANY (@__brandIds_0) + ) AS p2 + WHERE p2.row <= 6 +) AS p3 ON p1."BrandId" = p3."BrandId" +ORDER BY p1."BrandId", p3."BrandId", p3."Id" +``` diff --git a/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Page_Results__net_8_0.md b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Page_Results__net_8_0.md new file mode 100644 index 00000000000..d9aa35ea43b --- /dev/null +++ b/src/HotChocolate/Data/test/Data.PostgreSQL.Tests/__snapshots__/DataLoaderTests.Include_On_Page_Results__net_8_0.md @@ -0,0 +1,22 @@ +# Include_On_Page_Results + +```sql +-- @__brandIds_0={ '1' } (DbType = Object) +SELECT t."BrandId", t0."BrandId", t0."Name", t0."Price", t0."Id" +FROM ( + SELECT p."BrandId" + FROM "Products" AS p + WHERE p."BrandId" = ANY (@__brandIds_0) + GROUP BY p."BrandId" +) AS t +LEFT JOIN ( + SELECT t1."BrandId", t1."Name", t1."Price", t1."Id" + FROM ( + SELECT p0."BrandId", p0."Name", p0."Price", p0."Id", ROW_NUMBER() OVER(PARTITION BY p0."BrandId" ORDER BY p0."Id") AS row + FROM "Products" AS p0 + WHERE p0."BrandId" = ANY (@__brandIds_0) + ) AS t1 + WHERE t1.row <= 6 +) AS t0 ON t."BrandId" = t0."BrandId" +ORDER BY t."BrandId", t0."BrandId", t0."Id" +```