Skip to content

Commit

Permalink
s2s token, s2s client, errors
Browse files Browse the repository at this point in the history
  • Loading branch information
david-littlefarmer committed Oct 24, 2024
1 parent fe8f8fe commit 83b0f41
Show file tree
Hide file tree
Showing 7 changed files with 220 additions and 55 deletions.
6 changes: 3 additions & 3 deletions common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func keyFunc(r *http.Request) string {
return r.Header.Get(HeaderKey)
}

func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt string) (bool, error) {
func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, path, accessKey string, jwt *string) (bool, error) {
req, err := http.NewRequest("POST", path, nil)
require.NoError(t, err)

Expand All @@ -30,8 +30,8 @@ func executeRequest(t *testing.T, ctx context.Context, handler http.Handler, pat
req.Header.Set(HeaderKey, accessKey)
}

if jwt != "" {
req.Header.Set("Authorization", "Bearer "+jwt)
if jwt != nil {
req.Header.Set("Authorization", "Bearer "+*jwt)
}

rr := httptest.NewRecorder()
Expand Down
8 changes: 8 additions & 0 deletions errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package authcontrol

import "fmt"

var (
ErrEmptyJWTSecret error = fmt.Errorf("JWTSecret is empty")
ErrS2SClientConfigIsNil error = fmt.Errorf("S2SClientConfig is nil")
)
84 changes: 63 additions & 21 deletions http.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package authcontrol

import (
"maps"
"cmp"
"net/http"
"os"
"time"

"github.com/go-chi/jwtauth/v5"
Expand All @@ -12,47 +13,88 @@ import (
)

type S2SClientConfig struct {
Service string
ServiceName string
JWTSecret string
DebugRequests bool
Expiration time.Duration
}

func (cfg *S2SClientConfig) Validate() error {
if cfg.JWTSecret == "" {
return ErrEmptyJWTSecret
}

return nil
}

// Service-to-service HTTP client for internal communication between Sequence services.
func S2SClient(cfg *S2SClientConfig) *http.Client {
func S2SClient(cfg *S2SClientConfig) (*http.Client, error) {
if cfg == nil {
return nil, ErrS2SClientConfigIsNil
}

if cfg.JWTSecret == "" {
return nil, ErrEmptyJWTSecret
}

tokenCfg := &S2STokenConfig{
JWTSecret: cfg.JWTSecret,
ServiceName: cfg.ServiceName,
Expiration: cfg.Expiration,
}

httpClient := &http.Client{
Transport: transport.Chain(http.DefaultTransport,
traceid.Transport,
transport.SetHeaderFunc("Authorization", s2sAuthHeader(cfg.JWTSecret, map[string]any{"service": cfg.Service})),
transport.SetHeaderFunc("Authorization", s2sAuthHeader(tokenCfg)),
transport.If(cfg.DebugRequests, transport.LogRequests(transport.LogOptions{Concise: true, CURL: true})),
),
}

return httpClient
return httpClient, nil
}

// Create short-lived service-to-service JWT token for internal communication between Sequence services.
func S2SToken(jwtSecret string, claims map[string]any) string {
jwtAuth := jwtauth.New("HS256", []byte(jwtSecret), nil, jwt.WithAcceptableSkew(2*time.Minute))
func s2sAuthHeader(cfg *S2STokenConfig) func(req *http.Request) string {
return func(req *http.Request) string {
return "BEARER " + S2SToken(cfg)
}
}

now := time.Now().UTC()
const (
defaultExpiration time.Duration = 30 * time.Second
acceptableSkew time.Duration = 2 * time.Minute
)

type S2STokenConfig struct {
JWTSecret string
ServiceName string
Expiration time.Duration
}

c := maps.Clone(claims)
if c == nil {
c = map[string]any{}
func (cfg *S2STokenConfig) Validate() error {
if cfg.JWTSecret == "" {
return ErrEmptyJWTSecret
}

c["iat"] = now
return nil
}

if _, ok := c["exp"]; !ok {
c["exp"] = now.Add(30 * time.Second)
// Create short-lived service-to-service JWT token for internal communication between Sequence services with HS256 algorithm.
func S2SToken(cfg *S2STokenConfig) string {
if cfg == nil {
return ""
}

_, t, _ := jwtAuth.Encode(c)
return t
}
jwtAuth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil, jwt.WithAcceptableSkew(acceptableSkew))

func s2sAuthHeader(jwtSecret string, claims map[string]any) func(req *http.Request) string {
return func(req *http.Request) string {
return "BEARER " + S2SToken(jwtSecret, claims)
now := time.Now().UTC()
claims := map[string]any{
"service": cmp.Or(cfg.ServiceName, os.Args[0]),
"iat": now,
"exp": now.Add(cmp.Or(cfg.Expiration, defaultExpiration)),
}

_, t, _ := jwtAuth.Encode(claims)

return t
}
98 changes: 98 additions & 0 deletions http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package authcontrol_test

import (
"context"
"os"
"testing"
"time"

"github.com/go-chi/jwtauth/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/0xsequence/authcontrol"
)

func TestS2SClient(t *testing.T) {
secret := "secret"
serviceName := "test-service-name"

cfg := &authcontrol.S2SClientConfig{
JWTSecret: secret,
ServiceName: serviceName,
Expiration: 10 * time.Second,
}

err := cfg.Validate()
require.NoError(t, err)

s2sClient, err := authcontrol.S2SClient(cfg)
require.NoError(t, err)
require.NotNil(t, s2sClient)

s2sClient, err = authcontrol.S2SClient(nil)
require.Error(t, err)
require.ErrorIs(t, err, authcontrol.ErrS2SClientConfigIsNil)
require.Nil(t, s2sClient)

cfg = &authcontrol.S2SClientConfig{
JWTSecret: "",
}
s2sClient, err = authcontrol.S2SClient(cfg)
require.Error(t, err)
require.ErrorIs(t, err, authcontrol.ErrEmptyJWTSecret)
require.Nil(t, s2sClient)

cfg = &authcontrol.S2SClientConfig{
JWTSecret: "",
}
err = cfg.Validate()
require.Error(t, err)
require.ErrorIs(t, err, authcontrol.ErrEmptyJWTSecret)
}

func TestS2SToken(t *testing.T) {
ctx := context.Background()
secret := "secret"
serviceName := "test-service-name"

cfg := &authcontrol.S2STokenConfig{
JWTSecret: secret,
ServiceName: serviceName,
Expiration: 10 * time.Second,
}

err := cfg.Validate()
require.NoError(t, err)

jwtAut := jwtauth.New("HS256", []byte(secret), nil)
jwtToken := authcontrol.S2SToken(cfg)

token, err := jwtauth.VerifyToken(jwtAut, jwtToken)
require.NoError(t, err)

claims, err := token.AsMap(ctx)
require.NoError(t, err)

cServiceName := claims["service"].(string)
assert.Equal(t, serviceName, cServiceName)

cfg = &authcontrol.S2STokenConfig{
JWTSecret: secret,
Expiration: 10 * time.Second,
}

jwtToken = authcontrol.S2SToken(cfg)

token, err = jwtauth.VerifyToken(jwtAut, jwtToken)
require.NoError(t, err)

claims, err = token.AsMap(ctx)
require.NoError(t, err)

cServiceName = claims["service"].(string)
assert.Equal(t, os.Args[0], cServiceName)

jwtToken = authcontrol.S2SToken(nil)
assert.Equal(t, "", jwtToken)
}
30 changes: 19 additions & 11 deletions middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@ type Options struct {
ErrHandler ErrHandler
}

func Session(cfg *Options) func(next http.Handler) http.Handler {
auth := jwtauth.New("HS256", []byte(cfg.JWTSecret), nil)
func (o *Options) Validate() error {
if o.JWTSecret == "" {
return ErrEmptyJWTSecret
}

return nil
}

func Session(o *Options) func(next http.Handler) http.Handler {
auth := jwtauth.New("HS256", []byte(o.JWTSecret), nil)

eh := errHandler
if cfg != nil && cfg.ErrHandler != nil {
eh = cfg.ErrHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
}

return func(next http.Handler) http.Handler {
Expand All @@ -41,8 +49,8 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {
token jwt.Token
)

if cfg != nil {
for _, f := range cfg.KeyFuncs {
if o != nil {
for _, f := range o.KeyFuncs {
if accessKey = f(r); accessKey != "" {
break
}
Expand Down Expand Up @@ -82,8 +90,8 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {
ctx = withAccount(ctx, accountClaim)
sessionType = proto.SessionType_Wallet

if cfg != nil && cfg.UserStore != nil {
user, isAdmin, err := cfg.UserStore.GetUser(ctx, accountClaim)
if o != nil && o.UserStore != nil {
user, isAdmin, err := o.UserStore.GetUser(ctx, accountClaim)
if err != nil {
eh(r, w, err)
return
Expand Down Expand Up @@ -124,10 +132,10 @@ func Session(cfg *Options) func(next http.Handler) http.Handler {

// AccessControl middleware that checks if the session type is allowed to access the endpoint.
// It also sets the compute units on the context if the endpoint requires it.
func AccessControl(acl Config[ACL], cfg *Options) func(next http.Handler) http.Handler {
func AccessControl(acl Config[ACL], o *Options) func(next http.Handler) http.Handler {
eh := errHandler
if cfg != nil && cfg.ErrHandler != nil {
eh = cfg.ErrHandler
if o != nil && o.ErrHandler != nil {
eh = o.ErrHandler
}

return func(next http.Handler) http.Handler {
Expand Down
Loading

0 comments on commit 83b0f41

Please sign in to comment.