Skip to content

Commit

Permalink
add active subscription to context (#69)
Browse files Browse the repository at this point in the history
  • Loading branch information
matoszz authored Nov 18, 2024
1 parent 3186ed9 commit 161bcd3
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 12 deletions.
54 changes: 53 additions & 1 deletion auth/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ type AuthenticatedUser struct {
OrganizationIDs []string
// AuthenticationType is the type of authentication used to authenticate the user (JWT, PAT, API Token)
AuthenticationType AuthenticationType
// ActiveSubscription is the active subscription for the user
ActiveSubscription bool
}

// GetContextName returns the name of the context key
Expand Down Expand Up @@ -310,7 +312,7 @@ func AddOrganizationIDToContext(ctx context.Context, orgID string) error {
return addOrganizationIDsToEchoContext(ec, orgID)
}

// getOrganizationIDsFromEchoContext appends an authorized organization ID to the echo context
// addOrganizationIDsToEchoContext appends an authorized organization ID to the echo context
func addOrganizationIDsToEchoContext(c echo.Context, orgID string) error {
if v := c.Get(ContextAuthenticatedUser.name); v != nil {
a, ok := v.(*AuthenticatedUser)
Expand Down Expand Up @@ -370,3 +372,53 @@ func GetRefreshTokenContext(c context.Context) (string, error) {

return token, nil
}

// AddSubscriptionToContext appends a subscription to the context
func AddSubscriptionToContext(ctx context.Context, subscription bool) error {
ec, err := echocontext.EchoContextFromContext(ctx)
if err != nil {
return err
}

return addSubscriptionToEchoContext(ec, subscription)
}

// addSubscriptionToEchoContext appends a subscription to the echo context
func addSubscriptionToEchoContext(c echo.Context, subscription bool) error {
if v := c.Get(ContextAuthenticatedUser.name); v != nil {
a, ok := v.(*AuthenticatedUser)
if !ok {
return ErrNoAuthUser
}

a.ActiveSubscription = subscription

return nil
}

return ErrNoAuthUser
}

// getSubscriptionFromContext returns the active subscription from the echo context
func getSubscriptionFromContext(c echo.Context) (bool, error) {
if v := c.Get(ContextAuthenticatedUser.name); v != nil {
a, ok := v.(*AuthenticatedUser)
if !ok {
return false, ErrNoAuthUser
}

return a.ActiveSubscription, nil
}

return false, nil
}

// GetSubscriptionFromContext returns the active subscription from the context
func GetSubscriptionFromContext(ctx context.Context) (bool, error) {
ec, err := echocontext.EchoContextFromContext(ctx)
if err != nil {
return false, err
}

return getSubscriptionFromContext(ec)
}
52 changes: 52 additions & 0 deletions auth/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,55 @@ func TestGetOrganizationIDFromContext(t *testing.T) {
})
}
}

func TestGetSubscriptionFromContext(t *testing.T) {
validsubscription := true
invalidsubscription := false

ec := echocontext.NewTestEchoContext()

basicContext := context.WithValue(ec.Request().Context(), echocontext.EchoContextKey, ec)

ec.SetRequest(ec.Request().WithContext(basicContext))

invalidCtx, err := auth.NewTestContextWithValidUser(ulids.Null.String())
if err != nil {
t.Fatal()
}

validCtx, err := auth.NewTestContextWithValidUser(ulids.New().String())
if err != nil {
t.Fatal()
}

if err := auth.AddSubscriptionToContext(validCtx, true); err != nil {
t.Fatal(err)
}

testCases := []struct {
name string
ctx context.Context
expect bool
}{
{
name: "happy path",
ctx: invalidCtx,
expect: invalidsubscription,
},
{
name: "MITB BABBYYYYY",
ctx: validCtx,
expect: validsubscription,
},
}

for _, tc := range testCases {
t.Run("Get "+tc.name, func(t *testing.T) {
got, err := auth.GetSubscriptionFromContext(tc.ctx)

assert.NoError(t, err)

assert.Equal(t, tc.expect, got)
})
}
}
12 changes: 2 additions & 10 deletions auth/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,26 @@ import (
var (
// ErrNoClaims is returned when no claims are found on the request context
ErrNoClaims = errors.New("no claims found on the request context")

// ErrNoUserInfo is returned when no user info is found on the request context
ErrNoUserInfo = errors.New("no user info found on the request context")

// ErrNoAuthUser is returned when no authenticated user is found on the request context
ErrNoAuthUser = errors.New("could not identify authenticated user in request")

// ErrUnverifiedUser is returned when the user is not verified
ErrUnverifiedUser = errors.New("user is not verified")

// ErrParseBearer is returned when the bearer token could not be parsed from the authorization header
ErrParseBearer = errors.New("could not parse bearer token from authorization header")

// ErrNoAuthorization is returned when no authorization header is found in the request
ErrNoAuthorization = errors.New("no authorization header in request")

// ErrNoRequest is returned when no request is found on the context
ErrNoRequest = errors.New("no request found on the context")

// ErrNoRefreshToken is returned when no refresh token is found on the request
ErrNoRefreshToken = errors.New("no refresh token available on request")

// ErrRefreshDisabled is returned when re-authentication with refresh tokens is disabled
ErrRefreshDisabled = errors.New("re-authentication with refresh tokens disabled")

// ErrUnableToConstructValidator is returned when the validator cannot be constructed
ErrUnableToConstructValidator = errors.New("unable to construct validator")

// ErrPasswordTooWeak is returned when the password is too weak
ErrPasswordTooWeak = errors.New("password is too weak: use a combination of upper and lower case letters, numbers, and special characters")
// ErrCouldNotFetchSubscription is returned when the subscription could not be fetched
ErrCouldNotFetchSubscription = errors.New("could not fetch subscription")
)
34 changes: 33 additions & 1 deletion auth/test_tools.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (

"github.com/golang-jwt/jwt/v5"
echo "github.com/theopenlane/echox"

"github.com/theopenlane/echox/middleware/echocontext"
"github.com/theopenlane/utils/ulids"

"github.com/theopenlane/iam/tokens"
)
Expand Down Expand Up @@ -99,6 +99,7 @@ func NewTestEchoContextWithOrgID(sub, orgID string) (echo.Context, error) {
return ec, nil
}

// NewTestContextWithOrgID creates a context with a fake orgID for testing purposes only (why all caps jeez keep it down)
func NewTestContextWithOrgID(sub, orgID string) (context.Context, error) {
ec, err := NewTestEchoContextWithOrgID(sub, orgID)
if err != nil {
Expand All @@ -111,3 +112,34 @@ func NewTestContextWithOrgID(sub, orgID string) (context.Context, error) {

return reqCtx, nil
}

// NewTestEchoContextWithOrgID creates an echo context with a fake orgID for testing purposes ONLY
func NewTestEchoContextWithSubscription(subscription bool) (echo.Context, error) {
ec := echocontext.NewTestEchoContext()

claims := newValidClaimsOrgID(ulids.New().String(), ulids.New().String())

SetAuthenticatedUserContext(ec, &AuthenticatedUser{
SubjectID: claims.UserID,
OrganizationID: claims.OrgID,
OrganizationIDs: []string{claims.OrgID},
AuthenticationType: "jwt",
ActiveSubscription: subscription,
})

return ec, nil
}

// NewTestContextWithOrgID creates a context with a fake orgID for testing purposes only (why all caps jeez keep it down)
func NewTestContextWithSubscription(subscription bool) (context.Context, error) {
ec, err := NewTestEchoContextWithSubscription(subscription)
if err != nil {
return nil, err
}

reqCtx := context.WithValue(ec.Request().Context(), echocontext.EchoContextKey, ec)

ec.SetRequest(ec.Request().WithContext(reqCtx))

return reqCtx, nil
}

0 comments on commit 161bcd3

Please sign in to comment.