Skip to content

Commit

Permalink
Rewrite + Improve Query Execution (#302)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkajla12 authored Mar 5, 2024
1 parent d7af249 commit 6030794
Show file tree
Hide file tree
Showing 16 changed files with 1,378 additions and 768 deletions.
66 changes: 33 additions & 33 deletions pkg/authz/check/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,16 @@ func (svc CheckService) getWithPolicyMatch(ctx context.Context, checkPipeline *p
defer checkPipeline.ReleaseServiceLock()

listParams := service.DefaultListParams(warrant.WarrantListParamParser{})
listParams.Limit = MaxWarrants
listParams.WithLimit(MaxWarrants)
warrantSpecs, _, _, err := svc.warrantSvc.List(
ctx,
warrant.FilterParams{
ObjectType: []string{spec.ObjectType},
ObjectId: []string{spec.ObjectId},
Relation: []string{spec.Relation},
SubjectType: []string{spec.Subject.ObjectType},
SubjectId: []string{spec.Subject.ObjectId},
SubjectRelation: []string{spec.Subject.Relation},
ObjectType: spec.ObjectType,
ObjectId: spec.ObjectId,
Relation: spec.Relation,
SubjectType: spec.Subject.ObjectType,
SubjectId: spec.Subject.ObjectId,
SubjectRelation: spec.Subject.Relation,
},
listParams,
)
Expand All @@ -91,16 +91,16 @@ func (svc CheckService) getWithPolicyMatch(ctx context.Context, checkPipeline *p
}

// if a warrant without a policy is found, match it
for _, warrant := range warrantSpecs {
if warrant.Policy == "" {
return &warrant, nil
for _, w := range warrantSpecs {
if w.Policy == "" {
return &w, nil
}
}

for _, warrant := range warrantSpecs {
if warrant.Policy != "" {
if policyMatched := evalWarrantPolicy(warrant, spec.Context); policyMatched {
return &warrant, nil
for _, w := range warrantSpecs {
if w.Policy != "" {
if policyMatched := evalWarrantPolicy(w, spec.Context); policyMatched {
return &w, nil
}
}
}
Expand All @@ -123,13 +123,13 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, checkPipeline *
}

listParams := service.DefaultListParams(warrant.WarrantListParamParser{})
listParams.Limit = MaxWarrants
listParams.WithLimit(MaxWarrants)
warrantSpecs, _, _, err = svc.warrantSvc.List(
ctx,
warrant.FilterParams{
ObjectType: []string{objectType},
ObjectId: []string{objectId},
Relation: []string{relation},
ObjectType: objectType,
ObjectId: objectId,
Relation: relation,
},
listParams,
)
Expand All @@ -138,12 +138,12 @@ func (svc CheckService) getMatchingSubjects(ctx context.Context, checkPipeline *
}

matchingSpecs := make([]warrant.WarrantSpec, 0)
for _, warrant := range warrantSpecs {
if warrant.Policy == "" {
matchingSpecs = append(matchingSpecs, warrant)
for _, w := range warrantSpecs {
if w.Policy == "" {
matchingSpecs = append(matchingSpecs, w)
} else {
if policyMatched := evalWarrantPolicy(warrant, checkCtx); policyMatched {
matchingSpecs = append(matchingSpecs, warrant)
if policyMatched := evalWarrantPolicy(w, checkCtx); policyMatched {
matchingSpecs = append(matchingSpecs, w)
}
}
}
Expand All @@ -167,14 +167,14 @@ func (svc CheckService) getMatchingSubjectsBySubjectType(ctx context.Context, ch
}

listParams := service.DefaultListParams(warrant.WarrantListParamParser{})
listParams.Limit = MaxWarrants
listParams.WithLimit(MaxWarrants)
warrantSpecs, _, _, err = svc.warrantSvc.List(
ctx,
warrant.FilterParams{
ObjectType: []string{objectType},
ObjectId: []string{objectId},
Relation: []string{relation},
SubjectType: []string{subjectType},
ObjectType: objectType,
ObjectId: objectId,
Relation: relation,
SubjectType: subjectType,
},
listParams,
)
Expand All @@ -183,12 +183,12 @@ func (svc CheckService) getMatchingSubjectsBySubjectType(ctx context.Context, ch
}

matchingSpecs := make([]warrant.WarrantSpec, 0)
for _, warrant := range warrantSpecs {
if warrant.Policy == "" {
matchingSpecs = append(matchingSpecs, warrant)
for _, w := range warrantSpecs {
if w.Policy == "" {
matchingSpecs = append(matchingSpecs, w)
} else {
if policyMatched := evalWarrantPolicy(warrant, checkCtx); policyMatched {
matchingSpecs = append(matchingSpecs, warrant)
if policyMatched := evalWarrantPolicy(w, checkCtx); policyMatched {
matchingSpecs = append(matchingSpecs, w)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions pkg/authz/query/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ func queryV1(svc QueryService, w http.ResponseWriter, r *http.Request) error {
return service.NewInvalidParameterError("lastId", "invalid lastId")
}

listParams.NextCursor = lastIdCursor
listParams.WithNextCursor(lastIdCursor)
} else if r.URL.Query().Has("afterId") {
afterIdCursor, err := service.NewCursorFromBase64String(r.URL.Query().Get("afterId"), QueryListParamParser{}, listParams.SortBy)
if err != nil {
return service.NewInvalidParameterError("afterId", "invalid afterId")
}

listParams.NextCursor = afterIdCursor
listParams.WithNextCursor(afterIdCursor)
}

results, _, nextCursor, err := svc.Query(r.Context(), query, listParams)
Expand Down
79 changes: 41 additions & 38 deletions pkg/authz/query/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,95 +84,98 @@ func (parser parser) Parse(query string) (*ast, error) {
return ast, nil
}

func NewQueryFromString(queryString string) (*Query, error) {
func NewQueryFromString(queryString string) (Query, error) {
var query Query

queryParser, err := newParser()
if err != nil {
return nil, errors.Wrap(err, "error creating query from string")
return Query{}, errors.Wrap(err, "error creating query from string")
}

ast, err := queryParser.Parse(queryString)
if err != nil {
return nil, service.NewInvalidParameterError("q", err.Error())
return Query{}, service.NewInvalidParameterError("q", err.Error())
}

if ast.SelectClause == nil {
return nil, service.NewInvalidParameterError("q", "must contain a 'select' clause")
return Query{}, service.NewInvalidParameterError("q", "must contain a 'select' clause")
}

if ast.SelectClause.ObjectTypesOrRelations == nil && ast.SelectClause.SubjectTypes == nil {
return nil, service.NewInvalidParameterError("q", "incomplete 'select' clause")
return Query{}, service.NewInvalidParameterError("q", "incomplete 'select' clause")
}

if ast.ForClause != nil && ast.WhereClause != nil {
return nil, service.NewInvalidParameterError("q", "cannot contain both a 'for' clause and a 'where' clause")
return Query{}, service.NewInvalidParameterError("q", "cannot contain both a 'for' clause and a 'where' clause")
}

query.Expand = !ast.SelectClause.Explicit

if ast.SelectClause.SubjectTypes != nil { // Querying for subjects
if len(ast.SelectClause.SubjectTypes) == 0 {
return nil, service.NewInvalidParameterError("q", "must contain one or more types of subjects to select")
return Query{}, service.NewInvalidParameterError("q", "must contain one or more types of subjects to select")
}

if ast.SelectClause.ObjectTypesOrRelations == nil || len(ast.SelectClause.ObjectTypesOrRelations) == 0 {
return nil, service.NewInvalidParameterError("q", "must select one or more relations for subjects to match on the object")
return Query{}, service.NewInvalidParameterError("q", "must select one or more relations for subjects to match on the object")
}

if ast.WhereClause != nil {
return nil, service.NewInvalidParameterError("q", "cannot contain a 'where' clause when selecting subjects")
return Query{}, service.NewInvalidParameterError("q", "cannot contain a 'where' clause when selecting subjects")
}

query.SelectSubjects = &SelectSubjects{
Relations: ast.SelectClause.ObjectTypesOrRelations,
SubjectTypes: ast.SelectClause.SubjectTypes,
}

if ast.ForClause != nil {
objectType, objectId, colonFound := strings.Cut(ast.ForClause.Object, ":")
if !colonFound {
return nil, service.NewInvalidParameterError("q", "'for' clause contains invalid object")
}

query.SelectSubjects.ForObject = &Resource{
Type: objectType,
Id: objectId,
}
if ast.ForClause == nil {
return Query{}, service.NewInvalidParameterError("q", "must contain a 'for' clause")
}

objectType, objectId, colonFound := strings.Cut(ast.ForClause.Object, ":")
if !colonFound {
return Query{}, service.NewInvalidParameterError("q", "'for' clause contains invalid object")
}

query.SelectSubjects.ForObject = &Resource{
Type: objectType,
Id: objectId,
}
} else { // Querying for objects
if ast.SelectClause.ObjectTypesOrRelations == nil || len(ast.SelectClause.ObjectTypesOrRelations) == 0 {
return nil, service.NewInvalidParameterError("q", "must contain one or more types of objects to select")
return Query{}, service.NewInvalidParameterError("q", "must contain one or more types of objects to select")
}

if ast.ForClause != nil {
return nil, service.NewInvalidParameterError("q", "cannot contain a 'for' clause when selecting objects")
return Query{}, service.NewInvalidParameterError("q", "cannot contain a 'for' clause when selecting objects")
}

query.SelectObjects = &SelectObjects{
ObjectTypes: ast.SelectClause.ObjectTypesOrRelations,
Relations: []string{warrant.Wildcard},
}

if ast.WhereClause != nil {
if ast.WhereClause.Relations == nil || len(ast.WhereClause.Relations) == 0 {
return nil, service.NewInvalidParameterError("q", "must contain one or more relations the subject must have on matching objects")
}

subjectType, subjectId, colonFound := strings.Cut(ast.WhereClause.Subject, ":")
if !colonFound {
return nil, service.NewInvalidParameterError("q", "'where' clause contains invalid subject")
}

query.SelectObjects.Relations = ast.WhereClause.Relations
query.SelectObjects.WhereSubject = &Resource{
Type: subjectType,
Id: subjectId,
}
if ast.WhereClause == nil {
return Query{}, service.NewInvalidParameterError("q", "must contain a 'where' clause")
}

if ast.WhereClause.Relations == nil || len(ast.WhereClause.Relations) == 0 {
return Query{}, service.NewInvalidParameterError("q", "must contain one or more relations the subject must have on matching objects")
}

subjectType, subjectId, colonFound := strings.Cut(ast.WhereClause.Subject, ":")
if !colonFound {
return Query{}, service.NewInvalidParameterError("q", "'where' clause contains invalid subject")
}

query.SelectObjects.Relations = ast.WhereClause.Relations
query.SelectObjects.WhereSubject = &Resource{
Type: subjectType,
Id: subjectId,
}
}

query.rawString = queryString

return &query, nil
return query, nil
}
22 changes: 13 additions & 9 deletions pkg/authz/query/resultset.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,24 @@ type ResultSet struct {
}

func (rs *ResultSet) List() *ResultSetNode {
if rs == nil {
return nil
}

return rs.head
}

func (rs *ResultSet) Add(objectType string, objectId string, warrant warrant.WarrantSpec, isImplicit bool) {
if _, exists := rs.m[key(objectType, objectId)]; !exists {
// Add warrant to list
newNode := &ResultSetNode{
ObjectType: objectType,
ObjectId: objectId,
Warrant: warrant,
IsImplicit: isImplicit,
next: nil,
}
newNode := &ResultSetNode{
ObjectType: objectType,
ObjectId: objectId,
Warrant: warrant,
IsImplicit: isImplicit,
next: nil,
}

if existingRes, exists := rs.m[key(objectType, objectId)]; !exists || (existingRes.IsImplicit && !isImplicit) {
// Add warrant to list
if rs.head == nil {
rs.head = newNode
}
Expand Down
Loading

0 comments on commit 6030794

Please sign in to comment.