diff --git a/language/ast/selections.go b/language/ast/selections.go index 1b7e60d2..dd36cf26 100644 --- a/language/ast/selections.go +++ b/language/ast/selections.go @@ -46,6 +46,10 @@ func (f *Field) GetLoc() *Location { return f.Loc } +func (f *Field) GetSelectionSet() *SelectionSet { + return f.SelectionSet +} + // FragmentSpread implements Node, Selection type FragmentSpread struct { Kind string @@ -74,6 +78,10 @@ func (fs *FragmentSpread) GetLoc() *Location { return fs.Loc } +func (fs *FragmentSpread) GetSelectionSet() *SelectionSet { + return nil +} + // InlineFragment implements Node, Selection type InlineFragment struct { Kind string @@ -104,6 +112,10 @@ func (f *InlineFragment) GetLoc() *Location { return f.Loc } +func (f *InlineFragment) GetSelectionSet() *SelectionSet { + return f.SelectionSet +} + // SelectionSet implements Node type SelectionSet struct { Kind string diff --git a/language/type_info/type_info.go b/language/type_info/type_info.go new file mode 100644 index 00000000..02b7b04f --- /dev/null +++ b/language/type_info/type_info.go @@ -0,0 +1,14 @@ +package type_info + +import ( + "github.com/graphql-go/graphql/language/ast" +) + +/** + * TypeInfoI defines the interface for TypeInfo + * Implementation + */ +type TypeInfoI interface { + Enter(node ast.Node) + Leave(node ast.Node) +} diff --git a/language/visitor/visitor.go b/language/visitor/visitor.go index 83edbd9b..3188efec 100644 --- a/language/visitor/visitor.go +++ b/language/visitor/visitor.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/type_info" "reflect" ) @@ -380,7 +381,7 @@ Loop: kind = node.GetKind() } - visitFn := GetVisitFn(visitorOpts, isLeaving, kind) + visitFn := GetVisitFn(visitorOpts, kind, isLeaving) if visitFn != nil { p := VisitFuncParams{ Node: nodeIn, @@ -709,7 +710,144 @@ func isNilNode(node interface{}) bool { return val.Interface() == nil } -func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitFunc { +/** + * Creates a new visitor instance which delegates to many visitors to run in + * parallel. Each visitor will be visited for each node before moving on. + * + * Visitors must not directly modify the AST nodes and only returning false to + * skip sub-branches is supported. + */ +func VisitInParallel(visitorOptsSlice []*VisitorOptions) *VisitorOptions { + skipping := map[int]interface{}{} + + return &VisitorOptions{ + Enter: func(p VisitFuncParams) (string, interface{}) { + for i, visitorOpts := range visitorOptsSlice { + if _, ok := skipping[i]; !ok { + switch node := p.Node.(type) { + case ast.Node: + kind := node.GetKind() + fn := GetVisitFn(visitorOpts, kind, false) + if fn != nil { + action, _ := fn(p) + if action == ActionSkip { + skipping[i] = node + } + } + } + } + } + return ActionNoChange, nil + }, + Leave: func(p VisitFuncParams) (string, interface{}) { + for i, visitorOpts := range visitorOptsSlice { + if _, ok := skipping[i]; !ok { + switch node := p.Node.(type) { + case ast.Node: + kind := node.GetKind() + fn := GetVisitFn(visitorOpts, kind, true) + if fn != nil { + fn(p) + } + } + } else { + delete(skipping, i) + } + } + return ActionNoChange, nil + }, + } +} + +/** + * Creates a new visitor instance which maintains a provided TypeInfo instance + * along with visiting visitor. + * + * Visitors must not directly modify the AST nodes and only returning false to + * skip sub-branches is supported. + */ +func VisitWithTypeInfo(typeInfo type_info.TypeInfoI, visitorOpts *VisitorOptions) *VisitorOptions { + return &VisitorOptions{ + Enter: func(p VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(ast.Node); ok { + typeInfo.Enter(node) + fn := GetVisitFn(visitorOpts, node.GetKind(), false) + if fn != nil { + action, _ := fn(p) + if action == ActionSkip { + typeInfo.Leave(node) + return ActionSkip, nil + } + } + } + return ActionNoChange, nil + }, + Leave: func(p VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(ast.Node); ok { + fn := GetVisitFn(visitorOpts, node.GetKind(), true) + if fn != nil { + fn(p) + } + typeInfo.Leave(node) + } + return ActionNoChange, nil + }, + } +} + +/** + * Given a visitor instance, if it is leaving or not, and a node kind, return + * the function the visitor runtime should call. + */ +func GetVisitFn(visitorOpts *VisitorOptions, kind string, isLeaving bool) VisitFunc { + if visitorOpts == nil { + return nil + } + kindVisitor, ok := visitorOpts.KindFuncMap[kind] + if ok { + if !isLeaving && kindVisitor.Kind != nil { + // { Kind() {} } + return kindVisitor.Kind + } + if isLeaving { + // { Kind: { leave() {} } } + return kindVisitor.Leave + } else { + // { Kind: { enter() {} } } + return kindVisitor.Enter + } + } else { + + if isLeaving { + // { enter() {} } + specificVisitor := visitorOpts.Leave + if specificVisitor != nil { + return specificVisitor + } + if specificKindVisitor, ok := visitorOpts.LeaveKindMap[kind]; ok { + // { leave: { Kind() {} } } + return specificKindVisitor + } + + } else { + // { leave() {} } + specificVisitor := visitorOpts.Enter + if specificVisitor != nil { + return specificVisitor + } + if specificKindVisitor, ok := visitorOpts.EnterKindMap[kind]; ok { + // { enter: { Kind() {} } } + return specificKindVisitor + } + } + } + + return nil +} + +///// DELETE //// + +func GetVisitFnOld(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitFunc { if visitorOpts == nil { return nil } @@ -753,3 +891,38 @@ func GetVisitFn(visitorOpts *VisitorOptions, isLeaving bool, kind string) VisitF return nil } + +/* + + +export function getVisitFn(visitor, isLeaving, kind) { + var kindVisitor = visitor[kind]; + if (kindVisitor) { + if (!isLeaving && typeof kindVisitor === 'function') { + // { Kind() {} } + return kindVisitor; + } + var kindSpecificVisitor = isLeaving ? kindVisitor.leave : kindVisitor.enter; + if (typeof kindSpecificVisitor === 'function') { + // { Kind: { enter() {}, leave() {} } } + return kindSpecificVisitor; + } + return; + } + var specificVisitor = isLeaving ? visitor.leave : visitor.enter; + if (specificVisitor) { + if (typeof specificVisitor === 'function') { + // { enter() {}, leave() {} } + return specificVisitor; + } + var specificKindVisitor = specificVisitor[kind]; + if (typeof specificKindVisitor === 'function') { + // { enter: { Kind() {} }, leave: { Kind() {} } } + return specificKindVisitor; + } + } +} + + + +*/ diff --git a/rules.go b/rules.go index c1fd9a8f..d46520a8 100644 --- a/rules.go +++ b/rules.go @@ -57,7 +57,7 @@ func newValidationError(message string, nodes []ast.Node) *gqlerrors.Error { ) } -func reportErrorAndReturn(context *ValidationContext, message string, nodes []ast.Node) (string, interface{}) { +func reportError(context *ValidationContext, message string, nodes []ast.Node) (string, interface{}) { context.ReportError(newValidationError(message, nodes)) return visitor.ActionNoChange, nil } @@ -84,7 +84,7 @@ func ArgumentsOfCorrectTypeRule(context *ValidationContext) *ValidationRuleInsta if argAST.Name != nil { argNameValue = argAST.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Argument "%v" expected type "%v" but got: %v.`, argNameValue, argDef.Type, printer.Print(value)), @@ -125,7 +125,7 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI ttype := context.InputType() if ttype, ok := ttype.(*NonNull); ok && defaultValue != nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Variable "$%v" of type "%v" is required and will not use the default value. Perhaps you meant to use type "%v".`, name, ttype, ttype.OfType), @@ -133,7 +133,7 @@ func DefaultValuesOfCorrectTypeRule(context *ValidationContext) *ValidationRuleI ) } if ttype != nil && defaultValue != nil && !isValidLiteralValue(ttype, defaultValue) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Variable "$%v" of type "%v" has invalid default value: %v.`, name, ttype, printer.Print(defaultValue)), @@ -175,7 +175,7 @@ func FieldsOnCorrectTypeRule(context *ValidationContext) *ValidationRuleInstance if node.Name != nil { nodeName = node.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Cannot query field "%v" on "%v".`, nodeName, ttype.Name()), @@ -210,7 +210,7 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn if node, ok := p.Node.(*ast.InlineFragment); ok { ttype := context.Type() if ttype != nil && !IsCompositeType(ttype) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment cannot condition on non composite type "%v".`, ttype), []ast.Node{node.TypeCondition}, @@ -229,7 +229,7 @@ func FragmentsOnCompositeTypesRule(context *ValidationContext) *ValidationRuleIn if node.Name != nil { nodeName = node.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment "%v" cannot condition on non composite type "%v".`, nodeName, printer.Print(node.TypeCondition)), []ast.Node{node.TypeCondition}, @@ -289,7 +289,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance if parentType != nil { parentTypeName = parentType.Name() } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown argument "%v" on field "%v" of type "%v".`, nodeName, fieldDef.Name, parentTypeName), []ast.Node{node}, @@ -311,7 +311,7 @@ func KnownArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance } } if directiveArgDef == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown argument "%v" on directive "@%v".`, nodeName, directive.Name), []ast.Node{node}, @@ -357,7 +357,7 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } } if directiveDef == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown directive "%v".`, nodeName), []ast.Node{node}, @@ -373,14 +373,14 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { } if appliedTo.GetKind() == kinds.OperationDefinition && directiveDef.OnOperation == false { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "operation"), []ast.Node{node}, ) } if appliedTo.GetKind() == kinds.Field && directiveDef.OnField == false { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "field"), []ast.Node{node}, @@ -389,7 +389,7 @@ func KnownDirectivesRule(context *ValidationContext) *ValidationRuleInstance { if (appliedTo.GetKind() == kinds.FragmentSpread || appliedTo.GetKind() == kinds.InlineFragment || appliedTo.GetKind() == kinds.FragmentDefinition) && directiveDef.OnFragment == false { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Directive "%v" may not be used on "%v".`, nodeName, "fragment"), []ast.Node{node}, @@ -430,7 +430,7 @@ func KnownFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance fragment := context.Fragment(fragmentName) if fragment == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown fragment "%v".`, fragmentName), []ast.Node{node.Name}, @@ -467,7 +467,7 @@ func KnownTypeNamesRule(context *ValidationContext) *ValidationRuleInstance { } ttype := context.Schema().Type(typeNameValue) if ttype == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Unknown type "%v".`, typeNameValue), []ast.Node{node}, @@ -512,7 +512,7 @@ func LoneAnonymousOperationRule(context *ValidationContext) *ValidationRuleInsta Kind: func(p visitor.VisitFuncParams) (string, interface{}) { if node, ok := p.Node.(*ast.OperationDefinition); ok { if node.Name == nil && operationCount > 1 { - return reportErrorAndReturn( + return reportError( context, `This anonymous operation must be the only defined operation.`, []ast.Node{node}, @@ -613,11 +613,11 @@ func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { if len(spreadNames) > 0 { via = " via " + strings.Join(spreadNames, ", ") } - err := newValidationError( + reportError( + context, fmt.Sprintf(`Cannot spread fragment "%v" within itself%v.`, initialName, via), cyclePath, ) - context.ReportError(err) continue } spreadPathHasCurrentNode := false @@ -654,77 +654,64 @@ func NoFragmentCyclesRule(context *ValidationContext) *ValidationRuleInstance { * and via fragment spreads, are defined by that operation. */ func NoUndefinedVariablesRule(context *ValidationContext) *ValidationRuleInstance { - var operation *ast.OperationDefinition - var visitedFragmentNames = map[string]bool{} - var definedVariableNames = map[string]bool{} + var variableNameDefined = map[string]bool{} + visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.OperationDefinition); ok && node != nil { - operation = node - visitedFragmentNames = map[string]bool{} - definedVariableNames = map[string]bool{} - } - return visitor.ActionNoChange, nil - }, - }, - kinds.VariableDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { - variableName := "" - if node.Variable != nil && node.Variable.Name != nil { - variableName = node.Variable.Name.Value - } - definedVariableNames[variableName] = true - } + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + variableNameDefined = map[string]bool{} return visitor.ActionNoChange, nil }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variable, ok := p.Node.(*ast.Variable); ok && variable != nil { - variableName := "" - if variable.Name != nil { - variableName = variable.Name.Value - } - if val, _ := definedVariableNames[variableName]; !val { - withinFragment := false - for _, node := range p.Ancestors { - if node.GetKind() == kinds.FragmentDefinition { - withinFragment = true - break - } + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if operation, ok := p.Node.(*ast.OperationDefinition); ok && operation != nil { + usages := context.RecursiveVariableUsages(operation) + + for _, usage := range usages { + if usage == nil { + continue } - if withinFragment == true && operation != nil && operation.Name != nil { - return reportErrorAndReturn( - context, - fmt.Sprintf(`Variable "$%v" is not defined by operation "%v".`, variableName, operation.Name.Value), - []ast.Node{variable, operation}, - ) + if usage.Node == nil { + continue + } + varName := "" + if usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + opName := "" + if operation.Name != nil { + opName = operation.Name.Value + } + if res, ok := variableNameDefined[varName]; !ok || !res { + if opName != "" { + reportError( + context, + fmt.Sprintf(`Variable "$%v" is not defined by operation "%v".`, varName, opName), + []ast.Node{usage.Node, operation}, + ) + } else { + + reportError( + context, + fmt.Sprintf(`Variable "$%v" is not defined.`, varName), + []ast.Node{usage.Node, operation}, + ) + } } - return reportErrorAndReturn( - context, - fmt.Sprintf(`Variable "$%v" is not defined.`, variableName), - []ast.Node{variable}, - ) } } return visitor.ActionNoChange, nil }, }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ + kinds.VariableDefinition: visitor.NamedVisitFuncs{ Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if node, ok := p.Node.(*ast.FragmentSpread); ok && node != nil { - // Only visit fragments of a particular name once per operation - fragmentName := "" - if node.Name != nil { - fragmentName = node.Name.Value - } - if val, ok := visitedFragmentNames[fragmentName]; ok && val == true { - return visitor.ActionSkip, nil + if node, ok := p.Node.(*ast.VariableDefinition); ok && node != nil { + variableName := "" + if node.Variable != nil && node.Variable.Name != nil { + variableName = node.Variable.Name.Value } - visitedFragmentNames[fragmentName] = true + // definedVariableNames[variableName] = true + variableNameDefined[variableName] = true } return visitor.ActionNoChange, nil }, @@ -817,11 +804,11 @@ func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { isFragNameUsed, ok := fragmentNameUsed[defName] if !ok || isFragNameUsed != true { - err := newValidationError( + reportError( + context, fmt.Sprintf(`Fragment "%v" is never used.`, defName), []ast.Node{def}, ) - context.ReportError(err) } } return visitor.ActionNoChange, nil @@ -843,33 +830,45 @@ func NoUnusedFragmentsRule(context *ValidationContext) *ValidationRuleInstance { */ func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { - var visitedFragmentNames = map[string]bool{} var variableDefs = []*ast.VariableDefinition{} - var variableNameUsed = map[string]bool{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ Enter: func(p visitor.VisitFuncParams) (string, interface{}) { - visitedFragmentNames = map[string]bool{} variableDefs = []*ast.VariableDefinition{} - variableNameUsed = map[string]bool{} return visitor.ActionNoChange, nil }, Leave: func(p visitor.VisitFuncParams) (string, interface{}) { - for _, def := range variableDefs { - variableName := "" - if def.Variable != nil && def.Variable.Name != nil { - variableName = def.Variable.Name.Value + if operation, ok := p.Node.(*ast.OperationDefinition); ok && operation != nil { + variableNameUsed := map[string]bool{} + usages := context.RecursiveVariableUsages(operation) + + for _, usage := range usages { + varName := "" + if usage != nil && usage.Node != nil && usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + if varName != "" { + variableNameUsed[varName] = true + } } - if isVariableNameUsed, _ := variableNameUsed[variableName]; isVariableNameUsed != true { - err := newValidationError( - fmt.Sprintf(`Variable "$%v" is never used.`, variableName), - []ast.Node{def}, - ) - context.ReportError(err) + for _, variableDef := range variableDefs { + variableName := "" + if variableDef != nil && variableDef.Variable != nil && variableDef.Variable.Name != nil { + variableName = variableDef.Variable.Name.Value + } + if res, ok := variableNameUsed[variableName]; !ok || !res { + reportError( + context, + fmt.Sprintf(`Variable "$%v" is never used.`, variableName), + []ast.Node{variableDef}, + ) + } } + } + return visitor.ActionNoChange, nil }, }, @@ -878,33 +877,6 @@ func NoUnusedVariablesRule(context *ValidationContext) *ValidationRuleInstance { if def, ok := p.Node.(*ast.VariableDefinition); ok && def != nil { variableDefs = append(variableDefs, def) } - // Do not visit deeper, or else the defined variable name will be visited. - return visitor.ActionSkip, nil - }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variable, ok := p.Node.(*ast.Variable); ok && variable != nil { - if variable.Name != nil { - variableNameUsed[variable.Name.Value] = true - } - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if spreadAST, ok := p.Node.(*ast.FragmentSpread); ok && spreadAST != nil { - // Only visit fragments of a particular name once per operation - spreadName := "" - if spreadAST.Name != nil { - spreadName = spreadAST.Name.Value - } - if hasVisitedFragmentNames, _ := visitedFragmentNames[spreadName]; hasVisitedFragmentNames == true { - return visitor.ActionSkip, nil - } - visitedFragmentNames[spreadName] = true - } return visitor.ActionNoChange, nil }, }, @@ -1301,7 +1273,8 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul for _, c := range conflicts { responseName := c.Reason.Name reason := c.Reason - err := newValidationError( + reportError( + context, fmt.Sprintf( `Fields "%v" conflict because %v.`, responseName, @@ -1309,7 +1282,6 @@ func OverlappingFieldsCanBeMergedRule(context *ValidationContext) *ValidationRul ), c.Fields, ) - context.ReportError(err) } return visitor.ActionNoChange, nil } @@ -1394,7 +1366,7 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst parentType, _ := context.ParentType().(Type) if fragType != nil && parentType != nil && !doTypesOverlap(fragType, parentType) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment cannot be spread here as objects of `+ `type "%v" can never be of type "%v".`, parentType, fragType), @@ -1415,7 +1387,7 @@ func PossibleFragmentSpreadsRule(context *ValidationContext) *ValidationRuleInst fragType := getFragmentType(context, fragName) parentType, _ := context.ParentType().(Type) if fragType != nil && parentType != nil && !doTypesOverlap(fragType, parentType) { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Fragment "%v" cannot be spread here as objects of `+ `type "%v" can never be of type "%v".`, fragName, parentType, fragType), @@ -1471,12 +1443,12 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if fieldAST.Name != nil { fieldName = fieldAST.Name.Value } - err := newValidationError( + reportError( + context, fmt.Sprintf(`Field "%v" argument "%v" of type "%v" `+ `is required but not provided.`, fieldName, argDef.Name(), argDefType), []ast.Node{fieldAST}, ) - context.ReportError(err) } } } @@ -1512,12 +1484,12 @@ func ProvidedNonNullArgumentsRule(context *ValidationContext) *ValidationRuleIns if directiveAST.Name != nil { directiveName = directiveAST.Name.Value } - err := newValidationError( + reportError( + context, fmt.Sprintf(`Directive "@%v" argument "%v" of type `+ `"%v" is required but not provided.`, directiveName, argDef.Name(), argDefType), []ast.Node{directiveAST}, ) - context.ReportError(err) } } } @@ -1554,14 +1526,14 @@ func ScalarLeafsRule(context *ValidationContext) *ValidationRuleInstance { if ttype != nil { if IsLeafType(ttype) { if node.SelectionSet != nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Field "%v" of type "%v" must not have a sub selection.`, nodeName, ttype), []ast.Node{node.SelectionSet}, ) } } else if node.SelectionSet == nil { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Field "%v" of type "%v" must have a sub selection.`, nodeName, ttype), []ast.Node{node}, @@ -1611,7 +1583,7 @@ func UniqueArgumentNamesRule(context *ValidationContext) *ValidationRuleInstance argName = node.Name.Value } if nameAST, ok := knownArgNames[argName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can be only one argument named "%v".`, argName), []ast.Node{nameAST, node.Name}, @@ -1648,7 +1620,7 @@ func UniqueFragmentNamesRule(context *ValidationContext) *ValidationRuleInstance fragmentName = node.Name.Value } if nameAST, ok := knownFragmentNames[fragmentName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can only be one fragment named "%v".`, fragmentName), []ast.Node{nameAST, node.Name}, @@ -1700,7 +1672,7 @@ func UniqueInputFieldNamesRule(context *ValidationContext) *ValidationRuleInstan fieldName = node.Name.Value } if knownNameAST, ok := knownNames[fieldName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can be only one input field named "%v".`, fieldName), []ast.Node{knownNameAST, node.Name}, @@ -1739,7 +1711,7 @@ func UniqueOperationNamesRule(context *ValidationContext) *ValidationRuleInstanc operationName = node.Name.Value } if nameAST, ok := knownOperationNames[operationName]; ok { - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`There can only be one operation named "%v".`, operationName), []ast.Node{nameAST, node.Name}, @@ -1779,7 +1751,7 @@ func VariablesAreInputTypesRule(context *ValidationContext) *ValidationRuleInsta if node.Variable != nil && node.Variable.Name != nil { variableName = node.Variable.Name.Value } - return reportErrorAndReturn( + return reportError( context, fmt.Sprintf(`Variable "$%v" cannot be non-input type "%v".`, variableName, printer.Print(node.Type)), @@ -1837,14 +1809,45 @@ func varTypeAllowedForType(varType Type, expectedType Type) bool { func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleInstance { varDefMap := map[string]*ast.VariableDefinition{} - visitedFragmentNames := map[string]bool{} visitorOpts := &visitor.VisitorOptions{ KindFuncMap: map[string]visitor.NamedVisitFuncs{ kinds.OperationDefinition: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { varDefMap = map[string]*ast.VariableDefinition{} - visitedFragmentNames = map[string]bool{} + return visitor.ActionNoChange, nil + }, + Leave: func(p visitor.VisitFuncParams) (string, interface{}) { + if operation, ok := p.Node.(*ast.OperationDefinition); ok { + + usages := context.RecursiveVariableUsages(operation) + for _, usage := range usages { + varName := "" + if usage != nil && usage.Node != nil && usage.Node.Name != nil { + varName = usage.Node.Name.Value + } + var varType Type + varDef, ok := varDefMap[varName] + if ok { + var err error + varType, err = typeFromAST(*context.Schema(), varDef.Type) + if err != nil { + varType = nil + } + } + if varType != nil && + usage.Type != nil && + !varTypeAllowedForType(effectiveType(varType, varDef), usage.Type) { + reportError( + context, + fmt.Sprintf(`Variable "$%v" of type "%v" used in position `+ + `expecting type "%v".`, varName, varType, usage.Type), + []ast.Node{usage.Node}, + ) + } + } + + } return visitor.ActionNoChange, nil }, }, @@ -1855,47 +1858,8 @@ func VariablesInAllowedPositionRule(context *ValidationContext) *ValidationRuleI if varDefAST.Variable != nil && varDefAST.Variable.Name != nil { defName = varDefAST.Variable.Name.Value } - varDefMap[defName] = varDefAST - } - return visitor.ActionNoChange, nil - }, - }, - kinds.FragmentSpread: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - // Only visit fragments of a particular name once per operation - if spreadAST, ok := p.Node.(*ast.FragmentSpread); ok { - spreadName := "" - if spreadAST.Name != nil { - spreadName = spreadAST.Name.Value - } - if hasVisited, _ := visitedFragmentNames[spreadName]; hasVisited { - return visitor.ActionSkip, nil - } - visitedFragmentNames[spreadName] = true - } - return visitor.ActionNoChange, nil - }, - }, - kinds.Variable: visitor.NamedVisitFuncs{ - Kind: func(p visitor.VisitFuncParams) (string, interface{}) { - if variableAST, ok := p.Node.(*ast.Variable); ok && variableAST != nil { - varName := "" - if variableAST.Name != nil { - varName = variableAST.Name.Value - } - varDef, _ := varDefMap[varName] - var varType Type - if varDef != nil { - varType, _ = typeFromAST(*context.Schema(), varDef.Type) - } - inputType := context.InputType() - if varType != nil && inputType != nil && !varTypeAllowedForType(effectiveType(varType, varDef), inputType) { - return reportErrorAndReturn( - context, - fmt.Sprintf(`Variable "$%v" of type "%v" used in position `+ - `expecting type "%v".`, varName, varType, inputType), - []ast.Node{variableAST}, - ) + if defName != "" { + varDefMap[defName] = varDefAST } } return visitor.ActionNoChange, nil diff --git a/rules_no_undefined_variables_test.go b/rules_no_undefined_variables_test.go index 64449842..0b253715 100644 --- a/rules_no_undefined_variables_test.go +++ b/rules_no_undefined_variables_test.go @@ -108,7 +108,7 @@ func TestValidate_NoUndefinedVariables_VariableNotDefined(t *testing.T) { field(a: $a, b: $b, c: $c, d: $d) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$d" is not defined.`, 3, 39), + testutil.RuleError(`Variable "$d" is not defined by operation "Foo".`, 3, 39, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableNotDefinedByUnnamedQuery(t *testing.T) { @@ -117,7 +117,7 @@ func TestValidate_NoUndefinedVariables_VariableNotDefinedByUnnamedQuery(t *testi field(a: $a) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 3, 18), + testutil.RuleError(`Variable "$a" is not defined.`, 3, 18, 2, 7), }) } func TestValidate_NoUndefinedVariables_MultipleVariablesNotDefined(t *testing.T) { @@ -126,8 +126,8 @@ func TestValidate_NoUndefinedVariables_MultipleVariablesNotDefined(t *testing.T) field(a: $a, b: $b, c: $c) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 3, 18), - testutil.RuleError(`Variable "$c" is not defined.`, 3, 32), + testutil.RuleError(`Variable "$a" is not defined by operation "Foo".`, 3, 18, 2, 7), + testutil.RuleError(`Variable "$c" is not defined by operation "Foo".`, 3, 32, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByUnnamedQuery(t *testing.T) { @@ -139,7 +139,7 @@ func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByUnnamedQuer field(a: $a) } `, []gqlerrors.FormattedError{ - testutil.RuleError(`Variable "$a" is not defined.`, 6, 18), + testutil.RuleError(`Variable "$a" is not defined.`, 6, 18, 2, 7), }) } func TestValidate_NoUndefinedVariables_VariableInFragmentNotDefinedByOperation(t *testing.T) { diff --git a/type_info.go b/type_info.go index a26825e3..3c6dd2e2 100644 --- a/type_info.go +++ b/type_info.go @@ -173,12 +173,18 @@ func (ti *TypeInfo) Leave(node ast.Node) { switch kind { case kinds.SelectionSet: // pop ti.parentTypeStack - _, ti.parentTypeStack = ti.parentTypeStack[len(ti.parentTypeStack)-1], ti.parentTypeStack[:len(ti.parentTypeStack)-1] + if len(ti.parentTypeStack) > 0 { + _, ti.parentTypeStack = ti.parentTypeStack[len(ti.parentTypeStack)-1], ti.parentTypeStack[:len(ti.parentTypeStack)-1] + } case kinds.Field: // pop ti.fieldDefStack - _, ti.fieldDefStack = ti.fieldDefStack[len(ti.fieldDefStack)-1], ti.fieldDefStack[:len(ti.fieldDefStack)-1] + if len(ti.fieldDefStack) > 0 { + _, ti.fieldDefStack = ti.fieldDefStack[len(ti.fieldDefStack)-1], ti.fieldDefStack[:len(ti.fieldDefStack)-1] + } // pop ti.typeStack - _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + if len(ti.typeStack) > 0 { + _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + } case kinds.Directive: ti.directive = nil case kinds.OperationDefinition: @@ -187,19 +193,27 @@ func (ti *TypeInfo) Leave(node ast.Node) { fallthrough case kinds.FragmentDefinition: // pop ti.typeStack - _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + if len(ti.typeStack) > 0 { + _, ti.typeStack = ti.typeStack[len(ti.typeStack)-1], ti.typeStack[:len(ti.typeStack)-1] + } case kinds.VariableDefinition: // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } case kinds.Argument: ti.argument = nil // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } case kinds.ListValue: fallthrough case kinds.ObjectField: // pop ti.inputTypeStack - _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + if len(ti.inputTypeStack) > 0 { + _, ti.inputTypeStack = ti.inputTypeStack[len(ti.inputTypeStack)-1], ti.inputTypeStack[:len(ti.inputTypeStack)-1] + } } } diff --git a/validator.go b/validator.go index 4bb0790f..b3714f09 100644 --- a/validator.go +++ b/validator.go @@ -34,6 +34,22 @@ func ValidateDocument(schema *Schema, astDoc *ast.Document, rules []ValidationRu } func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRuleFn) []gqlerrors.FormattedError { + + typeInfo := NewTypeInfo(schema) + context := NewValidationContext(schema, astDoc, typeInfo) + visitors := []*visitor.VisitorOptions{} + + for _, rule := range rules { + instance := rule(context) + visitors = append(visitors, instance.VisitorOpts) + } + + // Visit the whole document with each instance of all provided rules. + visitor.Visit(astDoc, visitor.VisitWithTypeInfo(typeInfo, visitor.VisitInParallel(visitors)), nil) + return context.Errors() +} + +func visitUsingRulesOld(schema *Schema, astDoc *ast.Document, rules []ValidationRuleFn) []gqlerrors.FormattedError { typeInfo := NewTypeInfo(schema) context := NewValidationContext(schema, astDoc, typeInfo) @@ -61,7 +77,7 @@ func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRul // Get the visitor function from the validation instance, and if it // exists, call it with the visitor arguments. - enterFn := visitor.GetVisitFn(instance.VisitorOpts, false, kind) + enterFn := visitor.GetVisitFn(instance.VisitorOpts, kind, false) if enterFn != nil { action, result = enterFn(p) } @@ -102,7 +118,7 @@ func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRul // Get the visitor function from the validation instance, and if it // exists, call it with the visitor arguments. - leaveFn := visitor.GetVisitFn(instance.VisitorOpts, true, kind) + leaveFn := visitor.GetVisitFn(instance.VisitorOpts, kind, true) if leaveFn != nil { action, result = leaveFn(p) } @@ -126,19 +142,42 @@ func visitUsingRules(schema *Schema, astDoc *ast.Document, rules []ValidationRul return context.Errors() } +type HasSelectionSet interface { + GetKind() string + GetLoc() *ast.Location + GetSelectionSet() *ast.SelectionSet +} + +var _ HasSelectionSet = (*ast.OperationDefinition)(nil) +var _ HasSelectionSet = (*ast.FragmentDefinition)(nil) + +type VariableUsage struct { + Node *ast.Variable + Type Input +} + type ValidationContext struct { - schema *Schema - astDoc *ast.Document - typeInfo *TypeInfo - fragments map[string]*ast.FragmentDefinition - errors []gqlerrors.FormattedError + schema *Schema + astDoc *ast.Document + typeInfo *TypeInfo + errors []gqlerrors.FormattedError + fragments map[string]*ast.FragmentDefinition + variableUsages map[HasSelectionSet][]*VariableUsage + recursiveVariableUsages map[*ast.OperationDefinition][]*VariableUsage + recursivelyReferencedFragments map[*ast.OperationDefinition][]*ast.FragmentDefinition + fragmentSpreads map[HasSelectionSet][]*ast.FragmentSpread } func NewValidationContext(schema *Schema, astDoc *ast.Document, typeInfo *TypeInfo) *ValidationContext { return &ValidationContext{ - schema: schema, - astDoc: astDoc, - typeInfo: typeInfo, + schema: schema, + astDoc: astDoc, + typeInfo: typeInfo, + fragments: map[string]*ast.FragmentDefinition{}, + variableUsages: map[HasSelectionSet][]*VariableUsage{}, + recursiveVariableUsages: map[*ast.OperationDefinition][]*VariableUsage{}, + recursivelyReferencedFragments: map[*ast.OperationDefinition][]*ast.FragmentDefinition{}, + fragmentSpreads: map[HasSelectionSet][]*ast.FragmentSpread{}, } } @@ -177,7 +216,126 @@ func (ctx *ValidationContext) Fragment(name string) *ast.FragmentDefinition { f, _ := ctx.fragments[name] return f } +func (ctx *ValidationContext) FragmentSpreads(node HasSelectionSet) []*ast.FragmentSpread { + if spreads, ok := ctx.fragmentSpreads[node]; ok && spreads != nil { + return spreads + } + + spreads := []*ast.FragmentSpread{} + setsToVisit := []*ast.SelectionSet{node.GetSelectionSet()} + + for { + if len(setsToVisit) == 0 { + break + } + var set *ast.SelectionSet + // pop + set, setsToVisit = setsToVisit[len(setsToVisit)-1], setsToVisit[:len(setsToVisit)-1] + if set.Selections != nil { + for _, selection := range set.Selections { + switch selection := selection.(type) { + case *ast.FragmentSpread: + spreads = append(spreads, selection) + case *ast.Field: + if selection.SelectionSet != nil { + setsToVisit = append(setsToVisit, selection.SelectionSet) + } + case *ast.InlineFragment: + if selection.SelectionSet != nil { + setsToVisit = append(setsToVisit, selection.SelectionSet) + } + } + } + } + ctx.fragmentSpreads[node] = spreads + } + return spreads +} +func (ctx *ValidationContext) RecursivelyReferencedFragments(operation *ast.OperationDefinition) []*ast.FragmentDefinition { + if fragments, ok := ctx.recursivelyReferencedFragments[operation]; ok && fragments != nil { + return fragments + } + + fragments := []*ast.FragmentDefinition{} + collectedNames := map[string]bool{} + nodesToVisit := []HasSelectionSet{operation} + + for { + if len(nodesToVisit) == 0 { + break + } + + var node HasSelectionSet + + node, nodesToVisit = nodesToVisit[len(nodesToVisit)-1], nodesToVisit[:len(nodesToVisit)-1] + spreads := ctx.FragmentSpreads(node) + for _, spread := range spreads { + fragName := "" + if spread.Name != nil { + fragName = spread.Name.Value + } + if res, ok := collectedNames[fragName]; !ok || !res { + collectedNames[fragName] = true + fragment := ctx.Fragment(fragName) + if fragment != nil { + fragments = append(fragments, fragment) + nodesToVisit = append(nodesToVisit, fragment) + } + } + + } + } + + ctx.recursivelyReferencedFragments[operation] = fragments + return fragments +} +func (ctx *ValidationContext) VariableUsages(node HasSelectionSet) []*VariableUsage { + if usages, ok := ctx.variableUsages[node]; ok && usages != nil { + return usages + } + usages := []*VariableUsage{} + typeInfo := NewTypeInfo(ctx.schema) + + visitor.Visit(node, visitor.VisitWithTypeInfo(typeInfo, &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.VariableDefinition: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + return visitor.ActionSkip, nil + }, + }, + kinds.Variable: visitor.NamedVisitFuncs{ + Kind: func(p visitor.VisitFuncParams) (string, interface{}) { + if node, ok := p.Node.(*ast.Variable); ok && node != nil { + usages = append(usages, &VariableUsage{ + Node: node, + Type: typeInfo.InputType(), + }) + } + return visitor.ActionNoChange, nil + }, + }, + }, + }), nil) + + ctx.variableUsages[node] = usages + return usages +} +func (ctx *ValidationContext) RecursiveVariableUsages(operation *ast.OperationDefinition) []*VariableUsage { + if usages, ok := ctx.recursiveVariableUsages[operation]; ok && usages != nil { + return usages + } + usages := ctx.VariableUsages(operation) + + fragments := ctx.RecursivelyReferencedFragments(operation) + for _, fragment := range fragments { + fragmentUsages := ctx.VariableUsages(fragment) + usages = append(usages, fragmentUsages...) + } + + ctx.recursiveVariableUsages[operation] = usages + return usages +} func (ctx *ValidationContext) Type() Output { return ctx.typeInfo.Type() }