Skip to content

Commit

Permalink
Added support for error filters to Fusion (#7190)
Browse files Browse the repository at this point in the history
  • Loading branch information
tobias-tengler authored Aug 2, 2024
1 parent d29f583 commit 044aabe
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using HotChocolate;
using HotChocolate.Execution;
using HotChocolate.Execution.Configuration;
using HotChocolate.Execution.Errors;
using HotChocolate.Execution.Options;
using HotChocolate.Execution.Pipeline;
using HotChocolate.Fusion;
Expand Down Expand Up @@ -343,6 +344,66 @@ public static FusionGatewayBuilder ModifyFusionOptions(
return builder;
}

public static FusionGatewayBuilder AddErrorFilter(
this FusionGatewayBuilder builder,
Func<IError, IError> errorFilter)
{
if (builder is null)
{
throw new ArgumentNullException(nameof(builder));
}

if (errorFilter is null)
{
throw new ArgumentNullException(nameof(errorFilter));
}

builder.CoreBuilder.ConfigureSchemaServices(
s => s.AddSingleton<IErrorFilter>(
new FuncErrorFilterWrapper(errorFilter)));

return builder;
}

public static FusionGatewayBuilder AddErrorFilter<T>(
this FusionGatewayBuilder builder,
Func<IServiceProvider, T> factory)
where T : class, IErrorFilter
{
if (builder is null)
{
throw new ArgumentNullException(nameof(builder));
}

if (factory is null)
{
throw new ArgumentNullException(nameof(factory));
}

builder.CoreBuilder.ConfigureSchemaServices(
s => s.AddSingleton<IErrorFilter, T>(
sp => factory(sp.GetCombinedServices())));

return builder;
}

public static FusionGatewayBuilder AddErrorFilter<T>(
this FusionGatewayBuilder builder)
where T : class, IErrorFilter
{
if (builder is null)
{
throw new ArgumentNullException(nameof(builder));
}

builder.Services.TryAddSingleton<T>();
builder.CoreBuilder.ConfigureSchemaServices(
s => s.AddSingleton<IErrorFilter, T>(
sp => sp.GetApplicationService<T>()));

return builder;
}

/// <summary>
/// Uses the default fusion gateway pipeline.
/// </summary>
Expand Down
8 changes: 6 additions & 2 deletions src/HotChocolate/Fusion/src/Core/Execution/ExecutionUtils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ public static void ExtractErrors(
DocumentNode document,
OperationDefinitionNode operation,
ResultBuilder resultBuilder,
IErrorHandler errorHandler,
JsonElement errors,
ObjectResult selectionSetResult,
int pathDepth,
Expand All @@ -505,14 +506,15 @@ public static void ExtractErrors(
var path = PathHelper.CreatePathFromContext(selectionSetResult);
foreach (var error in errors.EnumerateArray())
{
ExtractError(document, operation, resultBuilder, error, path, pathDepth, addDebugInfo);
ExtractError(document, operation, resultBuilder, errorHandler, error, path, pathDepth, addDebugInfo);
}
}

private static void ExtractError(
DocumentNode document,
OperationDefinitionNode operation,
ResultBuilder resultBuilder,
IErrorHandler errorHandler,
JsonElement error,
Path parentPath,
int pathDepth,
Expand Down Expand Up @@ -577,7 +579,9 @@ private static void ExtractError(
errorBuilder.AddLocation(field);
}

resultBuilder.AddError(errorBuilder.Build());
var handledError = errorHandler.Handle(errorBuilder.Build());

resultBuilder.AddError(handledError);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,17 @@ public FusionExecutionContext(
INodeIdSerializer idSerializer,
NodeIdParser nodeIdParser,
FusionOptions options,
IFusionDiagnosticEvents diagnosticEvents)
IFusionDiagnosticEvents diagnosticEvents,
IErrorHandler errorHandler)
{
Configuration = configuration ??
throw new ArgumentNullException(nameof(configuration));
QueryPlan = queryPlan ??
throw new ArgumentNullException(nameof(queryPlan));
DiagnosticEvents = diagnosticEvents ??
throw new ArgumentNullException(nameof(diagnosticEvents));
ErrorHandler = errorHandler ??
throw new ArgumentNullException(nameof(errorHandler));
_operationContextOwner = operationContextOwner ??
throw new ArgumentNullException(nameof(operationContextOwner));
_clientFactory = clientFactory ??
Expand All @@ -48,6 +51,8 @@ public FusionExecutionContext(
throw new ArgumentNullException(nameof(options));
}

public IErrorHandler ErrorHandler { get; }

/// <summary>
/// Gets the schema that is being executed on.
/// </summary>
Expand Down Expand Up @@ -177,5 +182,6 @@ public static FusionExecutionContext CreateFrom(
context._idSerializer,
context._nodeIdParser,
context._options,
context.DiagnosticEvents);
context.DiagnosticEvents,
context.ErrorHandler);
}
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ public async Task<IOperationResult> ExecuteAsync(

if (context.Result.Errors.Count == 0)
{
var errorHandler = context.OperationContext.ErrorHandler;
var errorHandler = context.ErrorHandler;
var error = errorHandler.CreateUnexpectedError(ex).Build();
error = errorHandler.Handle(error);
context.Result.AddError(error);
Expand Down
8 changes: 6 additions & 2 deletions src/HotChocolate/Fusion/src/Core/Execution/Nodes/Resolve.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,11 @@ protected override async Task OnExecuteAsync(
catch (Exception ex)
{
context.DiagnosticEvents.ResolveError(ex);
var error = context.OperationContext.ErrorHandler.CreateUnexpectedError(ex);
context.Result.AddError(error.Build());

var errorHandler = context.ErrorHandler;
var error = errorHandler.CreateUnexpectedError(ex).Build();
error = errorHandler.Handle(error);
context.Result.AddError(error);
}
}

Expand Down Expand Up @@ -149,6 +152,7 @@ private void ProcessResponses(
context.Operation.Document,
context.Operation.Definition,
context.Result,
context.ErrorHandler,
response.Errors,
selectionSetResult,
pathLength,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,11 @@ protected override async Task OnExecuteAsync(
catch (Exception ex)
{
context.DiagnosticEvents.ResolveByKeyBatchError(ex);
var error = context.OperationContext.ErrorHandler.CreateUnexpectedError(ex);
context.Result.AddError(error.Build());

var errorHandler = context.ErrorHandler;
var error = errorHandler.CreateUnexpectedError(ex).Build();
error = errorHandler.Handle(error);
context.Result.AddError(error);
}
}

Expand Down Expand Up @@ -155,6 +158,7 @@ private void ProcessResult(
context.Operation.Document,
context.Operation.Definition,
context.Result,
context.ErrorHandler,
response.Errors,
batchState.SelectionSetResult,
pathLength + 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ internal sealed class DistributedOperationExecutionMiddleware(
[SchemaService] GraphQLClientFactory clientFactory,
[SchemaService] NodeIdParser nodeIdParser,
[SchemaService] FusionOptions options,
[SchemaService] IFusionDiagnosticEvents diagnosticEvents)
[SchemaService] IFusionDiagnosticEvents diagnosticEvents,
[SchemaService] IErrorHandler errorHandler)
{
private static readonly object _queryRoot = new();
private static readonly object _mutationRoot = new();
Expand All @@ -44,6 +45,8 @@ internal sealed class DistributedOperationExecutionMiddleware(
?? throw new ArgumentNullException(nameof(options));
private readonly IFusionDiagnosticEvents _diagnosticEvents = diagnosticEvents
?? throw new ArgumentNullException(nameof(diagnosticEvents));
private readonly IErrorHandler _errorHandler = errorHandler
?? throw new ArgumentNullException(nameof(errorHandler));

public async ValueTask InvokeAsync(
IRequestContext context,
Expand Down Expand Up @@ -78,7 +81,8 @@ context.Variables is not null &&
_nodeIdSerializer,
_nodeIdParser,
_fusionOptionsAccessor,
diagnosticEvents);
_diagnosticEvents,
_errorHandler);

using (federatedQueryContext.DiagnosticEvents.ExecuteFederatedQuery(context))
{
Expand Down Expand Up @@ -119,6 +123,7 @@ public static RequestCoreMiddleware Create()
var nodeIdParser = core.SchemaServices.GetRequiredService<NodeIdParser>();
var fusionOptionsAccessor = core.SchemaServices.GetRequiredService<FusionOptions>();
var diagnosticEvents = core.SchemaServices.GetRequiredService<IFusionDiagnosticEvents>();
var errorHandler = core.SchemaServices.GetRequiredService<IErrorHandler>();
var middleware = new DistributedOperationExecutionMiddleware(
next,
contextFactory,
Expand All @@ -127,7 +132,8 @@ public static RequestCoreMiddleware Create()
clientFactory,
nodeIdParser,
fusionOptionsAccessor,
diagnosticEvents);
diagnosticEvents,
errorHandler);
return async context =>
{
var batchDispatcher = context.Services.GetRequiredService<IBatchDispatcher>();
Expand Down
53 changes: 53 additions & 0 deletions src/HotChocolate/Fusion/test/Core.Tests/SubgraphErrorTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using CookieCrumble;
using HotChocolate.Fusion.Shared;
using HotChocolate.Execution;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -2918,4 +2919,56 @@ type Product {
}

#endregion

[Fact]
public async Task ErrorFilter_Is_Applied()
{
// arrange
var subgraph = await TestSubgraph.CreateAsync(
"""
type Query {
field: String @error
}
"""
);

using var subgraphs = new TestSubgraphCollection(output, [subgraph]);
var executor = await subgraphs.GetExecutorAsync(
configureBuilder: builder =>
builder.AddErrorFilter(error => error.WithMessage("REPLACED MESSAGE").WithCode("CUSTOM_CODE")));
var request = """
query {
field
}
""";

// act
var result = await executor.ExecuteAsync(request);

// assert
result.MatchInlineSnapshot("""
{
"errors": [
{
"message": "REPLACED MESSAGE",
"locations": [
{
"line": 2,
"column": 3
}
],
"path": [
"field"
],
"extensions": {
"code": "CUSTOM_CODE"
}
}
],
"data": {
"field": null
}
}
""");
}
}
37 changes: 22 additions & 15 deletions src/HotChocolate/Fusion/test/Shared/TestSubgraphCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,40 @@ public IHttpClientFactory GetHttpClientFactory()
return new TestSubgraphCollectionHttpClientFactory(subgraphsDictionary);
}

public async Task<IRequestExecutor> GetExecutorAsync(FusionFeatureCollection? features = null)
public async Task<IRequestExecutor> GetExecutorAsync(
FusionFeatureCollection? features = null,
Action<FusionGatewayBuilder>? configureBuilder = null)
{
var fusionGraph = await ComposeFusionGraphAsync(features);

return await GetExecutorAsync(fusionGraph);
return await GetExecutorAsync(fusionGraph, configureBuilder);
}

public async Task<IRequestExecutor> GetExecutorAsync(Skimmed.SchemaDefinition fusionGraph)
public void Dispose()
{
foreach (var subgraph in subgraphs)
{
subgraph.TestServer.Dispose();
}
}

private async Task<IRequestExecutor> GetExecutorAsync(
Skimmed.SchemaDefinition fusionGraph,
Action<FusionGatewayBuilder>? configureBuilder = null)
{
var httpClientFactory = GetHttpClientFactory();

return await new ServiceCollection()
var builder = new ServiceCollection()
.AddSingleton(httpClientFactory)
.AddFusionGatewayServer()
.ConfigureFromDocument(SchemaFormatter.FormatAsDocument(fusionGraph))
.BuildRequestExecutorAsync();
.ConfigureFromDocument(SchemaFormatter.FormatAsDocument(fusionGraph));

configureBuilder?.Invoke(builder);

return await builder.BuildRequestExecutorAsync();
}

public async Task<Skimmed.SchemaDefinition> ComposeFusionGraphAsync(FusionFeatureCollection? features = null)
private async Task<Skimmed.SchemaDefinition> ComposeFusionGraphAsync(FusionFeatureCollection? features = null)
{
features ??= new FusionFeatureCollection(FusionFeatures.NodeField);

Expand All @@ -59,14 +74,6 @@ public async Task<IRequestExecutor> GetExecutorAsync(Skimmed.SchemaDefinition fu
.ComposeAsync(configurations, features);
}

public void Dispose()
{
foreach (var subgraph in subgraphs)
{
subgraph.TestServer.Dispose();
}
}

private IEnumerable<(string SubgraphName, TestSubgraph Subgraph)> GetSubgraphs()
=> subgraphs.Select((s, i) => ($"Subgraph_{++i}", s));

Expand Down

0 comments on commit 044aabe

Please sign in to comment.