From a921397e480d09062d214d3e1e15b135419e4992 Mon Sep 17 00:00:00 2001 From: Hafiz Ismail Date: Thu, 10 Mar 2016 22:17:51 +0800 Subject: [PATCH] [Validation] Parallelize validation rules. This provides a performance improvement and simplification to the validator by providing two new generic visitor utilities. One for tracking a TypeInfo instance alongside a visitor instance, and another for stepping through multiple visitors in parallel. The two can be composed together. Rather than 23 passes of AST visitation with one rule each, this now performs one pass of AST visitation with 23 rules. Since visitation is costly but rules are inexpensive, this nets out to a much faster overall validation, especially noticable for very large queries. Commit: 957704188b0a103c5f2fe0ab99479267d5d1ae43 [9577041] Parents: 439a3e2f4f Author: Lee Byron Date: 17 November 2015 at 12:33:54 PM SGT ---- [Validation] Memoize collecting variable usage. During multiple validation passes we need to know about variable usage within a de-fragmented operation. Memoizing this ensures each pass is O(N) - each fragment is no longer visited per operation, but once total. In doing so, `visitSpreadFragments` is no longer used, which will be cleaned up in a later PR Commit: 2afbff79bfd2b89f03ca7913577556b73980f974 [2afbff7] Parents: 88acc01b99 Author: Lee Byron Date: 17 November 2015 at 9:54:30 AM SGT --- language/ast/selections.go | 12 + language/type_info/type_info.go | 14 ++ language/visitor/visitor.go | 177 ++++++++++++++- rules.go | 318 ++++++++++++--------------- rules_no_undefined_variables_test.go | 10 +- type_info.go | 28 ++- validator.go | 178 ++++++++++++++- 7 files changed, 536 insertions(+), 201 deletions(-) create mode 100644 language/type_info/type_info.go diff --git a/language/ast/selections.go b/language/ast/selections.go index 1b7e60d2..dd36cf26 100644 --- a/language/ast/selections.go +++ b/language/ast/selections.go @@ -46,6 +46,10 @@ func (f *Field) GetLoc() *Location { return f.Loc } +func (f *Field) GetSelectionSet() *SelectionSet { + return f.SelectionSet +} + // FragmentSpread implements Node, Selection type FragmentSpread struct { Kind string @@ -74,6 +78,10 @@ func (fs *FragmentSpread) GetLoc() *Location { return fs.Loc } +func (fs *FragmentSpread) GetSelectionSet() *SelectionSet { + return nil +} + // InlineFragment implements Node, Selection type InlineFragment struct { Kind string @@ -104,6 +112,10 @@ func (f *InlineFragment) GetLoc() *Location { return f.Loc } +func (f *InlineFragment) GetSelectionSet() *SelectionSet { + return f.SelectionSet +} + // SelectionSet implements Node type SelectionSet struct { Kind string diff --git a/language/type_info/type_info.go b/language/type_info/type_info.go new file mode 100644 index 00000000..02b7b04f --- /dev/null +++ b/language/type_info/type_info.go @@ -0,0 +1,14 @@ +package type_info + +import ( + "github.com/graphql-go/graphql/language/ast" +) + +/** + * TypeInfoI defines the interface for TypeInfo + * Implementation + */ +type TypeInfoI interface { + Enter(node ast.Node) + Leave(node ast.Node) +} diff --git a/language/visitor/visitor.go b/language/visitor/visitor.go index 83edbd9b..3188efec 100644 --- a/language/visitor/visitor.go +++ b/language/visitor/visitor.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/type_info" "reflect" ) @@ -380,7 +381,7 @@ Loop: kind = node.GetKind() } - visitFn := GetVisitFn(visitorOpts, isLeaving, kind) + visitFn := GetVisitFn(visitorOpts, kind, isLeaving) if visitFn != nil { p := VisitFuncParams{ Node: nodeIn, @@ -709,7 +710,144 @@ func isNilNode(node interface{}) bool { return val.Interface() == nil } -func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitFunc { +/** + * Creates a new visitor instance which delegates to many visitors to run in + * parallel. Each visitor will be visited for each node before moving on. + * + * Visitors must not directly modify the AST nodes and only returning false to + * skip sub-branches is supported. + */ +func VisitInParallel(visitorOptsSlice []*VisitorOptions) *VisitorOptions { + skipping := map[int]interface{}{} + + return &VisitorOptions{ + Enter: func(p VisitFuncParams) (string, interface{}) { + for i, visitorOpts := range visitorOptsSlice { + if _, ok := skipping[i]; !ok { + switch node := p.Node.(type) { + case ast.Node: + kind := node.GetKind() + fn := GetVisitFn(visitorOpts, kind, false) + if fn != nil { + action, _ := fn(p) + if action == ActionSkip { + skipping[i] = node + } + } + } + } + } + return ActionNoChange, nil + }, + Leave: func(p VisitFuncParams) (string, interface{}) { + for i, visitorOpts := range visitorOptsSlice { + if _, ok := skipping[i]; !ok { + switch node := p.Node.(type) { + case ast.Node: + kind := node.GetKind() + fn := GetVisitFn(visitorOpts, kind, true) + if fn != nil { + fn(p) + } + } + } else { + delete(skipping, i) + } + } + return ActionNoChange, nil + }, + } +} + +/** + * Creates a new visitor instance which maintains a provided TypeInfo instance + * along with visiting visitor. + * + * Visitors must not directly modify the AST nodes and only returning false to + * skip sub-branches is supported. + */ +func VisitWithTypeInfo(typeInfo type_info.TypeInfoI, visitorOpts *VisitorOptions) *VisitorOptions { + return &VisitorOptions{ + Enter: func(p VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(ast.Node); ok { + typeInfo.Enter(node) + fn := GetVisitFn(visitorOpts, node.GetKind(), false) + if fn != nil { + action, _ := fn(p) + if action == ActionSkip { + typeInfo.Leave(node) + return ActionSkip, nil + } + } + } + return ActionNoChange, nil + }, + Leave: func(p VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(ast.Node); ok { + fn := GetVisitFn(visitorOpts, node.GetKind(), true) + if fn != nil { + fn(p) + } + typeInfo.Leave(node) + } + return ActionNoChange, nil + }, + } +} + +/** + * Given a visitor instance, if it is leaving or not, and a node kind, return + * the function the visitor runtime should call. + */ +func GetVisitFn(visitorOpts *VisitorOptions, kind string, isLeaving bool) VisitFunc { + if visitorOpts == nil { + return nil + } + kindVisitor, ok := visitorOpts.KindFuncMap[kind] + if ok { + if !isLeaving && kindVisitor.Kind != nil { + // { Kind() {} } + return kindVisitor.Kind + } + if isLeaving { + // { Kind: { leave() {} } } + return kindVisitor.Leave + } else { + // { Kind: { enter() {} } } + return kindVisitor.Enter + } + } else { + + if isLeaving { + // { enter() {} } + specificVisitor := visitorOpts.Leave + if specificVisitor != nil { + return specificVisitor + } + if specificKindVisitor, ok := visitorOpts.LeaveKindMap[kind]; ok { + // { leave: { Kind() {} } } + return specificKindVisitor + } + + } else { + // { leave() {} } + specificVisitor := visitorOpts.Enter + if specificVisitor != nil { + return specificVisitor + } + if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok { + // { enter: { Kind() {} } } + return specificKindVisitor + } + } + } + + return nil +} + +///// DELETE //// + +func GetVisitFnOld(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitFunc { if visitorOpts == nil { return nil } @@ -753,3 +891,38 @@ func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitF return nil } + +/* + + +export function getVisitFn(visitor, isLeaving, kind) { + var kindVisitor = visitor[kind]; + if (kindVisitor) { + if (!isLeaving && typeof kindVisitor === 'function') { + // { Kind() {} } + return kindVisitor; + } + var kindSpecificVisitor = isLeaving ? kindVisitor.leave : kindVisitor.enter; + if (typeof kindSpecificVisitor === 'function') { + // { Kind: { enter() {}, leave() {} } } + return kindSpecificVisitor; + } + return; + } + var specificVisitor = isLeaving ? visitor.leave : visitor.enter; + if (specificVisitor) { + if (typeof specificVisitor === 'function') { + // { enter() {}, leave() {} } + return specificVisitor; + } + var specificKindVisitor = specificVisitor[kind]; + if (typeof specificKindVisitor === 'function') { + // { enter: { Kind() {} }, leave: { Kind() {} } } + return specificKindVisitor; + } + } +} + + + +*/ diff --git a/rules.go b/rules.go index c1fd9a8f..d46520a8 100644 --- a/rules.go +++ b/rules.go @@ -57,7 +57,7 @@ func newValidationError(message string, nodes []ast.Node) *gqlerrors.Error { ) } -func reportErrorAndReturn(context *ValidationContext, message string, nodes []ast.Node) (string, interface{}) { +func reportError(context *ValidationContext, message string, nodes []ast.Node) (string, interface{}) { context.ReportError(newValidationError(message, nodes)) return visitor.ActionNoChange, nil } @@ -84,7 +84,7 @@ func ArgumentsOfCorrectTypeRule(context *ValidationContext) *ValidationRuleInsta if argAST.Name != nil { argNameValue = argAST.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Argument "%v" expected type "%v" but got: %v.`, argNameValue, argDef.Type, printer.Print(value)), @@ -125,7 +125,7 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI ttype := context.InputType() if ttype, ok := ttype.(*NonNull); ok && defaultValue != nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Variable "$%v" of type "%v" is required and will not use the default value. Perhaps you meant to use type "%v".`, name, ttype, ttype.OfType), @@ -133,7 +133,7 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI ) } if ttype != nil && defaultValue != nil && !isValidLiteralValue(ttype, defaultValue) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Variable "$%v" of type "%v" has invalid default value: %v.`, name, ttype, printer.Print(defaultValue)), @@ -175,7 +175,7 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance if node.Name != nil { nodeName = node.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Cannot query field "%v" on "%v".`, nodeName, ttype.Name()), @@ -210,7 +210,7 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn if node, ok := p.Node.(*ast.InlineFragment); ok { ttype := context.Type() if ttype != nil && !IsCompositeType(ttype) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment cannot condition on non composite type "%v".`, ttype), []ast.Node{node.TypeCondition}, @@ -229,7 +229,7 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn if node.Name != nil { nodeName = node.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment "%v" cannot condition on non composite type "%v".`, nodeName, printer.Print(node.TypeCondition)), []ast.Node{node.TypeCondition}, @@ -289,7 +289,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if parentType != nil { parentTypeName = parentType.Name() } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown argument "%v" on field "%v" of type "%v".`, nodeName, fieldDef.Name, parentTypeName), []ast.Node{node}, @@ -311,7 +311,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance } } if directiveArgDef == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown argument "%v" on directive "@%v".`, nodeName, directive.Name), []ast.Node{node}, @@ -357,7 +357,7 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } } if directiveDef == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown directive "%v".`, nodeName), []ast.Node{node}, @@ -373,14 +373,14 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } if appliedTo.GetKind() == kinds.OperationDefinition && directiveDef.OnOperation == false { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "operation"), []ast.Node{node}, ) } if appliedTo.GetKind() == kinds.Field && directiveDef.OnField == false { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "field"), []ast.Node{node}, @@ -389,7 +389,7 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { if (appliedTo.GetKind() == kinds.FragmentSpread || appliedTo.GetKind() == kinds.InlineFragment || appliedTo.GetKind() == kinds.FragmentDefinition) && directiveDef.OnFragment == false { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "fragment"), []ast.Node{node}, @@ -430,7 +430,7 @@ func KnownFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance fragment := context.Fragment(fragmentName) if fragment == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown fragment "%v".`, fragmentName), []ast.Node{node.Name}, @@ -467,7 +467,7 @@ func KnownTypeNamesRule(context *ValidationContext) *ValidationRuleInstance { } ttype := context.Schema().Type(typeNameValue) if ttype == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown type "%v".`, typeNameValue), []ast.Node{node}, @@ -512,7 +512,7 @@ func LoneAnonymousOperationRule(context *ValidationContext) *ValidationRuleInsta Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.OperationDefinition); ok { if node.Name == nil && operationCount > 1 { - return reportErrorAndReturn( + return reportError( context, `This anonymous operation must be the only defined operation.`, []ast.Node{node}, @@ -613,11 +613,11 @@ func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { if len(spreadNames) > 0 { via = " via " + strings.Join(spreadNames, ", ") } - err := newValidationError( + reportError( + context, fmt.Sprintf(`Cannot spread fragment "%v" within itself%v.`, initialName, via), cyclePath, ) - context.ReportError(err) continue } spreadPathHasCurrentNode := false @@ -654,77 +654,64 @@ func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { * and via fragment spreads, are defined by that operation. */ func NoUndefinedVariablesRule(context *ValidationContext) *ValidationRuleInstance { - var operation *ast.OperationDefinition - var visitedFragmentNames = map[string]bool{} - var definedVariableNames = map[string]bool{} + var variableNameDefined = map[string]bool{} + visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.OperationDefinition); ok && node != nil { - operation = node - visitedFragmentNames = map[string]bool{} - definedVariableNames = map[string]bool{} - } - return visitor.ActionNoChange, nil - }, - }, - kinds.VariableDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { - variableName := "" - if node.Variable != nil && node.Variable.Name != nil { - variableName = node.Variable.Name.Value - } - definedVariableNames[variableName] = true - } + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + variableNameDefined = map[string]bool{} return visitor.ActionNoChange, nil }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variable, ok := p.Node.(*ast.Variable); ok && variable != nil { - variableName := "" - if variable.Name != nil { - variableName = variable.Name.Value - } - if val, _ := definedVariableNames[variableName]; !val { - withinFragment := false - for _, node := range p.Ancestors { - if node.GetKind() == kinds.FragmentDefinition { - withinFragment = true - break - } + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if operation, ok := p.Node.(*ast.OperationDefinition); ok && operation != nil { + usages := context.RecursiveVariableUsages(operation) + + for _, usage := range usages { + if usage == nil { + continue } - if withinFragment == true && operation != nil && operation.Name != nil { - return reportErrorAndReturn( - context, - fmt.Sprintf(`Variable "$%v" is not defined by operation "%v".`, variableName, operation.Name.Value), - []ast.Node{variable, operation}, - ) + if usage.Node == nil { + continue + } + varName := "" + if usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + opName := "" + if operation.Name != nil { + opName = operation.Name.Value + } + if res, ok := variableNameDefined[varName]; !ok || !res { + if opName != "" { + reportError( + context, + fmt.Sprintf(`Variable "$%v" is not defined by operation "%v".`, varName, opName), + []ast.Node{usage.Node, operation}, + ) + } else { + + reportError( + context, + fmt.Sprintf(`Variable "$%v" is not defined.`, varName), + []ast.Node{usage.Node, operation}, + ) + } } - return reportErrorAndReturn( - context, - fmt.Sprintf(`Variable "$%v" is not defined.`, variableName), - []ast.Node{variable}, - ) } } return visitor.ActionNoChange, nil }, }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ + kinds.VariableDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.FragmentSpread); ok && node != nil { - // Only visit fragments of a particular name once per operation - fragmentName := "" - if node.Name != nil { - fragmentName = node.Name.Value - } - if val, ok := visitedFragmentNames[fragmentName]; ok && val == true { - return visitor.ActionSkip, nil + if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { + variableName := "" + if node.Variable != nil && node.Variable.Name != nil { + variableName = node.Variable.Name.Value } - visitedFragmentNames[fragmentName] = true + // definedVariableNames[variableName] = true + variableNameDefined[variableName] = true } return visitor.ActionNoChange, nil }, @@ -817,11 +804,11 @@ func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { isFragNameUsed, ok := fragmentNameUsed[defName] if !ok || isFragNameUsed != true { - err := newValidationError( + reportError( + context, fmt.Sprintf(`Fragment "%v" is never used.`, defName), []ast.Node{def}, ) - context.ReportError(err) } } return visitor.ActionNoChange, nil @@ -843,33 +830,45 @@ func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { */ func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { - var visitedFragmentNames = map[string]bool{} var variableDefs = []*ast.VariableDefinition{} - var variableNameUsed = map[string]bool{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ Enter: func(p visitor.VisitFuncParams) (string, interface{}) { - visitedFragmentNames = map[string]bool{} variableDefs = []*ast.VariableDefinition{} - variableNameUsed = map[string]bool{} return visitor.ActionNoChange, nil }, Leave: func(p visitor.VisitFuncParams) (string, interface{}) { - for _, def := range variableDefs { - variableName := "" - if def.Variable != nil && def.Variable.Name != nil { - variableName = def.Variable.Name.Value + if operation, ok := p.Node.(*ast.OperationDefinition); ok && operation != nil { + variableNameUsed := map[string]bool{} + usages := context.RecursiveVariableUsages(operation) + + for _, usage := range usages { + varName := "" + if usage != nil && usage.Node != nil && usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + if varName != "" { + variableNameUsed[varName] = true + } } - if isVariableNameUsed, _ := variableNameUsed[variableName]; isVariableNameUsed != true { - err := newValidationError( - fmt.Sprintf(`Variable "$%v" is never used.`, variableName), - []ast.Node{def}, - ) - context.ReportError(err) + for _, variableDef := range variableDefs { + variableName := "" + if variableDef != nil && variableDef.Variable != nil && variableDef.Variable.Name != nil { + variableName = variableDef.Variable.Name.Value + } + if res, ok := variableNameUsed[variableName]; !ok || !res { + reportError( + context, + fmt.Sprintf(`Variable "$%v" is never used.`, variableName), + []ast.Node{variableDef}, + ) + } } + } + return visitor.ActionNoChange, nil }, }, @@ -878,33 +877,6 @@ func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { if def, ok := p.Node.(*ast.VariableDefinition); ok && def != nil { variableDefs = append(variableDefs, def) } - // Do not visit deeper, or else the defined variable name will be visited. - return visitor.ActionSkip, nil - }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variable, ok := p.Node.(*ast.Variable); ok && variable != nil { - if variable.Name != nil { - variableNameUsed[variable.Name.Value] = true - } - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if spreadAST, ok := p.Node.(*ast.FragmentSpread); ok && spreadAST != nil { - // Only visit fragments of a particular name once per operation - spreadName := "" - if spreadAST.Name != nil { - spreadName = spreadAST.Name.Value - } - if hasVisitedFragmentNames, _ := visitedFragmentNames[spreadName]; hasVisitedFragmentNames == true { - return visitor.ActionSkip, nil - } - visitedFragmentNames[spreadName] = true - } return visitor.ActionNoChange, nil }, }, @@ -1301,7 +1273,8 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul for _, c := range conflicts { responseName := c.Reason.Name reason := c.Reason - err := newValidationError( + reportError( + context, fmt.Sprintf( `Fields "%v" conflict because %v.`, responseName, @@ -1309,7 +1282,6 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul ), c.Fields, ) - context.ReportError(err) } return visitor.ActionNoChange, nil } @@ -1394,7 +1366,7 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst parentType, _ := context.ParentType().(Type) if fragType != nil && parentType != nil && !doTypesOverlap(fragType, parentType) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment cannot be spread here as objects of `+ `type "%v" can never be of type "%v".`, parentType, fragType), @@ -1415,7 +1387,7 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst fragType := getFragmentType(context, fragName) parentType, _ := context.ParentType().(Type) if fragType != nil && parentType != nil && !doTypesOverlap(fragType, parentType) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment "%v" cannot be spread here as objects of `+ `type "%v" can never be of type "%v".`, fragName, parentType, fragType), @@ -1471,12 +1443,12 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if fieldAST.Name != nil { fieldName = fieldAST.Name.Value } - err := newValidationError( + reportError( + context, fmt.Sprintf(`Field "%v" argument "%v" of type "%v" `+ `is required but not provided.`, fieldName, argDef.Name(), argDefType), []ast.Node{fieldAST}, ) - context.ReportError(err) } } } @@ -1512,12 +1484,12 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if directiveAST.Name != nil { directiveName = directiveAST.Name.Value } - err := newValidationError( + reportError( + context, fmt.Sprintf(`Directive "@%v" argument "%v" of type `+ `"%v" is required but not provided.`, directiveName, argDef.Name(), argDefType), []ast.Node{directiveAST}, ) - context.ReportError(err) } } } @@ -1554,14 +1526,14 @@ func ScalarLeafsRule(context *ValidationContext) *ValidationRuleInstance { if ttype != nil { if IsLeafType(ttype) { if node.SelectionSet != nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Field "%v" of type "%v" must not have a sub selection.`, nodeName, ttype), []ast.Node{node.SelectionSet}, ) } } else if node.SelectionSet == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Field "%v" of type "%v" must have a sub selection.`, nodeName, ttype), []ast.Node{node}, @@ -1611,7 +1583,7 @@ func UniqueArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance argName = node.Name.Value } if nameAST, ok := knownArgNames[argName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can be only one argument named "%v".`, argName), []ast.Node{nameAST, node.Name}, @@ -1648,7 +1620,7 @@ func UniqueFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance fragmentName = node.Name.Value } if nameAST, ok := knownFragmentNames[fragmentName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can only be one fragment named "%v".`, fragmentName), []ast.Node{nameAST, node.Name}, @@ -1700,7 +1672,7 @@ func UniqueInputFieldNamesRule(context *ValidationContext) *ValidationRuleInstan fieldName = node.Name.Value } if knownNameAST, ok := knownNames[fieldName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can be only one input field named "%v".`, fieldName), []ast.Node{knownNameAST, node.Name}, @@ -1739,7 +1711,7 @@ func UniqueOperationNamesRule(context *ValidationContext) *ValidationRuleInstanc operationName = node.Name.Value } if nameAST, ok := knownOperationNames[operationName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can only be one operation named "%v".`, operationName), []ast.Node{nameAST, node.Name}, @@ -1779,7 +1751,7 @@ func VariablesAreInputTypesRule(context *ValidationContext) *ValidationRuleInsta if node.Variable != nil && node.Variable.Name != nil { variableName = node.Variable.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Variable "$%v" cannot be non-input type "%v".`, variableName, printer.Print(node.Type)), @@ -1837,14 +1809,45 @@ func varTypeAllowedForType(varType Type, expectedType Type) bool { func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleInstance { varDefMap := map[string]*ast.VariableDefinition{} - visitedFragmentNames := map[string]bool{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { varDefMap = map[string]*ast.VariableDefinition{} - visitedFragmentNames = map[string]bool{} + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if operation, ok := p.Node.(*ast.OperationDefinition); ok { + + usages := context.RecursiveVariableUsages(operation) + for _, usage := range usages { + varName := "" + if usage != nil && usage.Node != nil && usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + var varType Type + varDef, ok := varDefMap[varName] + if ok { + var err error + varType, err = typeFromAST(*context.Schema(), varDef.Type) + if err != nil { + varType = nil + } + } + if varType != nil && + usage.Type != nil && + !varTypeAllowedForType(effectiveType(varType, varDef), usage.Type) { + reportError( + context, + fmt.Sprintf(`Variable "$%v" of type "%v" used in position `+ + `expecting type "%v".`, varName, varType, usage.Type), + []ast.Node{usage.Node}, + ) + } + } + + } return visitor.ActionNoChange, nil }, }, @@ -1855,47 +1858,8 @@ func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleI if varDefAST.Variable != nil && varDefAST.Variable.Name != nil { defName = varDefAST.Variable.Name.Value } - varDefMap[defName] = varDefAST - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - // Only visit fragments of a particular name once per operation - if spreadAST, ok := p.Node.(*ast.FragmentSpread); ok { - spreadName := "" - if spreadAST.Name != nil { - spreadName = spreadAST.Name.Value - } - if hasVisited, _ := visitedFragmentNames[spreadName]; hasVisited { - return visitor.ActionSkip, nil - } - visitedFragmentNames[spreadName] = true - } - return visitor.ActionNoChange, nil - }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variableAST, ok := p.Node.(*ast.Variable); ok && variableAST != nil { - varName := "" - if variableAST.Name != nil { - varName = variableAST.Name.Value - } - varDef, _ := varDefMap[varName] - var varType Type - if varDef != nil { - varType, _ = typeFromAST(*context.Schema(), varDef.Type) - } - inputType := context.InputType() - if varType != nil && inputType != nil && !varTypeAllowedForType(effectiveType(varType, varDef), inputType) { - return reportErrorAndReturn( - context, - fmt.Sprintf(`Variable "$%v" of type "%v" used in position `+ - `expecting type "%v".`, varName, varType, inputType), - []ast.Node{variableAST}, - ) + if defName != "" { + varDefMap[defName] = varDefAST } } return visitor.ActionNoChange, nil diff --git a/rules_no_undefined_variables_test.go b/rules_no_undefined_variables_test.go index 64449842..0b253715 100644 --- a/rules_no_undefined_variables_test.go +++ b/rules_no_undefined_variables_test.go @@ -108,7 +108,7 @@ func TestValidate_NoUndefinedVariables_VariableNotDefined(t *testing.T) { field(a: $a, b: $b, c: $c, d: $d) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$d" is not defined.`, 3, 39), + testutil.RuleError(`Variable "$d" is not defined by operation "Foo".`, 3, 39, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableNotDefinedByUnnamedQuery(t *testing.T) { @@ -117,7 +117,7 @@ func TestValidate_NoUndefinedVariables_VariableNotDefinedByUnnamedQuery(t *testi field(a: $a) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 3, 18), + testutil.RuleError(`Variable "$a" is not defined.`, 3, 18, 2, 7), }) } func TestValidate_NoUndefinedVariables_MultipleVariablesNotDefined(t *testing.T) { @@ -126,8 +126,8 @@ func TestValidate_NoUndefinedVariables_MultipleVariablesNotDefined(t *testing.T) field(a: $a, b: $b, c: $c) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 3, 18), - testutil.RuleError(`Variable "$c" is not defined.`, 3, 32), + testutil.RuleError(`Variable "$a" is not defined by operation "Foo".`, 3, 18, 2, 7), + testutil.RuleError(`Variable "$c" is not defined by operation "Foo".`, 3, 32, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByUnnamedQuery(t *testing.T) { @@ -139,7 +139,7 @@ func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByUnnamedQuer field(a: $a) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 6, 18), + testutil.RuleError(`Variable "$a" is not defined.`, 6, 18, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByOperation(t *testing.T) { diff --git a/type_info.go b/type_info.go index a26825e3..3c6dd2e2 100644 --- a/type_info.go +++ b/type_info.go @@ -173,12 +173,18 @@ func (ti *TypeInfo) Leave(node ast.Node) { switch kind { case kinds.SelectionSet: // pop ti.parentTypeStack - _, ti.parentTypeStack = ti.parentTypeStack[len(ti.parentTypeStack)-1], ti.parentTypeStack[:len(ti.parentTypeStack)-1] + if len(ti.parentTypeStack) > 0 { + _, ti.parentTypeStack = ti.parentTypeStack[len(ti.parentTypeStack)-1], ti.parentTypeStack[:len(ti.parentTypeStack)-1] + } case kinds.Field: // pop ti.fieldDefStack - _, ti.fieldDefStack = ti.fieldDefStack[len(ti.fieldDefStack)-1], ti.fieldDefStack[:len(ti.fieldDefStack)-1] + if len(ti.fieldDefStack) > 0 { + _, ti.fieldDefStack = ti.fieldDefStack[len(ti.fieldDefStack)-1], ti.fieldDefStack[:len(ti.fieldDefStack)-1] + } // pop ti.typeStack - _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + if len(ti.typeStack) > 0 { + _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + } case kinds.Directive: ti.directive = nil case kinds.OperationDefinition: @@ -187,19 +193,27 @@ func (ti *TypeInfo) Leave(node ast.Node) { fallthrough case kinds.FragmentDefinition: // pop ti.typeStack - _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + if len(ti.typeStack) > 0 { + _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + } case kinds.VariableDefinition: // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } case kinds.Argument: ti.argument = nil // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } case kinds.ListValue: fallthrough case kinds.ObjectField: // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } } } diff --git a/validator.go b/validator.go index 4bb0790f..b3714f09 100644 --- a/validator.go +++ b/validator.go @@ -34,6 +34,22 @@ func ValidateDocument(schema *Schema, astDoc *ast.Document, rules []ValidationRu } func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRuleFn) []gqlerrors.FormattedError { + + typeInfo := NewTypeInfo(schema) + context := NewValidationContext(schema, astDoc, typeInfo) + visitors := []*visitor.VisitorOptions{} + + for _, rule := range rules { + instance := rule(context) + visitors = append(visitors, instance.VisitorOpts) + } + + // Visit the whole document with each instance of all provided rules. + visitor.Visit(astDoc, visitor.VisitWithTypeInfo(typeInfo, visitor.VisitInParallel(visitors)), nil) + return context.Errors() +} + +func visitUsingRulesOld(schema *Schema, astDoc *ast.Document, rules []ValidationRuleFn) []gqlerrors.FormattedError { typeInfo := NewTypeInfo(schema) context := NewValidationContext(schema, astDoc, typeInfo) @@ -61,7 +77,7 @@ func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRul // Get the visitor function from the validation instance, and if it // exists, call it with the visitor arguments. - enterFn := visitor.GetVisitFn(instance.VisitorOpts, false, kind) + enterFn := visitor.GetVisitFn(instance.VisitorOpts, kind, false) if enterFn != nil { action, result = enterFn(p) } @@ -102,7 +118,7 @@ func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRul // Get the visitor function from the validation instance, and if it // exists, call it with the visitor arguments. - leaveFn := visitor.GetVisitFn(instance.VisitorOpts, true, kind) + leaveFn := visitor.GetVisitFn(instance.VisitorOpts, kind, true) if leaveFn != nil { action, result = leaveFn(p) } @@ -126,19 +142,42 @@ func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRul return context.Errors() } +type HasSelectionSet interface { + GetKind() string + GetLoc() *ast.Location + GetSelectionSet() *ast.SelectionSet +} + +var _ HasSelectionSet = (*ast.OperationDefinition)(nil) +var _ HasSelectionSet = (*ast.FragmentDefinition)(nil) + +type VariableUsage struct { + Node *ast.Variable + Type Input +} + type ValidationContext struct { - schema *Schema - astDoc *ast.Document - typeInfo *TypeInfo - fragments map[string]*ast.FragmentDefinition - errors []gqlerrors.FormattedError + schema *Schema + astDoc *ast.Document + typeInfo *TypeInfo + errors []gqlerrors.FormattedError + fragments map[string]*ast.FragmentDefinition + variableUsages map[HasSelectionSet][]*VariableUsage + recursiveVariableUsages map[*ast.OperationDefinition][]*VariableUsage + recursivelyReferencedFragments map[*ast.OperationDefinition][]*ast.FragmentDefinition + fragmentSpreads map[HasSelectionSet][]*ast.FragmentSpread } func NewValidationContext(schema *Schema, astDoc *ast.Document, typeInfo *TypeInfo) *ValidationContext { return &ValidationContext{ - schema: schema, - astDoc: astDoc, - typeInfo: typeInfo, + schema: schema, + astDoc: astDoc, + typeInfo: typeInfo, + fragments: map[string]*ast.FragmentDefinition{}, + variableUsages: map[HasSelectionSet][]*VariableUsage{}, + recursiveVariableUsages: map[*ast.OperationDefinition][]*VariableUsage{}, + recursivelyReferencedFragments: map[*ast.OperationDefinition][]*ast.FragmentDefinition{}, + fragmentSpreads: map[HasSelectionSet][]*ast.FragmentSpread{}, } } @@ -177,7 +216,126 @@ func (ctx *ValidationContext) Fragment(name string) *ast.FragmentDefinition { f, _ := ctx.fragments[name] return f } +func (ctx *ValidationContext) FragmentSpreads(node HasSelectionSet) []*ast.FragmentSpread { + if spreads, ok := ctx.fragmentSpreads[node]; ok && spreads != nil { + return spreads + } + + spreads := []*ast.FragmentSpread{} + setsToVisit := []*ast.SelectionSet{node.GetSelectionSet()} + + for { + if len(setsToVisit) == 0 { + break + } + var set *ast.SelectionSet + // pop + set, setsToVisit = setsToVisit[len(setsToVisit)-1], setsToVisit[:len(setsToVisit)-1] + if set.Selections != nil { + for _, selection := range set.Selections { + switch selection := selection.(type) { + case *ast.FragmentSpread: + spreads = append(spreads, selection) + case *ast.Field: + if selection.SelectionSet != nil { + setsToVisit = append(setsToVisit, selection.SelectionSet) + } + case *ast.InlineFragment: + if selection.SelectionSet != nil { + setsToVisit = append(setsToVisit, selection.SelectionSet) + } + } + } + } + ctx.fragmentSpreads[node] = spreads + } + return spreads +} +func (ctx *ValidationContext) RecursivelyReferencedFragments(operation *ast.OperationDefinition) []*ast.FragmentDefinition { + if fragments, ok := ctx.recursivelyReferencedFragments[operation]; ok && fragments != nil { + return fragments + } + + fragments := []*ast.FragmentDefinition{} + collectedNames := map[string]bool{} + nodesToVisit := []HasSelectionSet{operation} + + for { + if len(nodesToVisit) == 0 { + break + } + + var node HasSelectionSet + + node, nodesToVisit = nodesToVisit[len(nodesToVisit)-1], nodesToVisit[:len(nodesToVisit)-1] + spreads := ctx.FragmentSpreads(node) + for _, spread := range spreads { + fragName := "" + if spread.Name != nil { + fragName = spread.Name.Value + } + if res, ok := collectedNames[fragName]; !ok || !res { + collectedNames[fragName] = true + fragment := ctx.Fragment(fragName) + if fragment != nil { + fragments = append(fragments, fragment) + nodesToVisit = append(nodesToVisit, fragment) + } + } + + } + } + + ctx.recursivelyReferencedFragments[operation] = fragments + return fragments +} +func (ctx *ValidationContext) VariableUsages(node HasSelectionSet) []*VariableUsage { + if usages, ok := ctx.variableUsages[node]; ok && usages != nil { + return usages + } + usages := []*VariableUsage{} + typeInfo := NewTypeInfo(ctx.schema) + + visitor.Visit(node, visitor.VisitWithTypeInfo(typeInfo, &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.VariableDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.Variable: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.Variable); ok && node != nil { + usages = append(usages, &VariableUsage{ + Node: node, + Type: typeInfo.InputType(), + }) + } + return visitor.ActionNoChange, nil + }, + }, + }, + }), nil) + + ctx.variableUsages[node] = usages + return usages +} +func (ctx *ValidationContext) RecursiveVariableUsages(operation *ast.OperationDefinition) []*VariableUsage { + if usages, ok := ctx.recursiveVariableUsages[operation]; ok && usages != nil { + return usages + } + usages := ctx.VariableUsages(operation) + + fragments := ctx.RecursivelyReferencedFragments(operation) + for _, fragment := range fragments { + fragmentUsages := ctx.VariableUsages(fragment) + usages = append(usages, fragmentUsages...) + } + + ctx.recursiveVariableUsages[operation] = usages + return usages +} func (ctx *ValidationContext) Type() Output { return ctx.typeInfo.Type() }