From 161bcd375f6e884becb679e7bdea614c63d684ab Mon Sep 17 00:00:00 2001 From: Matt Anderson <42154938+matoszz@users.noreply.github.com> Date: Mon, 18 Nov 2024 08:20:15 -0700 Subject: [PATCH] add active subscription to context (#69) --- auth/context.go | 54 +++++++++++++++++++++++++++++++++++++++++++- auth/context_test.go | 52 ++++++++++++++++++++++++++++++++++++++++++ auth/errors.go | 12 ++-------- auth/test_tools.go | 34 +++++++++++++++++++++++++++- 4 files changed, 140 insertions(+), 12 deletions(-) diff --git a/auth/context.go b/auth/context.go index ea8c3e3..0d14a12 100644 --- a/auth/context.go +++ b/auth/context.go @@ -52,6 +52,8 @@ type AuthenticatedUser struct { OrganizationIDs []string // AuthenticationType is the type of authentication used to authenticate the user (JWT, PAT, API Token) AuthenticationType AuthenticationType + // ActiveSubscription is the active subscription for the user + ActiveSubscription bool } // GetContextName returns the name of the context key @@ -310,7 +312,7 @@ func AddOrganizationIDToContext(ctx context.Context, orgID string) error { return addOrganizationIDsToEchoContext(ec, orgID) } -// getOrganizationIDsFromEchoContext appends an authorized organization ID to the echo context +// addOrganizationIDsToEchoContext appends an authorized organization ID to the echo context func addOrganizationIDsToEchoContext(c echo.Context, orgID string) error { if v := c.Get(ContextAuthenticatedUser.name); v != nil { a, ok := v.(*AuthenticatedUser) @@ -370,3 +372,53 @@ func GetRefreshTokenContext(c context.Context) (string, error) { return token, nil } + +// AddSubscriptionToContext appends a subscription to the context +func AddSubscriptionToContext(ctx context.Context, subscription bool) error { + ec, err := echocontext.EchoContextFromContext(ctx) + if err != nil { + return err + } + + return addSubscriptionToEchoContext(ec, subscription) +} + +// addSubscriptionToEchoContext appends a subscription to the echo context +func addSubscriptionToEchoContext(c echo.Context, subscription bool) error { + if v := c.Get(ContextAuthenticatedUser.name); v != nil { + a, ok := v.(*AuthenticatedUser) + if !ok { + return ErrNoAuthUser + } + + a.ActiveSubscription = subscription + + return nil + } + + return ErrNoAuthUser +} + +// getSubscriptionFromContext returns the active subscription from the echo context +func getSubscriptionFromContext(c echo.Context) (bool, error) { + if v := c.Get(ContextAuthenticatedUser.name); v != nil { + a, ok := v.(*AuthenticatedUser) + if !ok { + return false, ErrNoAuthUser + } + + return a.ActiveSubscription, nil + } + + return false, nil +} + +// GetSubscriptionFromContext returns the active subscription from the context +func GetSubscriptionFromContext(ctx context.Context) (bool, error) { + ec, err := echocontext.EchoContextFromContext(ctx) + if err != nil { + return false, err + } + + return getSubscriptionFromContext(ec) +} diff --git a/auth/context_test.go b/auth/context_test.go index a86a459..9218576 100644 --- a/auth/context_test.go +++ b/auth/context_test.go @@ -129,3 +129,55 @@ func TestGetOrganizationIDFromContext(t *testing.T) { }) } } + +func TestGetSubscriptionFromContext(t *testing.T) { + validsubscription := true + invalidsubscription := false + + ec := echocontext.NewTestEchoContext() + + basicContext := context.WithValue(ec.Request().Context(), echocontext.EchoContextKey, ec) + + ec.SetRequest(ec.Request().WithContext(basicContext)) + + invalidCtx, err := auth.NewTestContextWithValidUser(ulids.Null.String()) + if err != nil { + t.Fatal() + } + + validCtx, err := auth.NewTestContextWithValidUser(ulids.New().String()) + if err != nil { + t.Fatal() + } + + if err := auth.AddSubscriptionToContext(validCtx, true); err != nil { + t.Fatal(err) + } + + testCases := []struct { + name string + ctx context.Context + expect bool + }{ + { + name: "happy path", + ctx: invalidCtx, + expect: invalidsubscription, + }, + { + name: "MITB BABBYYYYY", + ctx: validCtx, + expect: validsubscription, + }, + } + + for _, tc := range testCases { + t.Run("Get "+tc.name, func(t *testing.T) { + got, err := auth.GetSubscriptionFromContext(tc.ctx) + + assert.NoError(t, err) + + assert.Equal(t, tc.expect, got) + }) + } +} diff --git a/auth/errors.go b/auth/errors.go index 998d721..bf9fce3 100644 --- a/auth/errors.go +++ b/auth/errors.go @@ -7,34 +7,26 @@ import ( var ( // ErrNoClaims is returned when no claims are found on the request context ErrNoClaims = errors.New("no claims found on the request context") - // ErrNoUserInfo is returned when no user info is found on the request context ErrNoUserInfo = errors.New("no user info found on the request context") - // ErrNoAuthUser is returned when no authenticated user is found on the request context ErrNoAuthUser = errors.New("could not identify authenticated user in request") - // ErrUnverifiedUser is returned when the user is not verified ErrUnverifiedUser = errors.New("user is not verified") - // ErrParseBearer is returned when the bearer token could not be parsed from the authorization header ErrParseBearer = errors.New("could not parse bearer token from authorization header") - // ErrNoAuthorization is returned when no authorization header is found in the request ErrNoAuthorization = errors.New("no authorization header in request") - // ErrNoRequest is returned when no request is found on the context ErrNoRequest = errors.New("no request found on the context") - // ErrNoRefreshToken is returned when no refresh token is found on the request ErrNoRefreshToken = errors.New("no refresh token available on request") - // ErrRefreshDisabled is returned when re-authentication with refresh tokens is disabled ErrRefreshDisabled = errors.New("re-authentication with refresh tokens disabled") - // ErrUnableToConstructValidator is returned when the validator cannot be constructed ErrUnableToConstructValidator = errors.New("unable to construct validator") - // ErrPasswordTooWeak is returned when the password is too weak ErrPasswordTooWeak = errors.New("password is too weak: use a combination of upper and lower case letters, numbers, and special characters") + // ErrCouldNotFetchSubscription is returned when the subscription could not be fetched + ErrCouldNotFetchSubscription = errors.New("could not fetch subscription") ) diff --git a/auth/test_tools.go b/auth/test_tools.go index 8972830..8da4fa9 100644 --- a/auth/test_tools.go +++ b/auth/test_tools.go @@ -6,8 +6,8 @@ import ( "github.com/golang-jwt/jwt/v5" echo "github.com/theopenlane/echox" - "github.com/theopenlane/echox/middleware/echocontext" + "github.com/theopenlane/utils/ulids" "github.com/theopenlane/iam/tokens" ) @@ -99,6 +99,7 @@ func NewTestEchoContextWithOrgID(sub, orgID string) (echo.Context, error) { return ec, nil } +// NewTestContextWithOrgID creates a context with a fake orgID for testing purposes only (why all caps jeez keep it down) func NewTestContextWithOrgID(sub, orgID string) (context.Context, error) { ec, err := NewTestEchoContextWithOrgID(sub, orgID) if err != nil { @@ -111,3 +112,34 @@ func NewTestContextWithOrgID(sub, orgID string) (context.Context, error) { return reqCtx, nil } + +// NewTestEchoContextWithOrgID creates an echo context with a fake orgID for testing purposes ONLY +func NewTestEchoContextWithSubscription(subscription bool) (echo.Context, error) { + ec := echocontext.NewTestEchoContext() + + claims := newValidClaimsOrgID(ulids.New().String(), ulids.New().String()) + + SetAuthenticatedUserContext(ec, &AuthenticatedUser{ + SubjectID: claims.UserID, + OrganizationID: claims.OrgID, + OrganizationIDs: []string{claims.OrgID}, + AuthenticationType: "jwt", + ActiveSubscription: subscription, + }) + + return ec, nil +} + +// NewTestContextWithOrgID creates a context with a fake orgID for testing purposes only (why all caps jeez keep it down) +func NewTestContextWithSubscription(subscription bool) (context.Context, error) { + ec, err := NewTestEchoContextWithSubscription(subscription) + if err != nil { + return nil, err + } + + reqCtx := context.WithValue(ec.Request().Context(), echocontext.EchoContextKey, ec) + + ec.SetRequest(ec.Request().WithContext(reqCtx)) + + return reqCtx, nil +}