Skip to content

Commit

Permalink
Merge pull request #1524 from canack/general-fixes
Browse files Browse the repository at this point in the history
refactor: improve efficiency, Bug fixes, Code readability
  • Loading branch information
tolgaOzen authored Aug 27, 2024
2 parents 7c496dd + fe68edb commit ff123a5
Show file tree
Hide file tree
Showing 14 changed files with 69 additions and 79 deletions.
10 changes: 5 additions & 5 deletions internal/authn/oidc/authn.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (oidc *Authn) Authenticate(requestContext context.Context) error {

// Retrieve the key ID from the JWT header and find the corresponding key in the JWKS.
if keyID, ok := token.Header["kid"].(string); ok {
return oidc.getKeyWithRetry(keyID, requestContext)
return oidc.getKeyWithRetry(requestContext, keyID)
}
slog.Error("jwt does not contain a key ID")
// If the JWT does not contain a key ID, return an error.
Expand Down Expand Up @@ -188,7 +188,7 @@ func (oidc *Authn) Authenticate(requestContext context.Context) error {
}

// getKeyWithRetry attempts to retrieve the key for the given keyID with retries using a custom backoff strategy.
func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface{}, error) {
func (oidc *Authn) getKeyWithRetry(ctx context.Context, keyID string) (interface{}, error) {
var rawKey interface{}
var err error

Expand All @@ -207,7 +207,7 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
oidc.mu.Unlock()

// Try to fetch the keyID once
rawKey, err = oidc.fetchKey(keyID, ctx)
rawKey, err = oidc.fetchKey(ctx, keyID)
if err == nil {
oidc.mu.Lock()
if _, ok := oidc.globalRetryKeyIds[keyID]; ok {
Expand All @@ -232,7 +232,7 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
// Retry mechanism
retries := 0
for retries <= oidc.backoffMaxRetries {
rawKey, err = oidc.fetchKey(keyID, ctx)
rawKey, err = oidc.fetchKey(ctx, keyID)
if err == nil {
if retries != 0 {
oidc.mu.Lock()
Expand Down Expand Up @@ -298,7 +298,7 @@ func (oidc *Authn) getKeyWithRetry(keyID string, ctx context.Context) (interface
}

// fetchKey attempts to fetch the JWKS and retrieve the key for the given keyID.
func (oidc *Authn) fetchKey(keyID string, ctx context.Context) (interface{}, error) {
func (oidc *Authn) fetchKey(ctx context.Context, keyID string) (interface{}, error) {
// Log the attempt to find the key.
slog.DebugContext(ctx, "attempting to find key in JWKS", "kid", keyID)

Expand Down
38 changes: 19 additions & 19 deletions internal/engines/bulk.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func (bc *BulkChecker) CollectAndSortRequests() {
}
}

// Signal to stop collecting requests and close the channel
// StopCollectingRequests Signal to stop collecting requests and close the channel
func (bc *BulkChecker) StopCollectingRequests() {
bc.mu.Lock()
defer bc.mu.Unlock()
Expand All @@ -145,24 +145,24 @@ func (bc *BulkChecker) sortRequests() {
}

// ExecuteRequests begins processing permission check requests from the sorted list.
func (c *BulkChecker) ExecuteRequests(size uint32) error {
func (bc *BulkChecker) ExecuteRequests(size uint32) error {
// Stop collecting new requests and close the RequestChan to ensure no more requests are added
c.StopCollectingRequests()
bc.StopCollectingRequests()

// Wait for request collection to complete before proceeding
c.wg.Wait()
bc.wg.Wait()

// Track the number of successful permission checks
successCount := int64(0)
// Semaphore to control the maximum number of concurrent permission checks
sem := semaphore.NewWeighted(int64(c.concurrencyLimit))
sem := semaphore.NewWeighted(int64(bc.concurrencyLimit))
var mu sync.Mutex

// Lock the mutex to prevent race conditions while sorting and copying the list of requests
c.mu.Lock()
c.sortRequests() // Sort requests based on id
listCopy := append([]BulkCheckerRequest{}, c.list...) // Create a copy of the list to avoid modifying the original during processing
c.mu.Unlock() // Unlock the mutex after sorting and copying
bc.mu.Lock()
bc.sortRequests() // Sort requests based on id
listCopy := append([]BulkCheckerRequest{}, bc.list...) // Create a copy of the list to avoid modifying the original during processing
bc.mu.Unlock() // Unlock the mutex after sorting and copying

// Pre-allocate a slice to store the results of the permission checks
results := make([]base.CheckResult, len(listCopy))
Expand All @@ -180,17 +180,17 @@ func (c *BulkChecker) ExecuteRequests(size uint32) error {
req := currentRequest

// Use errgroup to manage the goroutines, which allows for error handling and synchronization
c.g.Go(func() error {
bc.g.Go(func() error {
// Acquire a slot in the semaphore to control concurrency
if err := sem.Acquire(c.ctx, 1); err != nil {
if err := sem.Acquire(bc.ctx, 1); err != nil {
return err // Return an error if semaphore acquisition fails
}
defer sem.Release(1) // Ensure the semaphore slot is released after processing

var result base.CheckResult
if req.Result == base.CheckResult_CHECK_RESULT_UNSPECIFIED {
// Perform the permission check if the result is not already specified
cr, err := c.checker.Check(c.ctx, req.Request)
cr, err := bc.checker.Check(bc.ctx, req.Request)
if err != nil {
return err // Return an error if the check fails
}
Expand All @@ -212,17 +212,17 @@ func (c *BulkChecker) ExecuteRequests(size uint32) error {
ct := ""
if processedIndex+1 < len(listCopy) {
// If there is a next item, create a continuous token with the next ID
if c.typ == BULK_ENTITY {
if bc.typ == BULK_ENTITY {
ct = utils.NewContinuousToken(listCopy[processedIndex+1].Request.GetEntity().GetId()).Encode().String()
} else if c.typ == BULK_SUBJECT {
} else if bc.typ == BULK_SUBJECT {
ct = utils.NewContinuousToken(listCopy[processedIndex+1].Request.GetSubject().GetId()).Encode().String()
}
}
// Depending on the type of check (entity or subject), call the appropriate callback
if c.typ == BULK_ENTITY {
c.callback(listCopy[processedIndex].Request.GetEntity().GetId(), ct)
} else if c.typ == BULK_SUBJECT {
c.callback(listCopy[processedIndex].Request.GetSubject().GetId(), ct)
if bc.typ == BULK_ENTITY {
bc.callback(listCopy[processedIndex].Request.GetEntity().GetId(), ct)
} else if bc.typ == BULK_SUBJECT {
bc.callback(listCopy[processedIndex].Request.GetSubject().GetId(), ct)
}
}
}
Expand All @@ -235,7 +235,7 @@ func (c *BulkChecker) ExecuteRequests(size uint32) error {
}

// Wait for all goroutines to complete and check for any errors
if err := c.g.Wait(); err != nil {
if err := bc.g.Wait(); err != nil {
return err // Return the error if any goroutine returned an error
}

Expand Down
1 change: 0 additions & 1 deletion internal/engines/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ func GenerateKey(key *base.PermissionCheckRequest, isRelational bool) string {
if entityRelationString != "" {
parts = append(parts, fmt.Sprintf("%s@%s", entityRelationString, subjectString))
}

} else {
parts = append(parts, attribute.EntityAndCallOrAttributeToString(
key.GetEntity(),
Expand Down
6 changes: 3 additions & 3 deletions internal/invoke/invoke.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ func (invoker *DirectInvoker) Check(ctx context.Context, request *base.Permissio
},
}, err
}
duration := time.Now().Sub(start)
duration := time.Since(start)
invoker.checkDurationHistogram.Record(ctx, duration.Microseconds())

// Increase the check count in the response metadata.
Expand Down Expand Up @@ -276,7 +276,7 @@ func (invoker *DirectInvoker) LookupEntity(ctx context.Context, request *base.Pe

resp, err := invoker.lo.LookupEntity(ctx, request)

duration := time.Now().Sub(start)
duration := time.Since(start)
invoker.lookupEntityDurationHistogram.Record(ctx, duration.Microseconds())

// Increase the lookup entity count in the metrics.
Expand Down Expand Up @@ -323,7 +323,7 @@ func (invoker *DirectInvoker) LookupEntityStream(ctx context.Context, request *b

resp := invoker.lo.LookupEntityStream(ctx, request, server)

duration := time.Now().Sub(start)
duration := time.Since(start)
invoker.lookupEntityDurationHistogram.Record(ctx, duration.Microseconds())

// Increase the lookup entity count in the metrics.
Expand Down
7 changes: 3 additions & 4 deletions internal/servers/dataServer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ func (r *DataServer) ReadRelationships(ctx context.Context, request *v1.Relation
return nil, status.Error(GetStatus(err), err.Error())
}

duration := time.Now().Sub(start)
duration := time.Since(start)
r.readRelationshipsHistogram.Record(ctx, duration.Microseconds())

return &v1.RelationshipReadResponse{
Expand Down Expand Up @@ -140,7 +140,7 @@ func (r *DataServer) ReadAttributes(ctx context.Context, request *v1.AttributeRe
return nil, status.Error(GetStatus(err), err.Error())
}

duration := time.Now().Sub(start)
duration := time.Since(start)
r.readAttributesHistogram.Record(ctx, duration.Microseconds())

return &v1.AttributeReadResponse{
Expand Down Expand Up @@ -176,7 +176,6 @@ func (r *DataServer) Write(ctx context.Context, request *v1.DataWriteRequest) (*
relationshipsMap := map[string]struct{}{}

for _, tup := range request.GetTuples() {

key := tuple.ToString(tup)

if _, ok := relationshipsMap[key]; ok {
Expand Down Expand Up @@ -241,7 +240,7 @@ func (r *DataServer) Write(ctx context.Context, request *v1.DataWriteRequest) (*
return nil, status.Error(GetStatus(err), err.Error())
}

duration := time.Now().Sub(start)
duration := time.Since(start)
r.writeDataHistogram.Record(ctx, duration.Microseconds())

return &v1.DataWriteResponse{
Expand Down
37 changes: 14 additions & 23 deletions pkg/attribute/attribute.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,44 +56,40 @@ func Attribute(attribute string) (*base.Attribute, error) {
case "boolean":
boolVal, err := strconv.ParseBool(v[1])
if err != nil {
return nil, fmt.Errorf("failed to parse boolean: %v", err)
return nil, fmt.Errorf("failed to parse boolean: %w", err)
}
wrapped = &base.BooleanValue{Data: boolVal}
case "boolean[]":
var ba []bool
val := strings.Split(v[1], ",")
for _, value := range val {
var ba = make([]bool, len(val))
for i, value := range val {
boolVal, err := strconv.ParseBool(value)
if err != nil {
return nil, fmt.Errorf("failed to parse boolean: %v", err)
return nil, fmt.Errorf("failed to parse boolean: %w", err)
}
ba = append(ba, boolVal)
ba[i] = boolVal
}
wrapped = &base.BooleanArrayValue{Data: ba}
case "string":
wrapped = &base.StringValue{Data: v[1]}
case "string[]":
var sa []string
val := strings.Split(v[1], ",")
for _, value := range val {
sa = append(sa, value)
}
var sa = strings.Split(v[1], ",")
wrapped = &base.StringArrayValue{Data: sa}
case "double":
doubleVal, err := strconv.ParseFloat(v[1], 64)
if err != nil {
return nil, fmt.Errorf("failed to parse float: %v", err)
return nil, fmt.Errorf("failed to parse float: %w", err)
}
wrapped = &base.DoubleValue{Data: doubleVal}
case "double[]":
var da []float64
val := strings.Split(v[1], ",")
for _, value := range val {
var da = make([]float64, len(val))
for i, value := range val {
doubleVal, err := strconv.ParseFloat(value, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse float: %v", err)
}
da = append(da, doubleVal)
da[i] = doubleVal
}
wrapped = &base.DoubleArrayValue{Data: da}
case "integer":
Expand All @@ -103,15 +99,14 @@ func Attribute(attribute string) (*base.Attribute, error) {
}
wrapped = &base.IntegerValue{Data: int32(intVal)}
case "integer[]":

var ia []int32
val := strings.Split(v[1], ",")
for _, value := range val {
var ia = make([]int32, len(val))
for i, value := range val {
intVal, err := strconv.ParseInt(value, 10, 32)
if err != nil {
return nil, fmt.Errorf("failed to parse integer: %v", err)
}
ia = append(ia, int32(intVal))
ia[i] = int32(intVal)
}
wrapped = &base.IntegerArrayValue{Data: ia}
default:
Expand Down Expand Up @@ -243,11 +238,7 @@ func AnyToString(any *anypb.Any) string {
if err := any.UnmarshalTo(stringVal); err != nil {
return "undefined"
}
var strs []string
for _, v := range stringVal.GetData() {
strs = append(strs, v)
}
str = strings.Join(strs, ",")
str = strings.Join(stringVal.GetData(), ",")
case "type.googleapis.com/base.v1.DoubleValue":
doubleVal := &base.DoubleValue{}
if err := any.UnmarshalTo(doubleVal); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions pkg/balancer/errors.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package balancer

import (
"fmt"
"errors"
)

var (
// ErrSubConnMissing indicates that a SubConn (sub-connection) was expected but not found.
ErrSubConnMissing = fmt.Errorf("sub-connection is missing or not found")
ErrSubConnMissing = errors.New("sub-connection is missing or not found")
// ErrSubConnResetFailure indicates an error occurred while trying to reset the SubConn.
ErrSubConnResetFailure = fmt.Errorf("failed to reset the sub-connection")
ErrSubConnResetFailure = errors.New("failed to reset the sub-connection")
)
9 changes: 5 additions & 4 deletions pkg/balancer/hashring.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package balancer

import (
"errors"
"fmt"
"log"
"log/slog"
Expand Down Expand Up @@ -86,7 +87,7 @@ func (b *consistentHashBalancer) UpdateClientConnState(s balancer.ClientConnStat

// If there are no addresses from the resolver, log an error.
if len(s.ResolverState.Addresses) == 0 {
b.ResolverError(fmt.Errorf("produced zero addresses"))
b.ResolverError(errors.New("produced zero addresses"))
return balancer.ErrBadResolverState
}

Expand Down Expand Up @@ -206,14 +207,14 @@ func (b *consistentHashBalancer) mergeErrors() error {

// If only one of the errors is nil, return the other error.
if b.lastConnectionError == nil {
return fmt.Errorf("last resolver error: %v", b.lastResolverError)
return fmt.Errorf("last resolver error: %w", b.lastResolverError)
}
if b.lastResolverError == nil {
return fmt.Errorf("last connection error: %v", b.lastConnectionError)
return fmt.Errorf("last connection error: %w", b.lastConnectionError)
}

// If both errors are present, concatenate them.
return fmt.Errorf("last connection error: %v; last resolver error: %v", b.lastConnectionError, b.lastResolverError)
return errors.Join(b.lastConnectionError, b.lastResolverError)
}

// UpdateSubConnState -
Expand Down
2 changes: 1 addition & 1 deletion pkg/cmd/migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func migrateDown() func(cmd *cobra.Command, args []string) error {

p, err := strconv.ParseInt(flags[target], 10, 64)
if err != nil {
return nil
return err
}

if p == 0 {
Expand Down
12 changes: 6 additions & 6 deletions pkg/development/coverage/coverage.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ func Run(shape file.Shape) SchemaCoverageInfo {

schemaCoverageInfo := SchemaCoverageInfo{}

var refs []SchemaCoverage
for _, en := range definitions {
refs = append(refs, references(en))
var refs = make([]SchemaCoverage, len(definitions))
for i, en := range definitions {
refs[i] = references(en)
}

// Iterate through the schema coverage references
Expand Down Expand Up @@ -269,16 +269,16 @@ func relationships(en string, relationships []string) []string {

// attributes - Get attributes for a given entity
func attributes(en string, attributes []string) []string {
var attrs []string
for _, attr := range attributes {
var attrs = make([]string, len(attributes))
for i, attr := range attributes {
a, err := attribute.Attribute(attr)
if err != nil {
return []string{}
}
if a.GetEntity().GetType() != en {
continue
}
attrs = append(attrs, fmt.Sprintf("%s#%s", a.GetEntity().GetType(), a.GetAttribute()))
attrs[i] = fmt.Sprintf("%s#%s", a.GetEntity().GetType(), a.GetAttribute())
}
return attrs
}
Expand Down
1 change: 0 additions & 1 deletion pkg/development/development.go
Original file line number Diff line number Diff line change
Expand Up @@ -452,7 +452,6 @@ func (c *Development) RunWithShape(ctx context.Context, shape *file.Shape) (erro

// Each SubjectFilter in the current scenario is processed
for _, filter := range scenario.SubjectFilters {

subjectReference := tuple.RelationReference(filter.SubjectReference)
if err != nil {
errors = append(errors, Error{
Expand Down
Loading

0 comments on commit ff123a5

Please sign in to comment.