Skip to content

Commit

Permalink
VerifyACL and merge acl.Get and ParseRequest (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-littlefarmer authored Oct 24, 2024
1 parent 0bf48a2 commit bc9c31b
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 76 deletions.
77 changes: 30 additions & 47 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"reflect"
"strings"

"github.com/0xsequence/authcontrol/proto"
Expand Down Expand Up @@ -37,44 +37,47 @@ type UserStore interface {
type Config[T any] map[string]map[string]T

// Get returns the config value for the given request.
func (c Config[T]) Get(r *WebRPCRequest) (v T, ok bool) {
func (c Config[T]) Get(path string) (*T, error) {
if c == nil {
return v, false
return nil, fmt.Errorf("cofig is nil")
}

methodCfg, ok := c[r.ServiceName][r.MethodName]
if !ok {
return v, false
p := strings.Split(path, "/")
if len(p) < 4 {
return nil, fmt.Errorf("path has not enough parts")
}

return methodCfg, true
}
var (
packageName = p[len(p)-3]
serviceName = p[len(p)-2]
methodName = p[len(p)-1]
)

// WebRPCRequest is a parsed RPC request.
type WebRPCRequest struct {
PackageName string
ServiceName string
MethodName string
}

// newRequest parses a path into an rcpRequest.
func ParseRequest(path string) *WebRPCRequest {
p := strings.Split(path, "/")
if len(p) < 4 {
return nil
if packageName != "rpc" {
return nil, fmt.Errorf("path doesn't include rpc")
}

r := &WebRPCRequest{
PackageName: p[len(p)-3],
ServiceName: p[len(p)-2],
MethodName: p[len(p)-1],
methodCfg, ok := c[serviceName][methodName]
if !ok {
return nil, fmt.Errorf("acl not found")
}

if r.PackageName != "rpc" {
return nil
return &methodCfg, nil
}

// VerifyACL checks that the given ACL config is valid for the given service.
// It can be used in unit tests to ensure that all methods are covered.
func (acl Config[any]) VerifyACL(webrpcServices map[string][]string) error {
var errList []error
for service, methods := range webrpcServices {
for _, method := range methods {
if _, ok := acl[service][method]; !ok {
errList = append(errList, fmt.Errorf("%s.%s not found", service, method))
}
}
}

return r
return errors.Join(errList...)
}

// ACL is a list of session types, encoded as a bitfield.
Expand Down Expand Up @@ -102,23 +105,3 @@ func (a ACL) And(session ...proto.SessionType) ACL {
func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

// VerifyACL checks that the given ACL config is valid for the given service.
// It can be used in unit tests to ensure that all methods are covered.
func VerifyACL[T any](acl Config[ACL]) error {
var t T
iType := reflect.TypeOf(&t).Elem()
service := iType.Name()
m, ok := acl[service]
if !ok {
return errors.New("service " + service + " not found")
}
var errList []error
for i := 0; i < iType.NumMethod(); i++ {
method := iType.Method(i)
if _, ok := m[method.Name]; !ok {
errList = append(errList, errors.New(""+service+"."+method.Name+" not found"))
}
}
return errors.Join(errList...)
}
74 changes: 54 additions & 20 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@ package authcontrol_test
import (
"context"
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"testing"

"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/0xsequence/authcontrol"
Expand Down Expand Up @@ -58,38 +60,70 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat
}

func TestVerifyACL(t *testing.T) {
type Service interface {
Method1() error
Method2() error
Method3() error
services := map[string][]string{
"Service1": {
"Method1",
"Method2",
"Method3",
},
"Service2": {
"Method1",
},
}

err := authcontrol.VerifyACL[Service](nil)
require.Error(t, err)

err = authcontrol.VerifyACL[Service](authcontrol.Config[authcontrol.ACL]{
"WrongName": {
// Valid ACL config
acl := authcontrol.Config[any]{
"Service1": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
},
})
require.Error(t, err)
"Service2": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
},
}

err := acl.VerifyACL(services)
assert.NoError(t, err)

err = authcontrol.VerifyACL[Service](authcontrol.Config[authcontrol.ACL]{
"Service": {
// Wrong Service
acl = authcontrol.Config[any]{
"WrongService1": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
},
})
"Service2": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
},
}

err = acl.VerifyACL(services)
require.Error(t, err)

err = authcontrol.VerifyACL[Service](authcontrol.Config[authcontrol.ACL]{
"Service": {
expectedErrors := []error{
errors.New("Service1.Method1 not found"),
errors.New("Service1.Method2 not found"),
errors.New("Service1.Method3 not found"),
}
assert.Equal(t, errors.Join(expectedErrors...).Error(), err.Error())

// Wrong Methods
acl = authcontrol.Config[any]{
"Service1": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
},
})
require.NoError(t, err)
"Service2": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
},
}

err = acl.VerifyACL(services)
require.Error(t, err)

expectedErrors = []error{
errors.New("Service1.Method2 not found"),
errors.New("Service1.Method3 not found"),
}
assert.Equal(t, errors.Join(expectedErrors...).Error(), err.Error())
}
12 changes: 3 additions & 9 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,9 @@ func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Han

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := ParseRequest(r.URL.Path)
if req == nil {
eh(r, w, proto.ErrUnauthorized.WithCausef("invalid rpc method"))
return
}

acl, ok := acl.Get(req)
if !ok {
eh(r, w, proto.ErrUnauthorized.WithCausef("rpc method not found"))
acl, err := acl.Get(r.URL.Path)
if err != nil {
eh(r, w, proto.ErrUnauthorized.WithCausef("get acl: %w", err))
return
}

Expand Down

0 comments on commit bc9c31b

Please sign in to comment.