diff --git a/graphql/e2e/auth/schema.graphql b/graphql/e2e/auth/schema.graphql index 4148b373968..ddfe04675ef 100644 --- a/graphql/e2e/auth/schema.graphql +++ b/graphql/e2e/auth/schema.graphql @@ -974,3 +974,43 @@ type Home @auth( favouriteMember: HomeMember } # union testing - end + +""" +This types are used to validate nested filting. +""" +type Nested_X @auth( + query: { rule: """ + query { + queryNested_X(filter: {y: {s: {eq: "y"}}}){ + __typename + } + } + """} +) { + b: Boolean @search + y: Nested_Y @hasInverse(field: x) @search +} + +type Nested_Y implements Nested_Z @auth( + query: { + or: [ + { rule: "{ $rbac: { eq: \"positive\" } }" } + { and: [ + { rule: "{ $rbac: { eq: \"uncertain\" } }" } + { rule: """ + query { + queryNested_Y(filter: {x: {b: true}}){ + __typename + } + } + """ + } + ] + }] + }) { + s: String @search(by: [hash]) +} + +interface Nested_Z { + x: Nested_X @search +} diff --git a/graphql/resolve/auth_delete_test.yaml b/graphql/resolve/auth_delete_test.yaml index 66d7b97f361..c28d0b6d3e9 100644 --- a/graphql/resolve/auth_delete_test.yaml +++ b/graphql/resolve/auth_delete_test.yaml @@ -880,3 +880,30 @@ } } +- name: "Delete selection by nested filter" + gqlquery: | + mutation { + deleteNested_X(filter: {y: {s: {eq: "-"}}}) { + numUids + } + } + variables: + jwtvar: + dgmutations: + - deletejson: | + [{ + "uid": "uid(x)" + },{ + "Nested_Z.x": {"uid": "uid(x)"}, + "uid": "uid(Nested_Y_3)" + }] + dgquery: |- + query { + x as deleteNested_X(func: type(Nested_X)) @filter((uid(deleteNested_X_y))) { + uid + Nested_Y_3 as Nested_X.y + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "-")) { + deleteNested_X_y as Nested_Z.x + } + } diff --git a/graphql/resolve/auth_query_test.yaml b/graphql/resolve/auth_query_test.yaml index a0083d44266..b71ca86f787 100644 --- a/graphql/resolve/auth_query_test.yaml +++ b/graphql/resolve/auth_query_test.yaml @@ -2148,3 +2148,112 @@ Person.id : uid } } + +- + name: "Query positive auth rules with nested filter" + gqlquery: | + query{ + queryNested_X(filter: {y: {s: {eq: "-"}}}) { + b + y { s } + } + } + jwtvar: + rbac: positive + dgquery: |- + query { + queryNested_X(func: uid(Nested_XRoot)) { + Nested_X.b : Nested_X.b + Nested_X.y : Nested_X.y @filter(uid(Nested_Y_1)) { + Nested_Y.s : Nested_Y.s + dgraph.uid : uid + } + dgraph.uid : uid + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "-")) { + queryNested_X_y as Nested_Z.x + } + Nested_XRoot as var(func: uid(Nested_X_3)) @filter(uid(Nested_X_Auth4)) + Nested_X_3 as var(func: type(Nested_X)) @filter((uid(queryNested_X_y))) + Nested_X_Auth4 as var(func: uid(Nested_X_3)) @filter((uid(Nested_X_Auth4_y))) @cascade + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "y")) { + Nested_X_Auth4_y as Nested_Z.x + } + var(func: uid(Nested_XRoot)) { + Nested_Y_2 as Nested_X.y + } + Nested_Y_1 as var(func: uid(Nested_Y_2)) + } + +- + name: "Query negative auth rules with nested filter" + gqlquery: | + query{ + queryNested_X(filter: {y: {s: {eq: "-"}}}) { + b + y { s } + } + } + jwtvar: + rbac: negative + dgquery: |- + query { + queryNested_X(func: uid(Nested_XRoot)) { + Nested_X.b : Nested_X.b + dgraph.uid : uid + } + queryNested_X_y as var() + Nested_XRoot as var(func: uid(Nested_X_3)) @filter(uid(Nested_X_Auth4)) + Nested_X_3 as var(func: type(Nested_X)) @filter((uid(queryNested_X_y))) + Nested_X_Auth4 as var(func: uid(Nested_X_3)) @filter((uid(Nested_X_Auth4_y))) @cascade + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "y")) { + Nested_X_Auth4_y as Nested_Z.x + } + } + +- + name: "Query uncertain auth rules with nested filter" + gqlquery: | + query{ + queryNested_X(filter: {y: {s: {eq: "-"}}}) { + b + y { s } + } + } + jwtvar: + rbac: uncertain + dgquery: |- + query { + queryNested_X(func: uid(Nested_XRoot)) { + Nested_X.b : Nested_X.b + Nested_X.y : Nested_X.y @filter(uid(Nested_Y_3)) { + Nested_Y.s : Nested_Y.s + dgraph.uid : uid + } + dgraph.uid : uid + } + var(func: uid(queryNested_X_yRoot)) { + queryNested_X_y as Nested_Z.x + } + queryNested_X_yRoot as var(func: uid(Nested_Y_1)) @filter(uid(Nested_Y_Auth2)) + Nested_Y_1 as var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "-")) + Nested_Y_Auth2 as var(func: uid(Nested_Y_1)) @filter((uid(Nested_Y_Auth2_x))) @cascade + var(func: type(Nested_X)) @filter(eq(Nested_X.b, true)) { + Nested_Y_Auth2_x as Nested_X.y + } + Nested_XRoot as var(func: uid(Nested_X_6)) @filter(uid(Nested_X_Auth7)) + Nested_X_6 as var(func: type(Nested_X)) @filter((uid(queryNested_X_y))) + Nested_X_Auth7 as var(func: uid(Nested_X_6)) @filter((uid(Nested_X_Auth7_y))) @cascade + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "y")) { + Nested_X_Auth7_y as Nested_Z.x + } + var(func: uid(Nested_XRoot)) { + Nested_Y_4 as Nested_X.y + } + Nested_Y_3 as var(func: uid(Nested_Y_4)) @filter(uid(Nested_Y_Auth5)) + Nested_Y_Auth5 as var(func: uid(Nested_Y_4)) @filter((uid(Nested_Y_Auth5_x))) @cascade + var(func: type(Nested_X)) @filter(eq(Nested_X.b, true)) { + Nested_Y_Auth5_x as Nested_X.y + } + } + \ No newline at end of file diff --git a/graphql/resolve/mutation_rewriter.go b/graphql/resolve/mutation_rewriter.go index d78ad8bf433..6c22cb28a85 100644 --- a/graphql/resolve/mutation_rewriter.go +++ b/graphql/resolve/mutation_rewriter.go @@ -809,7 +809,7 @@ func (arw *AddRewriter) FromMutationResult( return nil, errs } // No errors are thrown while rewriting queries by Ids. - return rewriteAsQueryByIds(mutation.QueryField(), uids, authRw), nil + return rewriteAsQueryByIds(mutation.QueryField(), uids, authRw, mutation.Alias()), nil } // FromMutationResult rewrites the query part of a GraphQL update mutation into a Dgraph query. @@ -845,7 +845,7 @@ func (urw *UpdateRewriter) FromMutationResult( parentVarName: mutation.MutatedType().Name() + "Root", } authRw.hasAuthRules = hasAuthRules(mutation.QueryField(), authRw) - return rewriteAsQueryByIds(mutation.QueryField(), uids, authRw), nil + return rewriteAsQueryByIds(mutation.QueryField(), uids, authRw, mutation.Alias()), nil } func (arw *AddRewriter) MutatedRootUIDs( @@ -998,7 +998,8 @@ func RewriteUpsertQueryFromMutation( addTypeFunc(dgQuery[0], m.MutatedType().DgraphName()) } - _ = addFilter(dgQuery[0], m.MutatedType(), filter) + _, varQry := addFilter(dgQuery[0], m.MutatedType(), filter, authRw, m.Alias()) + dgQuery = append(dgQuery, varQry...) } else { // It means this is called from upsert with Add mutation. // nodeID will be uid of the node to be upserted. We add UID func @@ -1115,7 +1116,7 @@ func (drw *deleteRewriter) Rewrite( } // these queries are responsible for querying the queryField - queryFieldQry := rewriteAsQuery(queryField, queryAuthRw) + queryFieldQry := rewriteAsQuery(queryField, queryAuthRw, MutationQueryVar) // we don't want the `x` query to show up in GraphQL JSON response while querying the query // field. So, need to make it `var` query and remove any children from it as there can be diff --git a/graphql/resolve/query.go b/graphql/resolve/query.go index f657663c30c..51968c113f5 100644 --- a/graphql/resolve/query.go +++ b/graphql/resolve/query.go @@ -127,7 +127,6 @@ func (qr *queryResolver) rewriteAndExecute(ctx context.Context, query schema.Que query.ResponseName())) } qry := dgraph.AsString(dgQuery) - queryTimer := newtimer(ctx, &dgraphQueryDuration.OffsetDuration) queryTimer.Start() resp, err := qr.executor.Execute(ctx, &dgoapi.Request{Query: qry, ReadOnly: true}, query) diff --git a/graphql/resolve/query_rewriter.go b/graphql/resolve/query_rewriter.go index 975ae69f47c..60b233e4d90 100644 --- a/graphql/resolve/query_rewriter.go +++ b/graphql/resolve/query_rewriter.go @@ -157,7 +157,7 @@ func (qr *queryRewriter) Rewrite( case schema.SimilarByEmbeddingQuery: return rewriteAsSimilarByEmbeddingQuery(gqlQuery, authRw), nil case schema.FilterQuery: - return rewriteAsQuery(gqlQuery, authRw), nil + return rewriteAsQuery(gqlQuery, authRw, gqlQuery.Alias()), nil case schema.PasswordQuery: return passwordQuery(gqlQuery, authRw) case schema.AggregateQuery: @@ -257,7 +257,8 @@ func aggregateQuery(query schema.Query, authRw *authRewriter) []*dql.GraphQuery // Add filter filter, _ := query.ArgValue("filter").(map[string]interface{}) - _ = addFilter(dgQuery[0], mainType, filter) + _, varQry := addFilter(dgQuery[0], mainType, filter, authRw, query.Alias()) + dgQuery = append(dgQuery, varQry...) dgQuery = authRw.addAuthQueries(mainType, dgQuery, rbac) @@ -451,7 +452,8 @@ func addUID(dgQuery *dql.GraphQuery) { func rewriteAsQueryByIds( field schema.Field, uids []uint64, - authRw *authRewriter) []*dql.GraphQuery { + authRw *authRewriter, + queryName string) []*dql.GraphQuery { if field == nil { return nil } @@ -475,7 +477,8 @@ func rewriteAsQueryByIds( addUIDFunc(dgQuery[0], intersection(ids, uids)) } - addArgumentsToField(dgQuery[0], field) + includedQueries := addArgumentsToField(dgQuery[0], field, authRw, queryName) + dgQuery = append(dgQuery, includedQueries...) // The function getQueryByIds is called for passwordQuery or fetching query result types // after making a mutation. In both cases, we want the selectionSet to use the `query` auth @@ -501,11 +504,15 @@ func rewriteAsQueryByIds( // addArgumentsToField adds various different arguments to a field, such as // filter, order and pagination. -func addArgumentsToField(dgQuery *dql.GraphQuery, field schema.Field) { +func addArgumentsToField(dgQuery *dql.GraphQuery, + field schema.Field, + auth *authRewriter, + queryName string) []*dql.GraphQuery { filter, _ := field.ArgValue("filter").(map[string]interface{}) - _ = addFilter(dgQuery, field.Type(), filter) + _, varQry := addFilter(dgQuery, field.Type(), filter, auth, queryName) addOrder(dgQuery, field) addPagination(dgQuery, field) + return varQry } func addTopLevelTypeFilter(query *dql.GraphQuery, field schema.Field) { @@ -545,7 +552,7 @@ func rewriteAsGet( } if len(xidArgToVal) == 0 { - dgQuery = rewriteAsQueryByIds(query, []uint64{uid}, auth) + dgQuery = rewriteAsQueryByIds(query, []uint64{uid}, auth, query.Alias()) // Add the type filter to the top level get query. When the auth has been written into the // query the top level get query may be present in query's children. @@ -774,7 +781,7 @@ func rewriteAsSimilarByIdQuery( }, Order: []*pb.Order{{Attr: "val(distance)", Desc: false}}, } - addArgumentsToField(sortQuery, query) + addArgumentsToField(sortQuery, query, auth, query.Alias()) dgQuery = append(dgQuery, aggQuery, similarQuery, sortQuery) return dgQuery @@ -803,7 +810,7 @@ func rewriteAsSimilarByIdQuery( func rewriteAsSimilarByEmbeddingQuery( query schema.Query, auth *authRewriter) []*dql.GraphQuery { - dgQuery := rewriteAsQuery(query, auth) + dgQuery := rewriteAsQuery(query, auth, query.Alias()) // Remember dgQuery[0].Children as result type for the last block // in the rewritten query @@ -951,13 +958,15 @@ func addCommonRules( return []*dql.GraphQuery{dgQuery}, rbac } -func rewriteAsQuery(field schema.Field, authRw *authRewriter) []*dql.GraphQuery { +func rewriteAsQuery(field schema.Field, authRw *authRewriter, queryName string) []*dql.GraphQuery { dgQuery, rbac := addCommonRules(field, field.Type(), authRw) if rbac == schema.Negative { return dgQuery } - addArgumentsToField(dgQuery[0], field) + varQry := addArgumentsToField(dgQuery[0], field, authRw, queryName) + dgQuery = append(dgQuery, varQry...) + selectionAuth := addSelectionSetFrom(dgQuery[0], field, authRw) // we don't need to query uid for auth queries, as they always have at least one field in their // selection set. @@ -1321,14 +1330,15 @@ func (authRw *authRewriter) rewriteRuleNode( // build // Todo2 as var(func: uid(Todo1)) @cascade { ...auth query 1... } varName := authRw.varGen.Next(typ, "", "", authRw.isWritingAuth) - r1 := rewriteAsQuery(qry, authRw) + r1 := rewriteAsQuery(qry, authRw, varName) r1[0].Var = varName r1[0].Attr = "var" if len(r1[0].Cascade) == 0 { r1[0].Cascade = append(r1[0].Cascade, "__all__") } - return []*dql.GraphQuery{r1[0]}, &dql.FilterTree{ + // return all queries, including the nested var queries. + return r1, &dql.FilterTree{ Func: &dql.Function{ Name: "uid", Args: []dql.Arg{{Value: varName}}, @@ -1465,7 +1475,7 @@ func buildAggregateFields( // Filter for aggregate Fields. This is added to all count aggregate fields // and mainField fieldFilter, _ := f.ArgValue("filter").(map[string]interface{}) - _ = addFilter(mainField, constructedForType, fieldFilter) + _, varQry := addFilter(mainField, constructedForType, fieldFilter, auth, f.Alias()) // Add type filter in case the Dgraph predicate for which the aggregate // field belongs to is a reverse edge @@ -1489,7 +1499,7 @@ func buildAggregateFields( Attr: "count(" + constructedForDgraphPredicate + ")", } // Add filter to count aggregation field. - _ = addFilter(aggregateChild, constructedForType, fieldFilter) + addFilter(aggregateChild, constructedForType, fieldFilter, auth, f.Alias()) // Add type filter in case the Dgraph predicate for which the aggregate // field belongs to is a reverse edge @@ -1597,6 +1607,7 @@ func buildAggregateFields( // not added to them. aggregateChildren = append(aggregateChildren, otherAggregateChildren...) retAuthQueries = append(retAuthQueries, fieldAuth...) + retAuthQueries = append(retAuthQueries, varQry...) return aggregateChildren, retAuthQueries } @@ -1681,10 +1692,13 @@ func addSelectionSetFrom( filter, _ := f.ArgValue("filter").(map[string]interface{}) // if this field has been filtered out by the filter, then don't add it in DQL query - if includeField := addFilter(child, f.Type(), filter); !includeField { + includeField, varQry := addFilter(child, f.Type(), filter, auth, f.Alias()) + if !includeField { continue } + authQueries = append(authQueries, varQry...) + // Add type filter in case the Dgraph predicate is a reverse edge if strings.HasPrefix(f.DgraphPredicate(), "~") { addTypeFilter(child, f.Type()) @@ -1889,9 +1903,16 @@ func idFilter(filter map[string]interface{}, idField schema.FieldDefinition) []u // addFilter adds a filter to the input DQL query. It returns false if the field for which the // filter was specified should not be included in the DQL query. // Currently, it would only be false for a union field when no memberTypes are queried. -func addFilter(q *dql.GraphQuery, typ schema.Type, filter map[string]interface{}) bool { +func addFilter(q *dql.GraphQuery, + typ schema.Type, + filter map[string]interface{}, + auth *authRewriter, + queryName string) (bool, []*dql.GraphQuery) { + + varQry := []*dql.GraphQuery{} + if len(filter) == 0 { - return true + return true, varQry } // There are two cases here. @@ -1913,18 +1934,19 @@ func addFilter(q *dql.GraphQuery, typ schema.Type, filter map[string]interface{} } if typ.IsUnion() { - if filter, includeField := buildUnionFilter(typ, filter); includeField { + if filter, varq, includeField := buildUnionFilter(typ, filter, auth, queryName); includeField { q.Filter = filter + varQry = varq } else { - return false + return false, varQry } } else { - q.Filter = buildFilter(typ, filter) + q.Filter, varQry = buildFilter(typ, filter, auth, queryName) } if filterAtRoot { addTypeFilter(q, typ) } - return true + return true, varQry } // buildFilter builds a Dgraph dql.FilterTree from a GraphQL 'filter' arg. @@ -1935,6 +1957,8 @@ func addFilter(q *dql.GraphQuery, typ schema.Type, filter map[string]interface{} // filter: { title: { anyofterms: "GraphQL" }, isPublished: true, ... } // or // filter: { title: { anyofterms: "GraphQL" }, and: { not: { ... } } } +// or +// filter: { : { ... }, ... } // etc // // typ is the GraphQL type we are filtering on, and is needed to turn for example @@ -1953,8 +1977,12 @@ func addFilter(q *dql.GraphQuery, typ schema.Type, filter map[string]interface{} // ATM those will probably generate junk that might cause a Dgraph error. And // bubble back to the user as a GraphQL error when the query fails. Really, // they should fail query validation and never get here. -func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree { +func buildFilter(typ schema.Type, + filter map[string]interface{}, + auth *authRewriter, + queryName string) (*dql.FilterTree, []*dql.GraphQuery) { + var varQry []*dql.GraphQuery var ands []*dql.FilterTree var or *dql.FilterTree // Get a stable ordering so we generate the same thing each time. @@ -1970,6 +1998,10 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree if filter[field] == nil { continue } + + // derive variable name for nested objects + qn := queryName + "_" + field + switch field { // In 'and', 'or' and 'not' cases, filter[field] must be a map[string]interface{} @@ -1988,12 +2020,14 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree // ... and: [{}] switch v := filter[field].(type) { case map[string]interface{}: - ft := buildFilter(typ, v) + ft, qs := buildFilter(typ, v, auth, qn) ands = append(ands, ft) + varQry = append(varQry, qs...) case []interface{}: for _, obj := range v { - ft := buildFilter(typ, obj.(map[string]interface{})) + ft, qs := buildFilter(typ, obj.(map[string]interface{}), auth, qn) ands = append(ands, ft) + varQry = append(varQry, qs...) } } case "or": @@ -2008,12 +2042,15 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree // ... or: [{}] switch v := filter[field].(type) { case map[string]interface{}: - or = buildFilter(typ, v) + cond, qs := buildFilter(typ, v, auth, qn) + or = cond + varQry = append(varQry, qs...) case []interface{}: ors := make([]*dql.FilterTree, 0, len(v)) for _, obj := range v { - ft := buildFilter(typ, obj.(map[string]interface{})) + ft, qs := buildFilter(typ, obj.(map[string]interface{}), auth, qn) ors = append(ors, ft) + varQry = append(varQry, qs...) } or = &dql.FilterTree{ Child: ors, @@ -2025,13 +2062,85 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree // we are here ^^ // -> // @filter(anyofterms(Post.title, "GraphQL") AND NOT eq(Post.isPublished, true)) - not := buildFilter(typ, filter[field].(map[string]interface{})) + not, qs := buildFilter(typ, filter[field].(map[string]interface{}), auth, qn) ands = append(ands, &dql.FilterTree{ Op: "not", Child: []*dql.FilterTree{not}, }) + varQry = append(varQry, qs...) default: + // Handle nested object filtering + // + // filter: { : { ... }, ... } + // we are here ^^ + // -> + // var() @filter(){ + // nested_field_name as + // } + // root() @filter(var(nested_field_name)) + fd := typ.Field(field) + if fd != nil && fd.HasSearchDirective() { + + if inv := fd.Inverse(); inv != nil { + + fil, qs := buildFilter(fd.Type(), filter[field].(map[string]interface{}), auth, qn) + varQry = append(varQry, qs...) + + // add the uids of the nested object + ands = append(ands, &dql.FilterTree{ + Op: "and", + Child: []*dql.FilterTree{{ + Func: &dql.Function{ + Name: "uid", + Args: []dql.Arg{{Value: qn}}, + }, + }}, + }) + + // generate filter var query for nested object + nestedQry := &dql.GraphQuery{ + Attr: "var", + Func: &dql.Function{ + Name: "type", + Args: []dql.Arg{{Value: fd.Type().Name()}}, + }, + Filter: fil, + Children: []*dql.GraphQuery{{ + Attr: inv.DgraphPredicate(), + Var: qn, + }}, + } + + // add auth queries to nested field + nestedQrys := []*dql.GraphQuery{nestedQry} + + if !auth.isWritingAuth { + wr := &authRewriter{ + authVariables: auth.authVariables, + varGen: auth.varGen, + selector: auth.selector, + parentVarName: qn + "Root", + isWritingAuth: auth.isWritingAuth, + } + + rbac := wr.evaluateStaticRules(fd.Type()) + if rbac == schema.Uncertain { + nestedQrys = wr.addAuthQueries(fd.Type(), nestedQrys, rbac) + } else if rbac == schema.Negative { + nestedQry.Attr = "var()" + nestedQry.Var = qn + nestedQry.Func = nil + nestedQry.Filter = nil + nestedQry.Children = nil + } + } + + varQry = append(varQry, nestedQrys...) + continue + } + } + //// It's a base case like: //// title: { anyofterms: "GraphQL" } -> anyofterms(Post.title: "GraphQL") //// numLikes: { between : { min : 10, max:100 }} @@ -2048,7 +2157,9 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree // the filters with null values will be ignored in query rewriting. if fn == "eq" { hasFilterMap := map[string]interface{}{"not": map[string]interface{}{"has": []interface{}{field}}} - ands = append(ands, buildFilter(typ, hasFilterMap)) + ft, qs := buildFilter(typ, hasFilterMap, auth, qn) + ands = append(ands, ft) + varQry = append(varQry, qs...) } continue } @@ -2192,7 +2303,7 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree var andFt *dql.FilterTree if len(ands) == 0 { - return or + return or, varQry } else if len(ands) == 1 { andFt = ands[0] } else if len(ands) > 1 { @@ -2203,13 +2314,13 @@ func buildFilter(typ schema.Type, filter map[string]interface{}) *dql.FilterTree } if or == nil { - return andFt + return andFt, varQry } return &dql.FilterTree{ Op: "or", Child: []*dql.FilterTree{andFt, or}, - } + }, varQry } func buildHasFilterList(typ schema.Type, fieldsSlice []interface{}) []*dql.FilterTree { @@ -2271,12 +2382,17 @@ func buildMultiPolygon(multipolygon map[string]interface{}, buf *bytes.Buffer) { x.Check2(buf.WriteString("]")) } -func buildUnionFilter(typ schema.Type, filter map[string]interface{}) (*dql.FilterTree, bool) { +func buildUnionFilter(typ schema.Type, + filter map[string]interface{}, + auth *authRewriter, + queryName string) (*dql.FilterTree, []*dql.GraphQuery, bool) { + + var varQry []*dql.GraphQuery memberTypesList, ok := filter["memberTypes"].([]interface{}) // if memberTypes was specified to be an empty list like: { memberTypes: [], ...}, // then we don't need to include the field, on which the filter was specified, in the query. if ok && len(memberTypesList) == 0 { - return nil, false + return nil, varQry, false } ft := &dql.FilterTree{ @@ -2294,11 +2410,13 @@ func buildUnionFilter(typ schema.Type, filter map[string]interface{}) (*dql.Filt memberTypeFt = &dql.FilterTree{Func: buildTypeFunc(memberType.DgraphName())} } else { // else we need to query only the nodes which match the filter for that member type + ft, qs := buildFilter(memberType, memberTypeFilter, auth, queryName) + varQry = qs memberTypeFt = &dql.FilterTree{ Op: "and", Child: []*dql.FilterTree{ {Func: buildTypeFunc(memberType.DgraphName())}, - buildFilter(memberType, memberTypeFilter), + ft, }, } } @@ -2306,7 +2424,7 @@ func buildUnionFilter(typ schema.Type, filter map[string]interface{}) (*dql.Filt } // return true because we want to include the field with filter in query - return ft, true + return ft, varQry, true } func maybeQuoteArg(fn string, arg interface{}) string { diff --git a/graphql/resolve/query_test.yaml b/graphql/resolve/query_test.yaml index ab15599d020..fe93d817b3c 100644 --- a/graphql/resolve/query_test.yaml +++ b/graphql/resolve/query_test.yaml @@ -3517,4 +3517,142 @@ dgraph.uid : uid ProjectDotProduct.vector_distance : val(distance) } - } \ No newline at end of file + } + +- name: "query nested type with interface" + gqlquery: | + query { + queryNested_X(filter: {s: {eq: ""}, y: {s: {eq: ""}}}) { + s + } + } + dgquery: |- + query { + queryNested_X(func: type(Nested_X)) @filter((eq(Nested_X.s, "") AND (uid(queryNested_X_y)))) { + Nested_X.s : Nested_X.s + dgraph.uid : uid + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "")) { + queryNested_X_y as Nested_Z.x + } + } + +- name: "query nested type from type with interface" + gqlquery: | + query { + queryNested_Y(filter: {s: {eq: ""}, x: {s: {eq: ""}}}) { + s + } + } + dgquery: |- + query { + queryNested_Y(func: type(Nested_Y)) @filter((eq(Nested_Y.s, "") AND (uid(queryNested_Y_x)))) { + Nested_Y.s : Nested_Y.s + dgraph.uid : uid + } + var(func: type(Nested_X)) @filter(eq(Nested_X.s, "")) { + queryNested_Y_x as Nested_X.y + } + } + +- name: "query nested type from interface" + gqlquery: | + query { + queryNested_Z(filter: { x: {s: {eq: ""}}}) { + x { s } + } + } + dgquery: |- + query { + queryNested_Z(func: type(Nested_Z)) @filter((uid(queryNested_Z_x))) { + dgraph.type + Nested_Z.x : Nested_Z.x { + Nested_X.s : Nested_X.s + dgraph.uid : uid + } + dgraph.uid : uid + } + var(func: type(Nested_X)) @filter(eq(Nested_X.s, "")) { + queryNested_Z_x as Nested_X.y + } + } +- name: "query deeply nested object" + gqlquery: | + query { + queryNested_Z(filter: { x: { y: {s: {eq: ""}}}}) { + x { s } + } + } + dgquery: |- + query { + queryNested_Z(func: type(Nested_Z)) @filter((uid(queryNested_Z_x))) { + dgraph.type + Nested_Z.x : Nested_Z.x { + Nested_X.s : Nested_X.s + dgraph.uid : uid + } + dgraph.uid : uid + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "")) { + queryNested_Z_x_y as Nested_Z.x + } + var(func: type(Nested_X)) @filter((uid(queryNested_Z_x_y))) { + queryNested_Z_x as Nested_X.y + } + } + +- name: "query nested with AND condition" + gqlquery: | + query { + queryNested_X(filter: {s: {eq: ""}, and: { y: {s: {eq: ""}}}}) { + s + } + } + dgquery: |- + query { + queryNested_X(func: type(Nested_X)) @filter(((uid(queryNested_X_and_y)) AND eq(Nested_X.s, ""))) { + Nested_X.s : Nested_X.s + dgraph.uid : uid + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "")) { + queryNested_X_and_y as Nested_Z.x + } + } + +- name: "query nested with OR condition" + gqlquery: | + query { + queryNested_X(filter: {s: {eq: ""}, or: { y: {s: {eq: ""}}}}) { + s + } + } + dgquery: |- + query { + queryNested_X(func: type(Nested_X)) @filter((eq(Nested_X.s, "") OR ((uid(queryNested_X_or_y))))) { + Nested_X.s : Nested_X.s + dgraph.uid : uid + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "")) { + queryNested_X_or_y as Nested_Z.x + } + } + +- name: "query nested with aggregate function" + gqlquery: | + query { + aggregateNested_X(filter: {s: {eq: ""}, or: { y: {s: {eq: ""}}}}) { + sMax + } + } + dgquery: |- + query { + aggregateNested_X() { + Nested_XAggregateResult.sMax : max(val(sVar)) + } + var(func: type(Nested_X)) @filter((eq(Nested_X.s, "") OR ((uid(aggregateNested_X_or_y))))) { + sVar as Nested_X.s + } + var(func: type(Nested_Y)) @filter(eq(Nested_Y.s, "")) { + aggregateNested_X_or_y as Nested_Z.x + } + } diff --git a/graphql/resolve/schema.graphql b/graphql/resolve/schema.graphql index 93b928dbd01..2ce40b84f6e 100644 --- a/graphql/resolve/schema.graphql +++ b/graphql/resolve/schema.graphql @@ -548,3 +548,19 @@ type ProjectDotProduct { title: String description_v: [Float!] @embedding @search(by: ["hnsw(metric: dotproduct, exponent: 4)"]) } + +""" +This types are used to validate nested filting. +""" +type Nested_X { + s: String @search(by: [hash]) + y: Nested_Y @hasInverse(field: x) @search +} + +type Nested_Y implements Nested_Z{ + s: String @search(by: [hash]) +} + +interface Nested_Z { + x: Nested_X @search +} \ No newline at end of file diff --git a/graphql/schema/gqlschema.go b/graphql/schema/gqlschema.go index d6d32ed5c81..778671d37e5 100644 --- a/graphql/schema/gqlschema.go +++ b/graphql/schema/gqlschema.go @@ -1480,6 +1480,13 @@ func addPaginationArguments(fld *ast.FieldDefinition) { // getFilterTypes converts search arguments of a field to graphql filter types. func getFilterTypes(schema *ast.Schema, fld *ast.FieldDefinition, filterName string) []string { + + // Return the object filter if the field is an object that is searchable. + fldType := schema.Types[fld.Type.Name()] + if isCustomType(schema, fld.Type) && hasFilterable(fldType) && hasSearchDirective(fld) { + return []string{fld.Type.Name() + "Filter"} + } + searchArgs := getSearchArgs(fld) filterNames := make([]string, len(searchArgs)) diff --git a/graphql/schema/gqlschema_test.yml b/graphql/schema/gqlschema_test.yml index a82500b4426..00e50b2bd7b 100644 --- a/graphql/schema/gqlschema_test.yml +++ b/graphql/schema/gqlschema_test.yml @@ -404,7 +404,7 @@ invalid_schemas: ] - - name: "Search will error on type that can't have the @search" + name: "Search will error on type that require @hasInverse directive" input: | type X { y: Y @search @@ -413,8 +413,22 @@ invalid_schemas: y: String } errlist: [ - {"message": "Type X; Field y: has the @search directive but fields of type Y - can't have the @search directive.", + {"message": "Type X; Field y: has the @search directive for type Y but also requires + the @hasInverse directive.", + "locations":[{"line":2, "column":9}]} + ] + - + name: "Search will error on interface that require @hasInverse directive" + input: | + type X { + y: Y @search + } + interface Y { + y: String + } + errlist: [ + {"message": "Type X; Field y: has the @search directive for type Y but also requires + the @hasInverse directive.", "locations":[{"line":2, "column":9}]} ] @@ -3270,6 +3284,38 @@ valid_schemas: A } + - + name: "Correct search on object type" + input: | + type X { + y: Y @hasInverse(field: x) @search + y2: Y @search + y3: Y @hasInverse(field: x3) @search + } + type Y { + x: X + x2: X @hasInverse(field: y2) + x3: X @hasInverse(field: y3) @search + } + + - + name: "Correct search on interface type" + input: | + type X { + y: Y @hasInverse(field: x) @search + y2: Y @search + y3: Y @hasInverse(field: x3) @search + y4: Y + y5: Y @hasInverse(field: x5) + } + interface Y { + x: X + x2: X @hasInverse(field: y2) + x3: X @hasInverse(field: y3) @search + x4: X @hasInverse(field: y4) @search + x5: X @search + } + - name: "dgraph directive with correct reverse field works" input: | diff --git a/graphql/schema/rules.go b/graphql/schema/rules.go index c0774dce714..02b31607832 100644 --- a/graphql/schema/rules.go +++ b/graphql/schema/rules.go @@ -1075,6 +1075,21 @@ func searchValidation( return nil } + // If the field is an object, it is require to have an inverse edge for filtering. + // It's not enough to just check for the @hasInverse directive as it + // may be defined in the inverse type. + if isCustomType(sch, field.Type) { + if !hasInverse(sch, typ, field) { + errs = append(errs, gqlerror.ErrorPosf( + dir.Position, + "Type %s; Field %s: has the @search directive for type %s "+ + "but also requires the @hasInverse directive.", + typ.Name, field.Name, field.Type.Name())) + return errs + } + return nil + } + errs = append(errs, gqlerror.ErrorPosf( dir.Position, "Type %s; Field %s: has the @search directive but fields of type %s "+ diff --git a/graphql/schema/wrappers.go b/graphql/schema/wrappers.go index 0d8d6362756..d00b634c2ea 100644 --- a/graphql/schema/wrappers.go +++ b/graphql/schema/wrappers.go @@ -285,6 +285,7 @@ type FieldDefinition interface { IsID() bool IsExternal() bool HasIDDirective() bool + HasSearchDirective() bool HasEmbeddingDirective() bool EmbeddingSearchMetric() string HasInterfaceArg() bool @@ -2405,6 +2406,18 @@ func hasEmbeddingDirective(fd *ast.FieldDefinition) bool { return id != nil } +func (fd *fieldDefinition) HasSearchDirective() bool { + if fd.fieldDef == nil { + return false + } + return hasSearchDirective(fd.fieldDef) +} + +func hasSearchDirective(fd *ast.FieldDefinition) bool { + id := fd.Directives.ForName(searchDirective) + return id != nil +} + func (fd *fieldDefinition) HasInterfaceArg() bool { if fd.fieldDef == nil { return false @@ -2429,6 +2442,47 @@ func hasInterfaceArg(fd *ast.FieldDefinition) bool { return false } +// hasInverse checks if an inverse predicate is configured for an object. +func hasInverse(sch *ast.Schema, typ *ast.Definition, fd *ast.FieldDefinition) bool { + // check that the @hasInverse directive is provided + id := fd.Directives.ForName(inverseDirective) + if id != nil { + return true + } + + // also check the reference type. + refType := sch.Types[fd.Type.Name()] + for _, refField := range refType.Fields { + + refFieldType := sch.Types[refField.Type.Name()] + if refField.Type.Name() != typ.Name && !typ.OneOf(refFieldType.Interfaces...) { + continue + } + + refFieldDir := refField.Directives.ForName(inverseDirective) + if refFieldDir == nil { + continue + } + + invField := refFieldDir.Arguments.ForName("field") + if invField == nil { + continue + } + + invFieldName := invField.Value.Raw + if invFieldName == fd.Name { + return true + } + } + return false +} + +func isCustomType(sch *ast.Schema, t *ast.Type) bool { + _, ok := inbuiltTypeToDgraph[t.Name()] + return !ok && (sch.Types[t.Name()].Kind == ast.Object || + sch.Types[t.Name()].Kind == ast.Interface) +} + func isID(fd *ast.FieldDefinition) bool { return fd.Type.Name() == "ID" } @@ -2451,29 +2505,69 @@ func (fd *fieldDefinition) ParentType() Type { func (fd *fieldDefinition) Inverse() FieldDefinition { - invDirective := fd.fieldDef.Directives.ForName(inverseDirective) - if invDirective == nil { + if fd.fieldDef == nil { return nil } - invFieldArg := invDirective.Arguments.ForName(inverseArg) - if invFieldArg == nil { - return nil // really not possible - } + invDirective := fd.fieldDef.Directives.ForName(inverseDirective) + if invDirective != nil { - typeWrapper := fd.Type() - // typ must exist if the schema passed GQL validation - typ := fd.inSchema.schema.Types[typeWrapper.Name()] + invFieldArg := invDirective.Arguments.ForName(inverseArg) + if invFieldArg == nil { + return nil // really not possible + } - // fld must exist if the schema passed our validation - fld := typ.Fields.ForName(invFieldArg.Value.Raw) + typeWrapper := fd.Type() + // typ must exist if the schema passed GQL validation + typ := fd.inSchema.schema.Types[typeWrapper.Name()] - return &fieldDefinition{ - fieldDef: fld, - inSchema: fd.inSchema, - dgraphPredicate: fd.dgraphPredicate, - parentType: typeWrapper, + // fld must exist if the schema passed our validation + fld := typ.Fields.ForName(invFieldArg.Value.Raw) + + return &fieldDefinition{ + fieldDef: fld, + inSchema: fd.inSchema, + dgraphPredicate: fd.dgraphPredicate, + parentType: typeWrapper, + } + } else { + // also check the inverse type especially when querying from an interface + // and not the implemented type. In this case the interface won't have the + // inverse field. + typeWrapper := fd.Type() + // typ must exist if the schema passed GQL validation + typ := fd.inSchema.schema.Types[typeWrapper.Name()] + + for _, refField := range typ.Fields { + + refFieldType := fd.inSchema.schema.Types[refField.Type.Name()] + if refField.Type.Name() != typ.Name && + !fd.inSchema.schema.Types[fd.ParentType().Name()].OneOf(refFieldType.Interfaces...) { + continue + } + + refFieldDir := refField.Directives.ForName(inverseDirective) + if refFieldDir == nil { + continue + } + + invField := refFieldDir.Arguments.ForName("field") + if invField == nil { + continue + } + + invFieldName := invField.Value.Raw + if invFieldName == fd.Name() { + return &fieldDefinition{ + fieldDef: refField, + inSchema: fd.inSchema, + dgraphPredicate: fd.dgraphPredicate, + parentType: typeWrapper, + } + } + } } + return nil } func (fd *fieldDefinition) WithMemberType(memberType string) FieldDefinition {