Skip to content

Commit

Permalink
Pass options separately (#2)
Browse files Browse the repository at this point in the history
* Pass options separately

* Restore options

* godoc

* don't export

* move struct back to file

* Showcase VerifyACL

* uses ctx in handler

* renmae project to projectID

* fix name
  • Loading branch information
klaidliadon authored Oct 24, 2024
1 parent 989c1fc commit 0bf48a2
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 36 deletions.
56 changes: 40 additions & 16 deletions common.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,17 @@ package authcontrol
import (
"context"
"encoding/json"
"errors"
"net/http"
"reflect"
"strings"

"github.com/0xsequence/authcontrol/proto"
)

func defaultErrHandler(r *http.Request, w http.ResponseWriter, err error) {
type ErrHandler func(r *http.Request, w http.ResponseWriter, err error)

func errHandler(r *http.Request, w http.ResponseWriter, err error) {
rpcErr, ok := err.(proto.WebRPCError)
if !ok {
rpcErr = proto.ErrWebrpcEndpoint.WithCause(err)
Expand All @@ -32,45 +36,45 @@ type UserStore interface {
// map[service]map[method]T
type Config[T any] map[string]map[string]T

// returns the config value for the given request.
func (c Config[T]) Get(r *rcpRequest) (v T, ok bool) {
// Get returns the config value for the given request.
func (c Config[T]) Get(r *WebRPCRequest) (v T, ok bool) {
if c == nil {
return v, false
}

methodCfg, ok := c[r.serviceName][r.methodName]
methodCfg, ok := c[r.ServiceName][r.MethodName]
if !ok {
return v, false
}

return methodCfg, true
}

// rcpRequest is a parsed RPC request.
type rcpRequest struct {
packageName string
serviceName string
methodName string
// WebRPCRequest is a parsed RPC request.
type WebRPCRequest struct {
PackageName string
ServiceName string
MethodName string
}

// newRequest parses a path into an rcpRequest.
func newRequest(path string) *rcpRequest {
func ParseRequest(path string) *WebRPCRequest {
p := strings.Split(path, "/")
if len(p) < 4 {
return nil
}

t := &rcpRequest{
packageName: p[len(p)-3],
serviceName: p[len(p)-2],
methodName: p[len(p)-1],
r := &WebRPCRequest{
PackageName: p[len(p)-3],
ServiceName: p[len(p)-2],
MethodName: p[len(p)-1],
}

if t.packageName != "rpc" {
if r.PackageName != "rpc" {
return nil
}

return t
return r
}

// ACL is a list of session types, encoded as a bitfield.
Expand Down Expand Up @@ -98,3 +102,23 @@ func (a ACL) And(session ...proto.SessionType) ACL {
func (t ACL) Includes(session proto.SessionType) bool {
return t&ACL(1<<session) != 0
}

// VerifyACL checks that the given ACL config is valid for the given service.
// It can be used in unit tests to ensure that all methods are covered.
func VerifyACL[T any](acl Config[ACL]) error {
var t T
iType := reflect.TypeOf(&t).Elem()
service := iType.Name()
m, ok := acl[service]
if !ok {
return errors.New("service " + service + " not found")
}
var errList []error
for i := 0; i < iType.NumMethod(); i++ {
method := iType.Method(i)
if _, ok := m[method.Name]; !ok {
errList = append(errList, errors.New(""+service+"."+method.Name+" not found"))
}
}
return errors.Join(errList...)
}
38 changes: 38 additions & 0 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/require"

"github.com/0xsequence/authcontrol"
"github.com/0xsequence/authcontrol/proto"
)

Expand Down Expand Up @@ -55,3 +56,40 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat

return true, nil
}

func TestVerifyACL(t *testing.T) {
type Service interface {
Method1() error
Method2() error
Method3() error
}

err := authcontrol.VerifyACL[Service](nil)
require.Error(t, err)

err = authcontrol.VerifyACL[Service](authcontrol.Config[authcontrol.ACL]{
"WrongName": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
},
})
require.Error(t, err)

err = authcontrol.VerifyACL[Service](authcontrol.Config[authcontrol.ACL]{
"Service": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
},
})
require.Error(t, err)

err = authcontrol.VerifyACL[Service](authcontrol.Config[authcontrol.ACL]{
"Service": {
"Method1": authcontrol.NewACL(proto.SessionType_User),
"Method2": authcontrol.NewACL(proto.SessionType_User),
"Method3": authcontrol.NewACL(proto.SessionType_User),
},
})
require.NoError(t, err)
}
8 changes: 4 additions & 4 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ func GetAccessKey(ctx context.Context) (string, bool) {
// Project ID
//

// withProjectID adds the projectID to the context.
func withProjectID(ctx context.Context, projectID uint64) context.Context {
return context.WithValue(ctx, ctxKeyProjectID, projectID)
// withProjectID adds the project to the context.
func withProjectID(ctx context.Context, project uint64) context.Context {
return context.WithValue(ctx, ctxKeyProjectID, project)
}

// GetProjectID returns the projectID and if its active from the context.
// GetProjectID returns the project and if its active from the context.
// In case its not set, it will return 0.
func GetProjectID(ctx context.Context) (uint64, bool) {
v, ok := ctx.Value(ctxKeyProjectID).(uint64)
Expand Down
21 changes: 9 additions & 12 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,15 @@ import (
type Options struct {
KeyFuncs []KeyFunc
UserStore UserStore
ErrHandler func(r *http.Request, w http.ResponseWriter, err error)
ErrHandler ErrHandler
}

func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Handler {
eh := defaultErrHandler
eh := errHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
}

var keyFuncs []KeyFunc
if o != nil {
keyFuncs = o.KeyFuncs
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
Expand All @@ -43,9 +38,11 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
token jwt.Token
)

for _, f := range keyFuncs {
if accessKey = f(r); accessKey != "" {
break
if o != nil {
for _, f := range o.KeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
}
}

Expand Down Expand Up @@ -125,14 +122,14 @@ func Session(auth *jwtauth.JWTAuth, o *Options) func(next http.Handler) http.Han
// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Handler {
eh := defaultErrHandler
eh := errHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
}

return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
req := newRequest(r.URL.Path)
req := ParseRequest(r.URL.Path)
if req == nil {
eh(r, w, proto.ErrUnauthorized.WithCausef("invalid rpc method"))
return
Expand Down
18 changes: 14 additions & 4 deletions middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func TestSession(t *testing.T) {

auth := jwtauth.New("HS256", []byte("secret"), nil)

options := authcontrol.Options{
options := &authcontrol.Options{
UserStore: mockStore{
UserAddress: false,
AdminAddress: true,
Expand All @@ -89,8 +89,8 @@ func TestSession(t *testing.T) {

r := chi.NewRouter()
r.Use(
authcontrol.Session(auth, &options),
authcontrol.AccessControl(ACLConfig, &options),
authcontrol.Session(auth, options),
authcontrol.AccessControl(ACLConfig, options),
)
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

Expand Down Expand Up @@ -193,7 +193,17 @@ func TestInvalid(t *testing.T) {
authcontrol.Session(auth, options),
authcontrol.AccessControl(ACLConfig, options),
)
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
r.Handle("/*", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
resp := map[string]any{}
resp["accessKey"], _ = authcontrol.GetAccessKey(ctx)
resp["account"], _ = authcontrol.GetAccount(ctx)
resp["project"], _ = authcontrol.GetProjectID(ctx)
resp["service"], _ = authcontrol.GetService(ctx)
resp["session"], _ = authcontrol.GetSessionType(ctx)
resp["user"], _ = authcontrol.GetUser[any](ctx)
assert.NoError(t, json.NewEncoder(w).Encode(resp))
}))

// Without JWT
ok, err := executeRequest(t, ctx, r, fmt.Sprintf("/rpc/%s/%s", ServiceName, MethodName), AccessKey, nil)
Expand Down

0 comments on commit 0bf48a2

Please sign in to comment.