Skip to content

Commit

Permalink
Added DataLoader group refinements. (#7529)
Browse files Browse the repository at this point in the history
(cherry picked from commit 3eac31b)
  • Loading branch information
michaelstaib committed Sep 29, 2024
1 parent b7c55f5 commit 0c9789a
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ public void WriteDataLoaderGroupClass(

_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.WriteLine();

_writer.WriteIndentedLine("public sealed class {0} : I{0}", groupClassName);
_writer.WriteIndentedLine("{");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ static void AddGroupNames(ImmutableHashSet<string>.Builder builder, IEnumerable<
{
if (IsDataLoaderGroupAttribute(attribute.AttributeClass))
{
foreach (var arg in attribute.ConstructorArguments.FirstOrDefault().Values)
var constructorArguments = attribute.ConstructorArguments;
if (constructorArguments.Length > 0)
{
if (arg.Value is string groupName)
foreach (var arg in constructorArguments[0].Values)
{
builder.Add(groupName);
if (arg.Value is string groupName)
{
builder.Add(groupName);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ public override bool Equals(SyntaxInfo obj)

private bool Equals(DataLoaderInfo other)
=> AttributeSyntax.IsEquivalentTo(other.AttributeSyntax)
&& MethodSyntax.IsEquivalentTo(other.MethodSyntax);
&& MethodSyntax.IsEquivalentTo(other.MethodSyntax)
&& Groups.SequenceEqual(other.Groups, StringComparer.Ordinal);

public override int GetHashCode()
=> HashCode.Combine(AttributeSyntax, MethodSyntax);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,66 @@ public class Entity
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_BatchDataLoader_With_Group_Only_On_Class_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using HotChocolate;
using GreenDonut;

namespace TestNamespace;

[DataLoaderGroup("Group1")]
internal static class TestClass
{
[DataLoader]
public static Task<IReadOnlyDictionary<int, Entity>> GetEntityByIdAsync(
IReadOnlyList<int> entityIds,
CancellationToken cancellationToken)
=> default!;
}

public class Entity
{
public int Id { get; set; }
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_BatchDataLoader_With_Group_Only_On_Method_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using HotChocolate;
using GreenDonut;

namespace TestNamespace;

internal static class TestClass
{
[DataLoaderGroup("Group1")]
[DataLoader]
public static Task<IReadOnlyDictionary<int, Entity>> GetEntityByIdAsync(
IReadOnlyList<int> entityIds,
CancellationToken cancellationToken)
=> default!;
}

public class Entity
{
public int Id { get; set; }
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_GroupedDataLoader_MatchesSnapshot()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ namespace TestNamespace
{
IEntityByIdDataLoader EntityById { get; }
}

public sealed class Group1 : IGroup1
{
private readonly IServiceProvider _services;
Expand Down Expand Up @@ -96,6 +97,7 @@ namespace TestNamespace
{
IEntityByIdDataLoader EntityById { get; }
}

public sealed class Group2 : IGroup2
{
private readonly IServiceProvider _services;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# GenerateSource_BatchDataLoader_With_Group_Only_On_Class_MatchesSnapshot

## GreenDonutDataLoader.735550c.g.cs

```csharp
// <auto-generated/>
#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;
using GreenDonut;

namespace TestNamespace
{
public interface IEntityByIdDataLoader
: global::GreenDonut.IDataLoader<int, global::TestNamespace.Entity>
{
}

public sealed class EntityByIdDataLoader
: global::GreenDonut.DataLoaderBase<int, global::TestNamespace.Entity>
, IEntityByIdDataLoader
{
private readonly global::System.IServiceProvider _services;

public EntityByIdDataLoader(
global::System.IServiceProvider services,
global::GreenDonut.IBatchScheduler batchScheduler,
global::GreenDonut.DataLoaderOptions options)
: base(batchScheduler, options)
{
_services = services ??
throw new global::System.ArgumentNullException(nameof(services));
}

protected override async global::System.Threading.Tasks.ValueTask FetchAsync(
global::System.Collections.Generic.IReadOnlyList<int> keys,
global::System.Memory<GreenDonut.Result<global::TestNamespace.Entity?>> results,
global::GreenDonut.DataLoaderFetchContext<global::TestNamespace.Entity> context,
global::System.Threading.CancellationToken ct)
{
var temp = await TestNamespace.TestClass.GetEntityByIdAsync(keys, ct).ConfigureAwait(false);
CopyResults(keys, results.Span, temp);
}

private void CopyResults(
global::System.Collections.Generic.IReadOnlyList<int> keys,
global::System.Span<GreenDonut.Result<global::TestNamespace.Entity?>> results,
global::System.Collections.Generic.IReadOnlyDictionary<int, TestNamespace.Entity> resultMap)
{
for (var i = 0; i < keys.Count; i++)
{
var key = keys[i];
if (resultMap.TryGetValue(key, out var value))
{
results[i] = global::GreenDonut.Result<global::TestNamespace.Entity?>.Resolve(value);
}
else
{
results[i] = global::GreenDonut.Result<global::TestNamespace.Entity?>.Resolve(default(global::TestNamespace.Entity));
}
}
}
}
public interface IGroup1
{
IEntityByIdDataLoader EntityById { get; }
}

public sealed class Group1 : IGroup1
{
private readonly IServiceProvider _services;
private IEntityByIdDataLoader? _entityById;

public Group1(IServiceProvider services)
{
_services = services
?? throw new ArgumentNullException(nameof(services));
}
public IEntityByIdDataLoader EntityById
{
get
{
if (_entityById is null)
{
_entityById = _services.GetRequiredService<IEntityByIdDataLoader>();
}

return _entityById!;
}
}
}
}


```

## HotChocolateTypeModule.735550c.g.cs

```csharp
// <auto-generated/>
#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using HotChocolate;
using HotChocolate.Types;
using HotChocolate.Execution.Configuration;

namespace Microsoft.Extensions.DependencyInjection
{
public static partial class TestsTypesRequestExecutorBuilderExtensions
{
public static IRequestExecutorBuilder AddTestsTypes(this IRequestExecutorBuilder builder)
{
builder.AddDataLoader<global::TestNamespace.IEntityByIdDataLoader, global::TestNamespace.EntityByIdDataLoader>();
builder.Services.AddScoped<global::TestNamespace.IGroup1, global::TestNamespace.Group1>();
return builder;
}
}
}

```

Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# GenerateSource_BatchDataLoader_With_Group_Only_On_Method_MatchesSnapshot

## GreenDonutDataLoader.735550c.g.cs

```csharp
// <auto-generated/>
#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using Microsoft.Extensions.DependencyInjection;
using GreenDonut;

namespace TestNamespace
{
public interface IEntityByIdDataLoader
: global::GreenDonut.IDataLoader<int, global::TestNamespace.Entity>
{
}

public sealed class EntityByIdDataLoader
: global::GreenDonut.DataLoaderBase<int, global::TestNamespace.Entity>
, IEntityByIdDataLoader
{
private readonly global::System.IServiceProvider _services;

public EntityByIdDataLoader(
global::System.IServiceProvider services,
global::GreenDonut.IBatchScheduler batchScheduler,
global::GreenDonut.DataLoaderOptions options)
: base(batchScheduler, options)
{
_services = services ??
throw new global::System.ArgumentNullException(nameof(services));
}

protected override async global::System.Threading.Tasks.ValueTask FetchAsync(
global::System.Collections.Generic.IReadOnlyList<int> keys,
global::System.Memory<GreenDonut.Result<global::TestNamespace.Entity?>> results,
global::GreenDonut.DataLoaderFetchContext<global::TestNamespace.Entity> context,
global::System.Threading.CancellationToken ct)
{
var temp = await TestNamespace.TestClass.GetEntityByIdAsync(keys, ct).ConfigureAwait(false);
CopyResults(keys, results.Span, temp);
}

private void CopyResults(
global::System.Collections.Generic.IReadOnlyList<int> keys,
global::System.Span<GreenDonut.Result<global::TestNamespace.Entity?>> results,
global::System.Collections.Generic.IReadOnlyDictionary<int, TestNamespace.Entity> resultMap)
{
for (var i = 0; i < keys.Count; i++)
{
var key = keys[i];
if (resultMap.TryGetValue(key, out var value))
{
results[i] = global::GreenDonut.Result<global::TestNamespace.Entity?>.Resolve(value);
}
else
{
results[i] = global::GreenDonut.Result<global::TestNamespace.Entity?>.Resolve(default(global::TestNamespace.Entity));
}
}
}
}
public interface IGroup1
{
IEntityByIdDataLoader EntityById { get; }
}

public sealed class Group1 : IGroup1
{
private readonly IServiceProvider _services;
private IEntityByIdDataLoader? _entityById;

public Group1(IServiceProvider services)
{
_services = services
?? throw new ArgumentNullException(nameof(services));
}
public IEntityByIdDataLoader EntityById
{
get
{
if (_entityById is null)
{
_entityById = _services.GetRequiredService<IEntityByIdDataLoader>();
}

return _entityById!;
}
}
}
}


```

## HotChocolateTypeModule.735550c.g.cs

```csharp
// <auto-generated/>
#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using HotChocolate;
using HotChocolate.Types;
using HotChocolate.Execution.Configuration;

namespace Microsoft.Extensions.DependencyInjection
{
public static partial class TestsTypesRequestExecutorBuilderExtensions
{
public static IRequestExecutorBuilder AddTestsTypes(this IRequestExecutorBuilder builder)
{
builder.AddDataLoader<global::TestNamespace.IEntityByIdDataLoader, global::TestNamespace.EntityByIdDataLoader>();
builder.Services.AddScoped<global::TestNamespace.IGroup1, global::TestNamespace.Group1>();
return builder;
}
}
}

```

0 comments on commit 0c9789a

Please sign in to comment.