Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass interface to Client + other improvements #44

Merged
merged 5 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ type Notifier interface {
Notify(access *proto.AccessKey) error
}

func NewClient(logger logger.Logger, service proto.Service, cfg Config) *Client {
func NewClient(logger logger.Logger, service proto.Service, cfg Config, qc proto.QuotaControl) *Client {
options := redis.Options{
Addr: fmt.Sprintf("%s:%d", cfg.Redis.Host, cfg.Redis.Port),
DB: cfg.Redis.DBIndex,
Expand All @@ -41,6 +41,12 @@ func NewClient(logger logger.Logger, service proto.Service, cfg Config) *Client
quotaCache = NewLRU(quotaCache, cfg.LRUSize, lruExpiration)
}

if qc == nil {
qc = proto.NewQuotaControlClient(cfg.URL, &http.Client{
Transport: bearerToken(cfg.AuthToken),
})
}

var ticker *time.Ticker
if cfg.UpdateFreq.Duration > 0 {
ticker = time.NewTicker(cfg.UpdateFreq.Duration)
Expand All @@ -52,14 +58,12 @@ func NewClient(logger logger.Logger, service proto.Service, cfg Config) *Client
usage: &usageTracker{
Usage: make(map[time.Time]usageRecord),
},
usageCache: cache,
quotaCache: quotaCache,
permCache: PermissionCache(cache),
quotaClient: proto.NewQuotaControlClient(cfg.URL, &http.Client{
Transport: bearerToken(cfg.AuthToken),
}),
ticker: ticker,
logger: logger.With(slog.String("service", "quotacontrol")),
usageCache: cache,
quotaCache: quotaCache,
permCache: PermissionCache(cache),
quotaClient: qc,
ticker: ticker,
logger: logger.With(slog.String("service", "quotacontrol")),
}
}

Expand Down
2 changes: 1 addition & 1 deletion common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func newConfig() quotacontrol.Config {

func newQuotaClient(cfg quotacontrol.Config, service proto.Service) *quotacontrol.Client {
logger := logger.NewLogger(logger.LogLevel_DEBUG).With(slog.String("client", "client"))
return quotacontrol.NewClient(logger, service, cfg)
return quotacontrol.NewClient(logger, service, cfg, nil)
}

type hitCounter int64
Expand Down
69 changes: 40 additions & 29 deletions handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) {

r := chi.NewRouter()
r.Use(
middleware.Session(client, auth, nil),
middleware.Session(client, auth, nil, nil),
addCredits(_credits*2).Middleware,
addCredits(_credits*-1).Middleware,
middleware.RateLimit(cfg.RateLimiter, cfg.Redis, nil),
Expand Down Expand Up @@ -360,7 +360,7 @@ func TestJWT(t *testing.T) {
r := chi.NewRouter()

r.Use(
middleware.Session(client, auth, nil),
middleware.Session(client, auth, nil, nil),
middleware.EnsureUsage(client, nil),
middleware.SpendUsage(client, nil),
)
Expand Down Expand Up @@ -434,7 +434,7 @@ func TestJWTAccess(t *testing.T) {

r := chi.NewRouter()
r.Use(
middleware.Session(client, auth, nil),
middleware.Session(client, auth, nil, nil),
middleware.RateLimit(cfg.RateLimiter, cfg.Redis, nil),
middleware.EnsurePermission(client, UserPermission_READ_WRITE, nil),
)
Expand Down Expand Up @@ -496,14 +496,16 @@ func TestJWTAccess(t *testing.T) {
}

func TestSession(t *testing.T) {
auth := jwtauth.New("HS256", []byte("secret"), nil)

project := uint64(7)
key := proto.GenerateAccessKey(project)
service := proto.Service_Indexer
address := "accountId"
serviceName := "serviceName"
const (
project = uint64(7)
accessKey = "AQAAAAAAAAAHkL0mNSrn6Sm3oHs0xfa_DnY"
service = proto.Service_Indexer
address = "walletAddress"
userAddress = "userAddress"
serviceName = "serviceName"
)

auth := jwtauth.New("HS256", []byte("secret"), nil)
counter := hitCounter(0)

cfg := newConfig()
Expand All @@ -516,26 +518,28 @@ func TestSession(t *testing.T) {
MethodAccount = "MethodAccount"
MethodAccessKey = "MethodAccessKey"
MethodProject = "MethodProject"
MethodUser = "MethodUser"
MethodAdmin = "MethodAdmin"
MethodService = "MethodService"
)

ACL := middleware.ACL{
ACL := middleware.ServiceConfig[middleware.ACL]{
"Service": {
MethodPublic: proto.SessionType_Public,
MethodAccount: proto.SessionType_Account,
MethodAccessKey: proto.SessionType_AccessKey,
MethodProject: proto.SessionType_Project,
MethodAdmin: proto.SessionType_Admin,
MethodService: proto.SessionType_Service,
MethodPublic: middleware.NewACL(proto.SessionType_Public.AndUp()...),
MethodAccount: middleware.NewACL(proto.SessionType_Wallet.AndUp()...),
MethodAccessKey: middleware.NewACL(proto.SessionType_AccessKey.AndUp()...),
MethodProject: middleware.NewACL(proto.SessionType_Project.AndUp()...),
MethodUser: middleware.NewACL(proto.SessionType_User.AndUp()...),
MethodAdmin: middleware.NewACL(proto.SessionType_Admin.AndUp()...),
MethodService: middleware.NewACL(proto.SessionType_Service.AndUp()...),
},
}

r := chi.NewRouter()
r.Use(
middleware.Session(client, auth, nil),
middleware.Session(client, auth, server.Store, nil),
middleware.RateLimit(cfg.RateLimiter, cfg.Redis, nil),
middleware.AccessControl(ACL, middleware.Cost{}, 1, nil),
middleware.AccessControl(ACL, nil, 1, nil),
)
r.Handle("/*", &counter)

Expand All @@ -547,9 +551,10 @@ func TestSession(t *testing.T) {
OverWarn: 7,
OverMax: 10,
}
server.Store.AddUser(ctx, userAddress)
server.Store.SetAccessLimit(ctx, project, &limit)
server.Store.SetUserPermission(ctx, project, address, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: project})
server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project})
server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: accessKey, ProjectID: project})

type testCase struct {
AccessKey string
Expand All @@ -561,32 +566,35 @@ func TestSession(t *testing.T) {
Session: proto.SessionType_Public,
},
{
Session: proto.SessionType_Account,
Session: proto.SessionType_Wallet,
},
{
AccessKey: key,
Session: proto.SessionType_AccessKey,
AccessKey: accessKey,
},
{
Session: proto.SessionType_Project,
},
{
AccessKey: key,
Session: proto.SessionType_Project,
AccessKey: accessKey,
},
{
Session: proto.SessionType_User,
},
{
Session: proto.SessionType_Admin,
},
{
AccessKey: key,
Session: proto.SessionType_Admin,
AccessKey: accessKey,
},
{
Session: proto.SessionType_Service,
},
{
AccessKey: key,
Session: proto.SessionType_Service,
AccessKey: accessKey,
},
}

Expand All @@ -602,29 +610,32 @@ func TestSession(t *testing.T) {
MethodAccount,
MethodAccessKey,
MethodProject,
MethodUser,
MethodAdmin,
MethodService,
}

for service := range ACL {
for _, method := range methods {
minSession := ACL[service][method]
types := ACL[service][method]
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s/%s/%s", method, tc.Session, tc.AccessKey), func(t *testing.T) {
var claims middleware.Claims
switch tc.Session {
case proto.SessionType_Account:
case proto.SessionType_Wallet:
claims = middleware.Claims{"account": address}
case proto.SessionType_Project:
claims = middleware.Claims{"account": address, "project": project}
case proto.SessionType_User:
claims = middleware.Claims{"account": userAddress}
case proto.SessionType_Admin:
claims = middleware.Claims{"account": address, "admin": true}
case proto.SessionType_Service:
claims = middleware.Claims{"service": serviceName}
}

ok, h, err := executeRequest(ctx, r, "/rpc/"+service+"/"+method, tc.AccessKey, mustJWT(t, auth, claims))
if tc.Session < minSession {
if !types.Includes(tc.Session) {
assert.Error(t, err)
assert.False(t, ok)
return
Expand All @@ -639,7 +650,7 @@ func TestSession(t *testing.T) {
case proto.SessionType_AccessKey, proto.SessionType_Project:
assert.Equal(t, quotaRPM, rateLimit)
assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), h.Get(middleware.HeaderQuotaLimit))
case proto.SessionType_Account, proto.SessionType_Admin:
case proto.SessionType_Wallet, proto.SessionType_Admin, proto.SessionType_User:
assert.Equal(t, accountRPM, rateLimit)
case proto.SessionType_Service:
assert.Equal(t, serviceRPM, rateLimit)
Expand Down
19 changes: 19 additions & 0 deletions mem.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ func NewMemoryStore() *MemoryStore {
ByProjectID: map[uint64]*proto.AccessUsage{},
ByAccessKey: map[string]*proto.AccessUsage{},
},
users: map[string]struct{}{},
permissions: map[uint64]map[string]userPermission{},
}
}
Expand All @@ -33,6 +34,7 @@ type MemoryStore struct {
cycles map[uint64]proto.Cycle
accessKeys map[string]proto.AccessKey
usage usageRecord
users map[string]struct{}
permissions map[uint64]map[string]userPermission
}

Expand Down Expand Up @@ -167,6 +169,23 @@ func (m *MemoryStore) ResetUsage(ctx context.Context, accessKey string) error {
return nil
}

func (m *MemoryStore) AddUser(ctx context.Context, userID string) error {
m.Lock()
m.users[userID] = struct{}{}
m.Unlock()
return nil
}

func (m *MemoryStore) GetUser(ctx context.Context, userID string) (any, error) {
m.Lock()
v, ok := m.users[userID]
m.Unlock()
if !ok {
return nil, nil
}
return v, nil
}

func (m *MemoryStore) GetUserPermission(ctx context.Context, projectID uint64, userID string) (proto.UserPermission, *proto.ResourceAccess, error) {
m.Lock()
defer m.Unlock()
Expand Down
50 changes: 37 additions & 13 deletions middleware/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,19 @@ func (c Claims) String() string {

type ServiceConfig[T any] map[string]map[string]T

type (
ACL = ServiceConfig[proto.SessionType]
Cost = ServiceConfig[int64]
)

func (s ServiceConfig[T]) GetConfig(r *rcpRequest) (v T, ok bool) {
if r.Package != "rpc" || s == nil {
if s == nil || r.Package != "rpc" {
return v, false
}

serviceACL, ok := s[r.Service]
serviceCfg, ok := s[r.Service]
if !ok {
return v, false
}

// get method's ACL
cfg, ok := serviceACL[r.Method]
methodCfg, ok := serviceCfg[r.Method]
if !ok {
return v, false
}

return cfg, true
return methodCfg, true
}

type rcpRequest struct {
Expand Down Expand Up @@ -96,3 +87,36 @@ func swapHeader(h http.Header, from, to string) {
h.Del(from)
}
}

// ACL is a list of session types, encoded as a bitfield.
// SessionType(n) is represented by n=-the bit.
type ACL uint64

func NewACL(t ...proto.SessionType) ACL {
var types ACL
for _, v := range t {
types = types.And(v)
}
return types
}

func (t ACL) And(types ...proto.SessionType) ACL {
for _, v := range types {
t |= 1 << v
}
return t
}

func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

func (t ACL) SessionTypes() []proto.SessionType {
var types []proto.SessionType
for i := proto.SessionType(0); i < proto.SessionType_Max; i++ {
if t.Includes(i) {
types = append(types, i)
}
}
return types
}
18 changes: 15 additions & 3 deletions middleware/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ func (k *contextKey) String() string {
var (
ctxKeySessionType = &contextKey{"SessionType"}
ctxKeyAccount = &contextKey{"Account"}
ctxKeyUser = &contextKey{"User"}
ctxKeyService = &contextKey{"Service"}
ctxKeyAccessKey = &contextKey{"AccessKey"}
ctxKeyAccessQuota = &contextKey{"AccessQuota"}
Expand All @@ -28,8 +29,8 @@ var (
ctxKeySpending = &contextKey{"Spending"}
)

// withSessionType adds the access key to the context.
func withSessionType(ctx context.Context, accessType proto.SessionType) context.Context {
// WithSessionType adds the access key to the context.
func WithSessionType(ctx context.Context, accessType proto.SessionType) context.Context {
return context.WithValue(ctx, ctxKeySessionType, accessType)
}

Expand All @@ -43,7 +44,7 @@ func GetSessionType(ctx context.Context) proto.SessionType {
}

// WithAccount adds the account to the context.
func withAccount(ctx context.Context, account string) context.Context {
func WithAccount(ctx context.Context, account string) context.Context {
return context.WithValue(ctx, ctxKeyAccount, account)
}

Expand All @@ -53,6 +54,17 @@ func GetAccount(ctx context.Context) (string, bool) {
return v, ok
}

// WithUser adds the user to the context.
func WithUser(ctx context.Context, user any) context.Context {
return context.WithValue(ctx, ctxKeyUser, user)
}

// GetUser returns the user from the context.
func GetUser[T any](ctx context.Context) (T, bool) {
v, ok := ctx.Value(ctxKeyUser).(T)
return v, ok
}

// withService adds the service to the context.
func withService(ctx context.Context, service string) context.Context {
return context.WithValue(ctx, ctxKeyService, service)
Expand Down
Loading
Loading