Skip to content

Commit

Permalink
pass cycle to getLimit (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
klaidliadon authored Feb 28, 2024
1 parent a7712f2 commit 0ec705b
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func (m *MemoryStore) SetAccessLimit(ctx context.Context, projectID uint64, conf
return nil
}

func (m *MemoryStore) GetAccessLimit(ctx context.Context, projectID uint64) (*proto.Limit, error) {
func (m *MemoryStore) GetAccessLimit(ctx context.Context, projectID uint64, cycle *proto.Cycle) (*proto.Limit, error) {
m.Lock()
limit, ok := m.limits[projectID]
m.Unlock()
Expand Down
6 changes: 3 additions & 3 deletions middleware/access.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func VerifyAccessKey(client Client, eh ErrorHandler) func(next http.Handler) htt
return
}

quota, err := client.FetchKeyQuota(ctx, accessKey, r.Header.Get(HeaderOrigin), getTime(ctx))
quota, err := client.FetchKeyQuota(ctx, accessKey, r.Header.Get(HeaderOrigin), GetTime(ctx))
if err != nil {
eh(w, r, next, err)
return
Expand Down Expand Up @@ -97,7 +97,7 @@ func EnsureUsage(client Client, eh ErrorHandler) func(next http.Handler) http.Ha
return
}

usage, err := client.FetchUsage(ctx, quota, getTime(ctx))
usage, err := client.FetchUsage(ctx, quota, GetTime(ctx))
if err != nil {
eh(w, r, next, err)
return
Expand Down Expand Up @@ -163,7 +163,7 @@ func SpendUsage(client Client, eh ErrorHandler) func(next http.Handler) http.Han
return
}

ok, err := client.SpendQuota(ctx, quota, cu, getTime(ctx))
ok, err := client.SpendQuota(ctx, quota, cu, GetTime(ctx))
if err != nil {
eh(w, r, next, err)
return
Expand Down
4 changes: 2 additions & 2 deletions middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ func WithTime(ctx context.Context, now time.Time) context.Context {
return context.WithValue(ctx, ctxKeyTime, now)
}

// getTime returns the time from the context. If the time is not set, it returns the current time.
func getTime(ctx context.Context) time.Time {
// GetTime returns the time from the context. If the time is not set, it returns the current time.
func GetTime(ctx context.Context) time.Time {
v, ok := ctx.Value(ctxKeyTime).(time.Time)
if !ok {
return time.Now().Truncate(time.Hour * 24)
Expand Down
29 changes: 18 additions & 11 deletions quotacontrol.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
)

type LimitStore interface {
GetAccessLimit(ctx context.Context, projectID uint64) (*proto.Limit, error)
GetAccessLimit(ctx context.Context, projectID uint64, cycle *proto.Cycle) (*proto.Limit, error)
}

type AccessKeyStore interface {
Expand Down Expand Up @@ -78,7 +78,7 @@ func (q qcHandler) GetTimeRange(ctx context.Context, projectID uint64, from, to
if from != nil && to != nil {
return *from, *to, nil
}
now := time.Now()
now := middleware.GetTime(ctx)
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, projectID, now)
if err != nil {
return time.Time{}, time.Time{}, err
Expand Down Expand Up @@ -153,14 +153,16 @@ func (q qcHandler) ClearUsage(ctx context.Context, projectID uint64, now time.Ti
}

func (q qcHandler) GetProjectQuota(ctx context.Context, projectID uint64, now time.Time) (*proto.AccessQuota, error) {
limit, err := q.store.LimitStore.GetAccessLimit(ctx, projectID)
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, projectID, now)
if err != nil {
return nil, err
}
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, projectID, now)

limit, err := q.store.LimitStore.GetAccessLimit(ctx, projectID, cycle)
if err != nil {
return nil, err
}

record := proto.AccessQuota{
Limit: limit,
Cycle: cycle,
Expand All @@ -178,11 +180,11 @@ func (q qcHandler) GetAccessQuota(ctx context.Context, accessKey string, now tim
if err != nil {
return nil, err
}
limit, err := q.store.LimitStore.GetAccessLimit(ctx, access.ProjectID)
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, access.ProjectID, now)
if err != nil {
return nil, err
}
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, access.ProjectID, now)
limit, err := q.store.LimitStore.GetAccessLimit(ctx, access.ProjectID, cycle)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -279,7 +281,11 @@ func (q qcHandler) GetDefaultAccessKey(ctx context.Context, projectID uint64) (*
}

func (q qcHandler) CreateAccessKey(ctx context.Context, projectID uint64, displayName string, allowedOrigins []string, allowedServices []*proto.Service) (*proto.AccessKey, error) {
limit, err := q.store.LimitStore.GetAccessLimit(ctx, projectID)
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, projectID, middleware.GetTime(ctx))
if err != nil {
return nil, err
}
limit, err := q.store.LimitStore.GetAccessLimit(ctx, projectID, cycle)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -472,17 +478,18 @@ func (q qcHandler) GetProjectStatus(ctx context.Context, projectID uint64) (*pro
status := proto.ProjectStatus{
ProjectID: projectID,
}
limit, err := q.store.LimitStore.GetAccessLimit(ctx, projectID)

now := middleware.GetTime(ctx)
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, projectID, now)
if err != nil {
return nil, err
}
status.Limit = limit

now := time.Now()
cycle, err := q.store.CycleStore.GetAccessCycle(ctx, projectID, now)
limit, err := q.store.LimitStore.GetAccessLimit(ctx, projectID, cycle)
if err != nil {
return nil, err
}
status.Limit = limit

key := getQuotaKey(projectID, cycle, now)

Expand Down

0 comments on commit 0ec705b

Please sign in to comment.