Skip to content

Commit

Permalink
Merge pull request #148 from dekart-xyz/support-google-oauth-dev-email
Browse files Browse the repository at this point in the history
Support dev email for Google OAuth flow
  • Loading branch information
delfrrr authored Nov 26, 2023
2 parents 79e92a3 + 2742406 commit d97e708
Showing 1 changed file with 70 additions and 20 deletions.
90 changes: 70 additions & 20 deletions src/server/user/claims.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"net/url"
"strings"
"sync"
"time"

"github.com/golang-jwt/jwt"
"github.com/rs/zerolog/log"
Expand Down Expand Up @@ -55,32 +56,39 @@ type ClaimsCheck struct {

var b2i = map[bool]int{false: 0, true: 1}

// NewClaimsCheck creates Context
func NewClaimsCheck(c ClaimsCheckConfig) ClaimsCheck {
func validateConfig(c ClaimsCheckConfig) {
if b2i[c.RequireIAP]+b2i[c.RequireAmazonOIDC]+b2i[c.RequireGoogleOAuth] > 1 {
log.Fatal().Msg("DEKART_REQUIRE_IAP and DEKART_REQUIRE_AMAZON_OIDC and DEKART_REQUIRE_GOOGLE_OAUTH are mutually exclusive")
} else if c.RequireIAP {
}
switch {
case c.RequireIAP:
log.Info().Msgf("Dekart configured to require IAP")
} else if c.RequireAmazonOIDC {
case c.RequireAmazonOIDC:
log.Info().Msgf("Dekart configured to require Amazon OIDC")
if c.Region == "" {
log.Fatal().Msgf("Dekart AWS_REGION is required for OIDC")
}
} else if c.RequireGoogleOAuth {
if c.GoogleOAuthClientId == "" {
log.Fatal().Msgf("Dekart DEKART_GOOGLE_OAUTH_CLIENT_ID is required for Google OAuth")
}
if c.GoogleOAuthSecret == "" {
log.Fatal().Msgf("Dekart DEKART_GOOGLE_OAUTH_SECRET is required for Google OAuth")
case c.RequireGoogleOAuth:
if c.DevClaimsEmail == "" {
if c.GoogleOAuthClientId == "" {
log.Fatal().Msgf("Dekart DEKART_GOOGLE_OAUTH_CLIENT_ID is required for Google OAuth")
}
if c.GoogleOAuthSecret == "" {
log.Fatal().Msgf("Dekart DEKART_GOOGLE_OAUTH_SECRET is required for Google OAuth")
}
}
} else {
log.Info().Msgf("Dekart configured to require Google OAuth")
default:
log.Info().Msgf("All users can read/write all entities")
}

if c.DevClaimsEmail != "" {
log.Warn().Msgf("Use DEKART_DEV_CLAIMS_EMAIL only in development environment")
}
}

func NewClaimsCheck(c ClaimsCheckConfig) ClaimsCheck {
validateConfig(c)
return ClaimsCheck{
c,
&sync.Map{},
Expand All @@ -90,12 +98,19 @@ func NewClaimsCheck(c ClaimsCheckConfig) ClaimsCheck {
// UnknownEmail is set as claims email when auth is not required
const UnknownEmail = "UNKNOWN_EMAIL"

// validateToken receives Bearer token and receives user details
func (c ClaimsCheck) validateToken(ctx context.Context, header string) *Claims {
// validateAuthToken receives Bearer token and fetches user details
func (c ClaimsCheck) validateAuthToken(ctx context.Context, header string) *Claims {
if header == "" {
return nil
}

// check dev claims email after checking header to force redirect flow in dev mode
if c.DevClaimsEmail != "" {
return &Claims{
Email: c.DevClaimsEmail,
}
}

authHeaderParts := strings.Split(header, " ")
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
log.Warn().Msg("Invalid Authorization header format")
Expand Down Expand Up @@ -142,16 +157,12 @@ func GetTokenSource(ctx context.Context) oauth2.TokenSource {
func (c ClaimsCheck) GetContext(r *http.Request) context.Context {
ctx := r.Context()
var claims *Claims
if c.DevClaimsEmail != "" {
claims = &Claims{
Email: c.DevClaimsEmail,
}
} else if c.RequireIAP {
if c.RequireIAP {
claims = c.validateJWTFromAppEngine(ctx, r.Header.Get("X-Goog-IAP-JWT-Assertion"))
} else if c.RequireAmazonOIDC {
claims = c.validateJWTFromAmazonOIDC(ctx, r.Header.Get("x-amzn-oidc-data"))
} else if c.RequireGoogleOAuth {
claims = c.validateToken(ctx, r.Header.Get("Authorization"))
claims = c.validateAuthToken(ctx, r.Header.Get("Authorization"))
} else {
claims = &Claims{
Email: UnknownEmail,
Expand Down Expand Up @@ -240,6 +251,22 @@ func (c ClaimsCheck) requestToken(state *pb.AuthState, r *http.Request) *pb.Redi

const tokenRevokeURL = "https://oauth2.googleapis.com/revoke"

func (c ClaimsCheck) requestDevToken() *pb.RedirectState {
redirectState := &pb.RedirectState{}
token := &oauth2.Token{
AccessToken: "fake-access-token",
TokenType: "Bearer",
RefreshToken: "fake-refresh-token",
Expiry: time.Now().Add(time.Hour),
}
tokenBin, err := json.Marshal(*token)
if err != nil {
log.Fatal().Err(err).Msg("Error marshalling token")
}
redirectState.TokenJson = string(tokenBin)
return redirectState
}

// Authenticate redirects to Google OAuth
func (c ClaimsCheck) Authenticate(w http.ResponseWriter, r *http.Request) {
stateBase64 := r.URL.Query().Get("state")
Expand Down Expand Up @@ -268,6 +295,12 @@ func (c ClaimsCheck) Authenticate(w http.ResponseWriter, r *http.Request) {
return
}

if c.DevClaimsEmail != "" && state.Action == pb.AuthState_ACTION_REQUEST_CODE {
//skip request code from google
state.Action = pb.AuthState_ACTION_REQUEST_TOKEN
log.Debug().Msg("Skip request code from google, use dev token")
}

log.Debug().Msgf("Authenticate state action: %s", state.Action)
switch state.Action {
case pb.AuthState_ACTION_REQUEST_CODE: // request code from google
Expand All @@ -286,7 +319,13 @@ func (c ClaimsCheck) Authenticate(w http.ResponseWriter, r *http.Request) {
}
http.Redirect(w, r, url, http.StatusFound)
case pb.AuthState_ACTION_REQUEST_TOKEN: // exchange code for token and redirect to ui
redirectState := c.requestToken(&state, r)
var redirectState *pb.RedirectState
if c.DevClaimsEmail != "" {
//skip request token from google in dev mode
redirectState = c.requestDevToken()
} else {
redirectState = c.requestToken(&state, r)
}
redirectStateBin, err := proto.Marshal(redirectState)
if err != nil {
log.Fatal().Err(err).Msg("Error marshalling token")
Expand Down Expand Up @@ -361,6 +400,11 @@ func (c ClaimsCheck) getPublicKeyFromAmazon(token *jwt.Token) (interface{}, erro
// validateJWTFromAmazonOIDC parses and validates token from x-amzn-oidc-data
// see https://docs.aws.amazon.com/elasticloadbalancing/latest/application/listener-authenticate-users.html
func (c ClaimsCheck) validateJWTFromAmazonOIDC(ctx context.Context, header string) *Claims {
if c.DevClaimsEmail != "" {
return &Claims{
Email: c.DevClaimsEmail,
}
}
if header == "" {
return nil
}
Expand All @@ -383,6 +427,12 @@ func (c ClaimsCheck) validateJWTFromAmazonOIDC(ctx context.Context, header strin
// validateJWTFromAppEngine validates a JWT found in the
// "x-goog-iap-jwt-assertion" header.
func (c ClaimsCheck) validateJWTFromAppEngine(ctx context.Context, iapJWT string) *Claims {
if c.DevClaimsEmail != "" {
return &Claims{
Email: c.DevClaimsEmail,
}
}

payload, err := idtoken.Validate(ctx, iapJWT, c.Audience)
if err != nil {
log.Warn().Err(err).Msg("Error validating IAP JWT")
Expand Down

0 comments on commit d97e708

Please sign in to comment.