Skip to content

Commit

Permalink
Update query level incrementing + result aggregation logic so it alwa…
Browse files Browse the repository at this point in the history
…ys returns the correct implicit flag for results
  • Loading branch information
kkajla12 committed Mar 18, 2024
1 parent 31d2a4e commit 49fab0e
Show file tree
Hide file tree
Showing 2 changed files with 301 additions and 25 deletions.
78 changes: 53 additions & 25 deletions pkg/authz/query/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
for _, matchedWarrant := range matchedWarrants {
if matchedWarrant.Subject.Relation != "" {
// handle group warrants
userset, err := svc.query(ctx, Query{
subset, err := svc.query(ctx, Query{
Expand: query.Expand,
SelectSubjects: &SelectSubjects{
Relations: []string{matchedWarrant.Subject.Relation},
Expand All @@ -352,8 +352,8 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
return nil, err
}

for res := userset.List(); res != nil; res = res.Next() {
if res.ObjectType != query.SelectObjects.WhereSubject.Type || res.ObjectId != query.SelectObjects.WhereSubject.Id {
for sub := subset.List(); sub != nil; sub = sub.Next() {
if sub.ObjectType != query.SelectObjects.WhereSubject.Type || sub.ObjectId != query.SelectObjects.WhereSubject.Id {
continue
}

Expand All @@ -367,27 +367,29 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res

for _, w := range expandedWildcardWarrants {
if w.ObjectId != warrant.Wildcard {
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(w.ObjectType, w.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
}
}
} else {
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
}
}
} else if query.SelectObjects.WhereSubject == nil ||
(matchedWarrant.Subject.ObjectType == query.SelectObjects.WhereSubject.Type &&
matchedWarrant.Subject.ObjectId == query.SelectObjects.WhereSubject.Id) {
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.ObjectType, matchedWarrant.ObjectId, relation, matchedWarrant, false)
}
}

if query.Expand {
implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation])
implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation])
if err != nil {
return nil, err
}

resultSet = resultSet.Union(implicitResultSet)
for res := implicitResultSet.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
Expand Down Expand Up @@ -417,7 +419,7 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
for _, matchedWarrant := range matchedWarrants {
if matchedWarrant.Subject.Relation != "" {
// handle group warrants
userset, err := svc.query(ctx, Query{
subset, err := svc.query(ctx, Query{
Expand: query.Expand,
SelectSubjects: &SelectSubjects{
Relations: []string{matchedWarrant.Subject.Relation},
Expand All @@ -433,21 +435,23 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
return nil, err
}

for res := userset.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, matchedWarrant, level > 0)
for sub := subset.List(); sub != nil; sub = sub.Next() {
resultSet.Add(sub.ObjectType, sub.ObjectId, relation, matchedWarrant, sub.IsImplicit || level > 0)
}
} else if query.SelectSubjects.SubjectTypes[0] == matchedWarrant.Subject.ObjectType {
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, level > 0)
resultSet.Add(matchedWarrant.Subject.ObjectType, matchedWarrant.Subject.ObjectId, relation, matchedWarrant, false)
}
}

if query.Expand {
implicitResultSet, err := svc.queryRule(ctx, query, level, objectTypeDef.Relations[relation])
implicitResultSet, err := svc.queryRule(ctx, query, level+1, relation, objectTypeDef.Relations[relation])
if err != nil {
return nil, err
}

return resultSet.Union(implicitResultSet), nil
for res := implicitResultSet.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
Expand All @@ -456,14 +460,14 @@ func (svc QueryService) query(ctx context.Context, query Query, level int) (*Res
}
}

func (svc QueryService) queryRule(ctx context.Context, query Query, level int, rule objecttype.RelationRule) (*ResultSet, error) {
func (svc QueryService) queryRule(ctx context.Context, query Query, level int, relation string, rule objecttype.RelationRule) (*ResultSet, error) {
switch rule.InheritIf {
case "":
return NewResultSet(), nil
case objecttype.InheritIfAllOf:
var resultSet *ResultSet
for _, r := range rule.Rules {
res, err := svc.queryRule(ctx, query, level, r)
res, err := svc.queryRule(ctx, query, level, relation, r)
if err != nil {
return nil, err
}
Expand All @@ -479,7 +483,7 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
case objecttype.InheritIfAnyOf:
var resultSet *ResultSet
for _, r := range rule.Rules {
res, err := svc.queryRule(ctx, query, level, r)
res, err := svc.queryRule(ctx, query, level, relation, r)
if err != nil {
return nil, err
}
Expand All @@ -498,15 +502,25 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
switch {
case query.SelectObjects != nil:
if rule.OfType == "" && rule.WithRelation == "" {
return svc.query(ctx, Query{
results, err := svc.query(ctx, Query{
Expand: true,
SelectObjects: &SelectObjects{
ObjectTypes: query.SelectObjects.ObjectTypes,
WhereSubject: query.SelectObjects.WhereSubject,
Relations: []string{rule.InheritIf},
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet := NewResultSet()
for res := results.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}

return resultSet, nil
} else {
indirectWarrants, err := svc.listWarrants(ctx, warrant.FilterParams{
ObjectType: rule.OfType,
Expand Down Expand Up @@ -535,27 +549,39 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
Relations: []string{rule.WithRelation},
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet = resultSet.Union(inheritedResults)
for res := inheritedResults.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
}
case query.SelectSubjects != nil:
if rule.OfType == "" && rule.WithRelation == "" {
return svc.query(ctx, Query{
results, err := svc.query(ctx, Query{
Expand: true,
SelectSubjects: &SelectSubjects{
SubjectTypes: query.SelectSubjects.SubjectTypes,
Relations: []string{rule.InheritIf},
ForObject: query.SelectSubjects.ForObject,
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet := NewResultSet()
for res := results.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}

return resultSet, nil
} else {
userset, err := svc.listWarrants(ctx, warrant.FilterParams{
ObjectType: query.SelectSubjects.ForObject.Type,
Expand Down Expand Up @@ -584,12 +610,14 @@ func (svc QueryService) queryRule(ctx context.Context, query Query, level int, r
},
},
Context: query.Context,
}, level+1)
}, 0)
if err != nil {
return nil, err
}

resultSet = resultSet.Union(subset)
for res := subset.List(); res != nil; res = res.Next() {
resultSet.Add(res.ObjectType, res.ObjectId, relation, res.Warrant, res.IsImplicit || level > 0)
}
}

return resultSet, nil
Expand Down
Loading

0 comments on commit 49fab0e

Please sign in to comment.