diff --git a/client.go b/client.go index b8d66af..9f00f9f 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,7 @@ type Notifier interface { Notify(access *proto.AccessKey) error } -func NewClient(logger logger.Logger, service proto.Service, cfg Config) *Client { +func NewClient(logger logger.Logger, service proto.Service, cfg Config, qc proto.QuotaControl) *Client { options := redis.Options{ Addr: fmt.Sprintf("%s:%d", cfg.Redis.Host, cfg.Redis.Port), DB: cfg.Redis.DBIndex, @@ -41,6 +41,12 @@ func NewClient(logger logger.Logger, service proto.Service, cfg Config) *Client quotaCache = NewLRU(quotaCache, cfg.LRUSize, lruExpiration) } + if qc == nil { + qc = proto.NewQuotaControlClient(cfg.URL, &http.Client{ + Transport: bearerToken(cfg.AuthToken), + }) + } + var ticker *time.Ticker if cfg.UpdateFreq.Duration > 0 { ticker = time.NewTicker(cfg.UpdateFreq.Duration) @@ -52,14 +58,12 @@ func NewClient(logger logger.Logger, service proto.Service, cfg Config) *Client usage: &usageTracker{ Usage: make(map[time.Time]usageRecord), }, - usageCache: cache, - quotaCache: quotaCache, - permCache: PermissionCache(cache), - quotaClient: proto.NewQuotaControlClient(cfg.URL, &http.Client{ - Transport: bearerToken(cfg.AuthToken), - }), - ticker: ticker, - logger: logger.With(slog.String("service", "quotacontrol")), + usageCache: cache, + quotaCache: quotaCache, + permCache: PermissionCache(cache), + quotaClient: qc, + ticker: ticker, + logger: logger.With(slog.String("service", "quotacontrol")), } } diff --git a/common_test.go b/common_test.go index 2845ff7..2bfd1d7 100644 --- a/common_test.go +++ b/common_test.go @@ -35,7 +35,7 @@ func newConfig() quotacontrol.Config { func newQuotaClient(cfg quotacontrol.Config, service proto.Service) *quotacontrol.Client { logger := logger.NewLogger(logger.LogLevel_DEBUG).With(slog.String("client", "client")) - return quotacontrol.NewClient(logger, service, cfg) + return quotacontrol.NewClient(logger, service, cfg, nil) } type hitCounter int64 diff --git a/handler_test.go b/handler_test.go index f13e2da..9798fc1 100644 --- a/handler_test.go +++ b/handler_test.go @@ -52,7 +52,7 @@ func TestMiddlewareUseAccessKey(t *testing.T) { r := chi.NewRouter() r.Use( - middleware.Session(client, auth, nil), + middleware.Session(client, auth, nil, nil), addCredits(_credits*2).Middleware, addCredits(_credits*-1).Middleware, middleware.RateLimit(cfg.RateLimiter, cfg.Redis, nil), @@ -360,7 +360,7 @@ func TestJWT(t *testing.T) { r := chi.NewRouter() r.Use( - middleware.Session(client, auth, nil), + middleware.Session(client, auth, nil, nil), middleware.EnsureUsage(client, nil), middleware.SpendUsage(client, nil), ) @@ -434,7 +434,7 @@ func TestJWTAccess(t *testing.T) { r := chi.NewRouter() r.Use( - middleware.Session(client, auth, nil), + middleware.Session(client, auth, nil, nil), middleware.RateLimit(cfg.RateLimiter, cfg.Redis, nil), middleware.EnsurePermission(client, UserPermission_READ_WRITE, nil), ) @@ -496,14 +496,16 @@ func TestJWTAccess(t *testing.T) { } func TestSession(t *testing.T) { - auth := jwtauth.New("HS256", []byte("secret"), nil) - - project := uint64(7) - key := proto.GenerateAccessKey(project) - service := proto.Service_Indexer - address := "accountId" - serviceName := "serviceName" + const ( + project = uint64(7) + accessKey = "AQAAAAAAAAAHkL0mNSrn6Sm3oHs0xfa_DnY" + service = proto.Service_Indexer + address = "walletAddress" + userAddress = "userAddress" + serviceName = "serviceName" + ) + auth := jwtauth.New("HS256", []byte("secret"), nil) counter := hitCounter(0) cfg := newConfig() @@ -516,26 +518,28 @@ func TestSession(t *testing.T) { MethodAccount = "MethodAccount" MethodAccessKey = "MethodAccessKey" MethodProject = "MethodProject" + MethodUser = "MethodUser" MethodAdmin = "MethodAdmin" MethodService = "MethodService" ) - ACL := middleware.ACL{ + ACL := middleware.ServiceConfig[middleware.ACL]{ "Service": { - MethodPublic: proto.SessionType_Public, - MethodAccount: proto.SessionType_Account, - MethodAccessKey: proto.SessionType_AccessKey, - MethodProject: proto.SessionType_Project, - MethodAdmin: proto.SessionType_Admin, - MethodService: proto.SessionType_Service, + MethodPublic: middleware.NewACL(proto.SessionType_Public.OrHigher()...), + MethodAccount: middleware.NewACL(proto.SessionType_Wallet.OrHigher()...), + MethodAccessKey: middleware.NewACL(proto.SessionType_AccessKey.OrHigher()...), + MethodProject: middleware.NewACL(proto.SessionType_Project.OrHigher()...), + MethodUser: middleware.NewACL(proto.SessionType_User.OrHigher()...), + MethodAdmin: middleware.NewACL(proto.SessionType_Admin.OrHigher()...), + MethodService: middleware.NewACL(proto.SessionType_Service.OrHigher()...), }, } r := chi.NewRouter() r.Use( - middleware.Session(client, auth, nil), + middleware.Session(client, auth, server.Store, nil), middleware.RateLimit(cfg.RateLimiter, cfg.Redis, nil), - middleware.AccessControl(ACL, middleware.Cost{}, 1, nil), + middleware.AccessControl(ACL, nil, 1, nil), ) r.Handle("/*", &counter) @@ -547,9 +551,10 @@ func TestSession(t *testing.T) { OverWarn: 7, OverMax: 10, } + server.Store.AddUser(ctx, userAddress) server.Store.SetAccessLimit(ctx, project, &limit) server.Store.SetUserPermission(ctx, project, address, proto.UserPermission_READ, proto.ResourceAccess{ProjectID: project}) - server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: key, ProjectID: project}) + server.Store.InsertAccessKey(ctx, &proto.AccessKey{Active: true, AccessKey: accessKey, ProjectID: project}) type testCase struct { AccessKey string @@ -561,32 +566,35 @@ func TestSession(t *testing.T) { Session: proto.SessionType_Public, }, { - Session: proto.SessionType_Account, + Session: proto.SessionType_Wallet, }, { - AccessKey: key, Session: proto.SessionType_AccessKey, + AccessKey: accessKey, }, { Session: proto.SessionType_Project, }, { - AccessKey: key, Session: proto.SessionType_Project, + AccessKey: accessKey, + }, + { + Session: proto.SessionType_User, }, { Session: proto.SessionType_Admin, }, { - AccessKey: key, Session: proto.SessionType_Admin, + AccessKey: accessKey, }, { Session: proto.SessionType_Service, }, { - AccessKey: key, Session: proto.SessionType_Service, + AccessKey: accessKey, }, } @@ -602,21 +610,24 @@ func TestSession(t *testing.T) { MethodAccount, MethodAccessKey, MethodProject, + MethodUser, MethodAdmin, MethodService, } for service := range ACL { for _, method := range methods { - minSession := ACL[service][method] + types := ACL[service][method] for _, tc := range testCases { t.Run(fmt.Sprintf("%s/%s/%s", method, tc.Session, tc.AccessKey), func(t *testing.T) { var claims middleware.Claims switch tc.Session { - case proto.SessionType_Account: + case proto.SessionType_Wallet: claims = middleware.Claims{"account": address} case proto.SessionType_Project: claims = middleware.Claims{"account": address, "project": project} + case proto.SessionType_User: + claims = middleware.Claims{"account": userAddress} case proto.SessionType_Admin: claims = middleware.Claims{"account": address, "admin": true} case proto.SessionType_Service: @@ -624,7 +635,7 @@ func TestSession(t *testing.T) { } ok, h, err := executeRequest(ctx, r, "/rpc/"+service+"/"+method, tc.AccessKey, mustJWT(t, auth, claims)) - if tc.Session < minSession { + if !types.Includes(tc.Session) { assert.Error(t, err) assert.False(t, ok) return @@ -639,7 +650,7 @@ func TestSession(t *testing.T) { case proto.SessionType_AccessKey, proto.SessionType_Project: assert.Equal(t, quotaRPM, rateLimit) assert.Equal(t, strconv.FormatInt(limit.FreeMax, 10), h.Get(middleware.HeaderQuotaLimit)) - case proto.SessionType_Account, proto.SessionType_Admin: + case proto.SessionType_Wallet, proto.SessionType_Admin, proto.SessionType_User: assert.Equal(t, accountRPM, rateLimit) case proto.SessionType_Service: assert.Equal(t, serviceRPM, rateLimit) diff --git a/mem.go b/mem.go index d07c66b..b3f5301 100644 --- a/mem.go +++ b/mem.go @@ -17,6 +17,7 @@ func NewMemoryStore() *MemoryStore { ByProjectID: map[uint64]*proto.AccessUsage{}, ByAccessKey: map[string]*proto.AccessUsage{}, }, + users: map[string]struct{}{}, permissions: map[uint64]map[string]userPermission{}, } } @@ -33,6 +34,7 @@ type MemoryStore struct { cycles map[uint64]proto.Cycle accessKeys map[string]proto.AccessKey usage usageRecord + users map[string]struct{} permissions map[uint64]map[string]userPermission } @@ -167,6 +169,23 @@ func (m *MemoryStore) ResetUsage(ctx context.Context, accessKey string) error { return nil } +func (m *MemoryStore) AddUser(ctx context.Context, userID string) error { + m.Lock() + m.users[userID] = struct{}{} + m.Unlock() + return nil +} + +func (m *MemoryStore) GetUser(ctx context.Context, userID string) (any, error) { + m.Lock() + v, ok := m.users[userID] + m.Unlock() + if !ok { + return nil, nil + } + return v, nil +} + func (m *MemoryStore) GetUserPermission(ctx context.Context, projectID uint64, userID string) (proto.UserPermission, *proto.ResourceAccess, error) { m.Lock() defer m.Unlock() diff --git a/middleware/common.go b/middleware/common.go index 3bb1f6a..32c7808 100644 --- a/middleware/common.go +++ b/middleware/common.go @@ -40,28 +40,19 @@ func (c Claims) String() string { type ServiceConfig[T any] map[string]map[string]T -type ( - ACL = ServiceConfig[proto.SessionType] - Cost = ServiceConfig[int64] -) - func (s ServiceConfig[T]) GetConfig(r *rcpRequest) (v T, ok bool) { - if r.Package != "rpc" || s == nil { + if s == nil || r.Package != "rpc" { return v, false } - - serviceACL, ok := s[r.Service] + serviceCfg, ok := s[r.Service] if !ok { return v, false } - - // get method's ACL - cfg, ok := serviceACL[r.Method] + methodCfg, ok := serviceCfg[r.Method] if !ok { return v, false } - - return cfg, true + return methodCfg, true } type rcpRequest struct { @@ -96,3 +87,36 @@ func swapHeader(h http.Header, from, to string) { h.Del(from) } } + +// ACL is a list of session types, encoded as a bitfield. +// SessionType(n) is represented by n=-the bit. +type ACL uint64 + +func NewACL(t ...proto.SessionType) ACL { + var types ACL + for _, v := range t { + types = types.And(v) + } + return types +} + +func (t ACL) And(types ...proto.SessionType) ACL { + for _, v := range types { + t |= 1 << v + } + return t +} + +func (t ACL) Includes(session proto.SessionType) bool { + return t&ACL(1< 0 { projectID := uint64(projectClaim) if quota, err = client.FetchProjectQuota(ctx, projectID, now); err != nil { eh(w, err) @@ -116,7 +142,7 @@ func Session(client Client, auth *jwtauth.JWTAuth, eh ErrHandler, keyFuncs ...Ke sessionType = max(sessionType, proto.SessionType_AccessKey) } - ctx = withSessionType(ctx, sessionType) + ctx = WithSessionType(ctx, sessionType) if quota != nil { ctx = withAccessQuota(ctx, quota) diff --git a/proto/clients/quotacontrol.gen.ts b/proto/clients/quotacontrol.gen.ts index b2c3cf9..e667afb 100644 --- a/proto/clients/quotacontrol.gen.ts +++ b/proto/clients/quotacontrol.gen.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -// quota-control v0.1.0 f99ad8b9b6530c14703d10020bbe02d636527a9c +// quota-control v0.1.0 58eee330d15ad7d819698e8ef6ae774962e4bdf8 // -- // Code generated by webrpc-gen@v0.18.6 with typescript@v0.12.0 generator. DO NOT EDIT. // @@ -12,7 +12,7 @@ export const WebRPCVersion = "v1" export const WebRPCSchemaVersion = "v0.1.0" // Schema hash generated from your RIDL schema -export const WebRPCSchemaHash = "f99ad8b9b6530c14703d10020bbe02d636527a9c" +export const WebRPCSchemaHash = "58eee330d15ad7d819698e8ef6ae774962e4bdf8" // // Types @@ -25,7 +25,8 @@ export enum Service { Indexer = 'Indexer', Relayer = 'Relayer', Metadata = 'Metadata', - Marketplace = 'Marketplace' + Marketplace = 'Marketplace', + Builder = 'Builder' } export enum EventType { @@ -37,9 +38,10 @@ export enum EventType { export enum SessionType { Public = 'Public', - Account = 'Account', + Wallet = 'Wallet', AccessKey = 'AccessKey', Project = 'Project', + User = 'User', Admin = 'Admin', Service = 'Service' } diff --git a/proto/proto.gen.go b/proto/proto.gen.go index 97fe53a..77a8002 100644 --- a/proto/proto.gen.go +++ b/proto/proto.gen.go @@ -1,4 +1,4 @@ -// quota-control v0.1.0 f99ad8b9b6530c14703d10020bbe02d636527a9c +// quota-control v0.1.0 58eee330d15ad7d819698e8ef6ae774962e4bdf8 // -- // Code generated by webrpc-gen@v0.18.6 with golang@v0.14.8 generator. DO NOT EDIT. // @@ -33,7 +33,7 @@ func WebRPCSchemaVersion() string { // Schema hash generated from your RIDL schema func WebRPCSchemaHash() string { - return "f99ad8b9b6530c14703d10020bbe02d636527a9c" + return "58eee330d15ad7d819698e8ef6ae774962e4bdf8" } // @@ -49,6 +49,7 @@ const ( Service_Relayer Service = 3 Service_Metadata Service = 4 Service_Marketplace Service = 5 + Service_Builder Service = 6 ) var Service_name = map[uint16]string{ @@ -58,6 +59,7 @@ var Service_name = map[uint16]string{ 3: "Relayer", 4: "Metadata", 5: "Marketplace", + 6: "Builder", } var Service_value = map[string]uint16{ @@ -67,6 +69,7 @@ var Service_value = map[string]uint16{ "Relayer": 3, "Metadata": 4, "Marketplace": 5, + "Builder": 6, } func (x Service) String() string { @@ -146,29 +149,32 @@ type SessionType uint16 const ( SessionType_Public SessionType = 0 - SessionType_Account SessionType = 1 + SessionType_Wallet SessionType = 1 SessionType_AccessKey SessionType = 2 SessionType_Project SessionType = 3 - SessionType_Admin SessionType = 4 - SessionType_Service SessionType = 5 + SessionType_User SessionType = 4 + SessionType_Admin SessionType = 5 + SessionType_Service SessionType = 6 ) var SessionType_name = map[uint16]string{ 0: "Public", - 1: "Account", + 1: "Wallet", 2: "AccessKey", 3: "Project", - 4: "Admin", - 5: "Service", + 4: "User", + 5: "Admin", + 6: "Service", } var SessionType_value = map[string]uint16{ "Public": 0, - "Account": 1, + "Wallet": 1, "AccessKey": 2, "Project": 3, - "Admin": 4, - "Service": 5, + "User": 4, + "Admin": 5, + "Service": 6, } func (x SessionType) String() string { diff --git a/proto/proto.go b/proto/proto.go index 42516be..4ddd6a8 100644 --- a/proto/proto.go +++ b/proto/proto.go @@ -8,10 +8,21 @@ import ( "time" ) +const SessionType_Max SessionType = SessionType_Service + 1 + func Ptr[T any](v T) *T { return &v } +// AndUp returns a list of all session types from the current one up to the maximum. +func (s SessionType) OrHigher() []SessionType { + list := make([]SessionType, 0, SessionType_Service-s+1) + for i := s; i < SessionType_Max; i++ { + list = append(list, i) + } + return list +} + func (u *AccessUsage) Add(usage AccessUsage) { u.LimitedCompute += usage.LimitedCompute u.ValidCompute += usage.ValidCompute diff --git a/proto/proto.ridl b/proto/proto.ridl index 4cf66b0..b9b60dd 100644 --- a/proto/proto.ridl +++ b/proto/proto.ridl @@ -10,6 +10,7 @@ enum Service: uint16 - Relayer - Metadata - Marketplace + - Builder struct Limit - maxKeys: int64 @@ -69,9 +70,10 @@ enum EventType: uint16 enum SessionType: uint16 - Public - - Account + - Wallet - AccessKey - Project + - User - Admin - Service diff --git a/test/acl.go b/test/acl.go index 420c8d9..ec048d7 100644 --- a/test/acl.go +++ b/test/acl.go @@ -7,7 +7,7 @@ import ( "github.com/0xsequence/quotacontrol/middleware" ) -func VerifyACL[T any](acl middleware.ACL) error { +func VerifyACL[T any](acl middleware.ServiceConfig[middleware.ACL]) error { var t T iType := reflect.TypeOf(&t).Elem() service := iType.Name() diff --git a/test/acl_test.go b/test/acl_test.go index 469aa39..5877e26 100644 --- a/test/acl_test.go +++ b/test/acl_test.go @@ -14,19 +14,19 @@ func TestVerifyACL(t *testing.T) { Method2() error } - acl := middleware.ACL{ + acl := middleware.ServiceConfig[middleware.ACL]{ "Service": { - "Method1": proto.SessionType_Account, + "Method1": middleware.NewACL(proto.SessionType_Wallet.OrHigher()...), }, } err := VerifyACL[Service](acl) assert.Error(t, err) - acl = middleware.ACL{ + acl = middleware.ServiceConfig[middleware.ACL]{ "Service": { - "Method1": proto.SessionType_Account, - "Method2": proto.SessionType_Account, + "Method1": middleware.NewACL(proto.SessionType_Wallet.OrHigher()...), + "Method2": middleware.NewACL(proto.SessionType_Wallet.OrHigher()...), }, }