From bc9c31bdaa0c70e423edf1395db5adbf7e7dc771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?David=20Sedl=C3=A1=C4=8Dek?= Date: Thu, 24 Oct 2024 16:04:23 +0200 Subject: [PATCH] VerifyACL and merge acl.Get and ParseRequest (#4) --- common.go | 77 ++++++++++++++++++++------------------------------ common_test.go | 74 +++++++++++++++++++++++++++++++++++------------- middleware.go | 12 ++------ 3 files changed, 87 insertions(+), 76 deletions(-) diff --git a/common.go b/common.go index 8696666..943d2d2 100644 --- a/common.go +++ b/common.go @@ -4,8 +4,8 @@ import ( "context" "encoding/json" "errors" + "fmt" "net/http" - "reflect" "strings" "github.com/0xsequence/authcontrol/proto" @@ -37,44 +37,47 @@ type UserStore interface { type Config[T any] map[string]map[string]T // Get returns the config value for the given request. -func (c Config[T]) Get(r *WebRPCRequest) (v T, ok bool) { +func (c Config[T]) Get(path string) (*T, error) { if c == nil { - return v, false + return nil, fmt.Errorf("cofig is nil") } - methodCfg, ok := c[r.ServiceName][r.MethodName] - if !ok { - return v, false + p := strings.Split(path, "/") + if len(p) < 4 { + return nil, fmt.Errorf("path has not enough parts") } - return methodCfg, true -} + var ( + packageName = p[len(p)-3] + serviceName = p[len(p)-2] + methodName = p[len(p)-1] + ) -// WebRPCRequest is a parsed RPC request. -type WebRPCRequest struct { - PackageName string - ServiceName string - MethodName string -} - -// newRequest parses a path into an rcpRequest. -func ParseRequest(path string) *WebRPCRequest { - p := strings.Split(path, "/") - if len(p) < 4 { - return nil + if packageName != "rpc" { + return nil, fmt.Errorf("path doesn't include rpc") } - r := &WebRPCRequest{ - PackageName: p[len(p)-3], - ServiceName: p[len(p)-2], - MethodName: p[len(p)-1], + methodCfg, ok := c[serviceName][methodName] + if !ok { + return nil, fmt.Errorf("acl not found") } - if r.PackageName != "rpc" { - return nil + return &methodCfg, nil +} + +// VerifyACL checks that the given ACL config is valid for the given service. +// It can be used in unit tests to ensure that all methods are covered. +func (acl Config[any]) VerifyACL(webrpcServices map[string][]string) error { + var errList []error + for service, methods := range webrpcServices { + for _, method := range methods { + if _, ok := acl[service][method]; !ok { + errList = append(errList, fmt.Errorf("%s.%s not found", service, method)) + } + } } - return r + return errors.Join(errList...) } // ACL is a list of session types, encoded as a bitfield. @@ -102,23 +105,3 @@ func (a ACL) And(session ...proto.SessionType) ACL { func (t ACL) Includes(session proto.SessionType) bool { return t&ACL(1<