Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to authcontrol #53

Merged
merged 8 commits into from
Oct 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
.vscode
coverage.out
88 changes: 88 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
linters:
# Disable all linters.
disable-all: true
# Enable specific linter
enable:
- errcheck
- gci
- wrapcheck
- ineffassign
- unused

run:
# Number of operating system threads (`GOMAXPROCS`) that can execute golangci-lint simultaneously.
# If it is explicitly set to 0 (i.e. not the default) then golangci-lint will automatically set the value to match Linux container CPU quota.
# Default: the number of logical CPUs in the machine
concurrency: 8
# Timeout for analysis, e.g. 30s, 5m.
# Default: 1m
timeout: 5m
go: '1.22.0'

output:
# Show statistics per linter.
show-stats: true
# Sort results by the order defined in `sort-order`.
sort-results: true
# Order to use when sorting results.
# Require `sort-results` to `true`.
# Possible values: `file`, `linter`, and `severity`.
#
# If the severity values are inside the following list, they are ordered in this order:
# 1. error
# 2. warning
# 3. high
# 4. medium
# 5. low
# Either they are sorted alphabetically.
sort-order:
- linter
- file
- severity

issues:
# Maximum issues count per one linter.
# Set to 0 to disable.
# Default: 50
max-issues-per-linter: 0
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 0
exclude-rules:
- linters:
- lll
source: "^//go:generate "
exclude-dirs:
- "tools"
exclude-files:
- ".*\\.gen\\.go$"
- ".*\\.ridl$"

linters-settings:
errcheck:
# List of functions to exclude from checking, where each entry is a single function to exclude.
# See https://github.com/kisielk/errcheck#excluding-functions for details.
exclude-functions:
- (net/http.ResponseWriter).Write

gci:
# Section configuration to compare against.
# Section names are case-insensitive and may contain parameters in ().
# The default order of sections is `standard > default > custom > blank > dot > alias > localmodule`,
# If `custom-order` is `true`, it follows the order of `sections` option.
# Default: ["standard", "default"]
sections:
- standard # Standard section: captures all standard packages.
- default # Default section: contains all imports that could not be matched to another section type.
- prefix(github.com/0xsequence/authcontrol) # Custom section: groups all imports with the specified Prefix.
# Skip generated files.
# Default: true
skip-generated: true
# Enable custom order of sections.
# If `true`, make the section order the same as the order of `sections`.
# Default: false
custom-order: true
# Drops lexical ordering for custom sections.
# Default: false
no-lex-order: false
36 changes: 26 additions & 10 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,13 +1,29 @@
GO_TEST = $(shell if ! command -v gotest &> /dev/null; then echo "go test"; else echo "gotest"; fi)
TEST_FLAGS ?= -p 8 -failfast -race -shuffle on

.PHONY: build
build:
go build ./...
all:
@echo "make <cmd>:"
@echo ""
@echo "commands:"
@awk -F'[ :]' '/^#+/ {comment=$$0; gsub(/^#+[ ]*/, "", comment)} !/^(_|all:)/ && /^([A-Za-z_-]+):/ && !seen[$$1]++ {printf " %-24s %s\n", $$1, (comment ? "- " comment : ""); comment=""} !/^#+/ {comment=""}' Makefile

.PHONY: proto
proto:
go generate ./proto
test-clean:
go clean -testcache

test: test-clean
go test -run=$(TEST) $(TEST_FLAGS) -json ./... | tparse --all --follow

test-rerun: test-clean
go run github.com/goware/rerun/cmd/rerun -watch ./ -run 'make test'

test-coverage:
go test -run=$(TEST) $(TEST_FLAGS) -cover -coverprofile=coverage.out -json ./... | tparse --all --follow

test-coverage-inspect: test-coverage
go tool cover -html=coverage.out

generate:
go generate -x ./...

lint:
golangci-lint run ./... --fix -c .golangci.yml

.PHONY: test
test:
go clean -testcache && $(GO_TEST) -v -p=1 ./...
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ The package offers a Redis implementation for the cache, and a Memory version of

The methods that are used to save/load in a permanent storage the 3 entities are not implemented.
The requests are measure in compute units, if a compute unit is not specified it is assumed that the value it's 1.
A client can specify the amount of compute units by manipulating the request context using the `WithComputeUnits` function.
A client can specify the amount of compute units by manipulating the request context using the `WithCost` function.

# Increment operation

The client method `SpendComputeUnits` takes care of doing an increment operation in the cache. And works as follows:
The client method `SpendCost` takes care of doing an increment operation in the cache. And works as follows:
- It tries to fetch the usage record from the cache
- On a hit it executes the INCR operation.
- On a miss it sets it to `-1` and ask the server to populate it.
Expand Down
20 changes: 10 additions & 10 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ type QuotaCache interface {
}

type UsageCache interface {
SetComputeUnits(ctx context.Context, redisKey string, amount int64) error
ClearComputeUnits(ctx context.Context, redisKey string) (bool, error)
PeekComputeUnits(ctx context.Context, redisKey string) (int64, error)
SpendComputeUnits(ctx context.Context, redisKey string, amount, limit int64) (int64, error)
SetUsage(ctx context.Context, redisKey string, amount int64) error
ClearUsage(ctx context.Context, redisKey string) (bool, error)
PeekUsage(ctx context.Context, redisKey string) (int64, error)
SpendUsage(ctx context.Context, redisKey string, amount, limit int64) (int64, error)
}

type PermissionCache interface {
Expand Down Expand Up @@ -116,12 +116,12 @@ func (s *RedisCache) setQuota(ctx context.Context, key string, quota *proto.Acce
return nil
}

func (s *RedisCache) SetComputeUnits(ctx context.Context, redisKey string, amount int64) error {
func (s *RedisCache) SetUsage(ctx context.Context, redisKey string, amount int64) error {
cacheKey := fmt.Sprintf("%s%s", redisKeyPrefix, redisKey)
return s.client.Set(ctx, cacheKey, amount, s.ttl).Err()
}

func (s *RedisCache) ClearComputeUnits(ctx context.Context, redisKey string) (bool, error) {
func (s *RedisCache) ClearUsage(ctx context.Context, redisKey string) (bool, error) {
cacheKey := fmt.Sprintf("%s%s", redisKeyPrefix, redisKey)
count, err := s.client.Del(ctx, cacheKey).Result()
if err != nil {
Expand All @@ -130,7 +130,7 @@ func (s *RedisCache) ClearComputeUnits(ctx context.Context, redisKey string) (bo
return count != 0, nil
}

func (s *RedisCache) PeekComputeUnits(ctx context.Context, redisKey string) (int64, error) {
func (s *RedisCache) PeekUsage(ctx context.Context, redisKey string) (int64, error) {
const SpecialValue = -1
cacheKey := fmt.Sprintf("%s%s", redisKeyPrefix, redisKey)
v, err := s.client.Get(ctx, cacheKey).Int64()
Expand All @@ -153,9 +153,9 @@ func (s *RedisCache) PeekComputeUnits(ctx context.Context, redisKey string) (int
return 0, ErrCachePing
}

func (s *RedisCache) SpendComputeUnits(ctx context.Context, redisKey string, amount, limit int64) (int64, error) {
// NOTE: skip redisKeyPrefix as it's already in PeekComputeUnits
v, err := s.PeekComputeUnits(ctx, redisKey)
func (s *RedisCache) SpendUsage(ctx context.Context, redisKey string, amount, limit int64) (int64, error) {
// NOTE: skip redisKeyPrefix as it's already in PeekCost
v, err := s.PeekUsage(ctx, redisKey)
if err != nil {
return 0, err
}
Expand Down
21 changes: 11 additions & 10 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"sync/atomic"
"time"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/quotacontrol/middleware"
"github.com/0xsequence/quotacontrol/proto"
"github.com/goware/logger"
Expand Down Expand Up @@ -166,13 +167,13 @@ func (c *Client) FetchUsage(ctx context.Context, quota *proto.AccessQuota, now t
)

for i := range 3 {
usage, err := c.usageCache.PeekComputeUnits(ctx, key)
usage, err := c.usageCache.PeekUsage(ctx, key)
if err != nil {
// ping the server to prepare usage
if errors.Is(err, ErrCachePing) {
if _, err := c.quotaClient.PrepareUsage(ctx, quota.AccessKey.ProjectID, quota.Cycle, now); err != nil {
logger.Error("unexpected client error", slog.Any("error", err))
if _, err := c.usageCache.ClearComputeUnits(ctx, key); err != nil {
if _, err := c.usageCache.ClearUsage(ctx, key); err != nil {
logger.Error("unexpected cache error", slog.Any("error", err))
}
return 0, nil
Expand All @@ -197,7 +198,7 @@ func (c *Client) FetchUsage(ctx context.Context, quota *proto.AccessQuota, now t
}

func (c *Client) CheckPermission(ctx context.Context, projectID uint64, minPermission proto.UserPermission) (bool, error) {
if sessionType, _ := middleware.GetSessionType(ctx); sessionType >= proto.SessionType_Admin {
if sessionType, _ := authcontrol.GetSessionType(ctx); sessionType >= proto.SessionType_Admin {
return true, nil
}
perm, _, err := c.FetchPermission(ctx, projectID)
Expand All @@ -210,7 +211,7 @@ func (c *Client) CheckPermission(ctx context.Context, projectID uint64, minPermi
// FetchPermission fetches the user permission from cache or from the quota server.
// If an error occurs, it returns nil.
func (c *Client) FetchPermission(ctx context.Context, projectID uint64) (proto.UserPermission, *proto.ResourceAccess, error) {
userID, _ := middleware.GetAccount(ctx)
userID, _ := authcontrol.GetAccount(ctx)
logger := c.logger.With(
slog.String("op", "fetch_permission"),
slog.Uint64("project_id", projectID),
Expand All @@ -235,9 +236,9 @@ func (c *Client) FetchPermission(ctx context.Context, projectID uint64) (proto.U
return perm, access, nil
}

func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, computeUnits int64, now time.Time) (spent bool, total int64, err error) {
func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, cost int64, now time.Time) (spent bool, total int64, err error) {
// quota is nil only on unexpected errors from quota fetch
if quota == nil || computeUnits == 0 {
if quota == nil || cost == 0 {
return false, 0, nil
}

Expand All @@ -254,18 +255,18 @@ func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, compu
key := getQuotaKey(quota.AccessKey.ProjectID, quota.Cycle, now)

for i := range 3 {
total, err := c.usageCache.SpendComputeUnits(ctx, key, computeUnits, cfg.OverMax)
total, err := c.usageCache.SpendUsage(ctx, key, cost, cfg.OverMax)
if err != nil {
// limit exceeded
if errors.Is(err, proto.ErrLimitExceeded) {
c.usage.AddKeyUsage(accessKey, now, proto.AccessUsage{LimitedCompute: computeUnits})
c.usage.AddKeyUsage(accessKey, now, proto.AccessUsage{LimitedCompute: cost})
return false, total, proto.ErrLimitExceeded
}
// ping the server to prepare usage
if errors.Is(err, ErrCachePing) {
if _, err := c.quotaClient.PrepareUsage(ctx, quota.AccessKey.ProjectID, quota.Cycle, now); err != nil {
logger.Error("unexpected client error", slog.Any("error", err))
if _, err := c.usageCache.ClearComputeUnits(ctx, key); err != nil {
if _, err := c.usageCache.ClearUsage(ctx, key); err != nil {
logger.Error("unexpected cache error", slog.Any("error", err))
}
return false, 0, nil
Expand All @@ -284,7 +285,7 @@ func (c *Client) SpendQuota(ctx context.Context, quota *proto.AccessQuota, compu

}

usage, event := cfg.GetSpendResult(computeUnits, total)
usage, event := cfg.GetSpendResult(cost, total)
if quota.AccessKey.AccessKey == "" {
c.usage.AddProjectUsage(quota.AccessKey.ProjectID, now, usage)
} else {
Expand Down
19 changes: 3 additions & 16 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@ import (
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"

"github.com/0xsequence/quotacontrol"
"github.com/0xsequence/quotacontrol/middleware"
"github.com/0xsequence/quotacontrol/proto"
"github.com/go-chi/jwtauth/v5"
"github.com/goware/logger"

"github.com/goware/cachestore/redis"
"github.com/stretchr/testify/require"
)

func newConfig() quotacontrol.Config {
Expand Down Expand Up @@ -63,16 +60,6 @@ func (c *spendingCounter) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

func mustJWT(t *testing.T, auth *jwtauth.JWTAuth, claims map[string]interface{}) string {
t.Helper()
if claims == nil {
return ""
}
_, token, err := auth.Encode(claims)
require.NoError(t, err)
return token
}

func executeRequest(ctx context.Context, handler http.Handler, path, accessKey, jwt string) (bool, http.Header, error) {
req, err := http.NewRequest("POST", path, nil)
if err != nil {
Expand All @@ -98,10 +85,10 @@ func executeRequest(ctx context.Context, handler http.Handler, path, accessKey,
return true, rr.Header(), nil
}

type addCredits int64
type addCost int64

func (i addCredits) Middleware(h http.Handler) http.Handler {
func (i addCost) Middleware(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
h.ServeHTTP(w, r.WithContext(middleware.AddComputeUnits(r.Context(), int64(i))))
h.ServeHTTP(w, r.WithContext(middleware.AddCost(r.Context(), int64(i))))
})
}
16 changes: 16 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"time"

"github.com/0xsequence/quotacontrol/middleware"
"github.com/0xsequence/quotacontrol/proto"
"github.com/goware/cachestore/redis"
)

Expand All @@ -27,7 +28,22 @@ type Config struct {
DefaultUsage *int64 `toml:"default_usage"`
LRUSize int `toml:"lru_size"`
LRUExpiration Duration `toml:"lru_expiration"`
ErrorConfig ErrorConfig `toml:"error_config"`

// DangerMode is used for debugging
DangerMode bool `toml:"danger_mode"`
}

type ErrorConfig struct {
MessageQuota string `toml:"quota_message"`
MessageRate string `toml:"ratelimit_message"`
}

func (e ErrorConfig) Apply() {
if e.MessageQuota != "" {
proto.ErrLimitExceeded.Message = e.MessageQuota
}
if e.MessageRate != "" {
proto.ErrRateLimit.Message = e.MessageRate
}
}
Loading
Loading