From e89643989866cfa4dbc20bfd3b1ff6b358410ad5 Mon Sep 17 00:00:00 2001 From: Patryk Kalinowski Date: Fri, 31 May 2024 13:05:47 +0200 Subject: [PATCH] WIP: commit/reveal OIDC flow --- rpc/auth/oidc/provider.go | 214 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 214 insertions(+) create mode 100644 rpc/auth/oidc/provider.go diff --git a/rpc/auth/oidc/provider.go b/rpc/auth/oidc/provider.go new file mode 100644 index 00000000..11cac5e2 --- /dev/null +++ b/rpc/auth/oidc/provider.go @@ -0,0 +1,214 @@ +package oidc + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/0xsequence/ethkit/go-ethereum/common/hexutil" + ethcrypto "github.com/0xsequence/ethkit/go-ethereum/crypto" + "github.com/0xsequence/go-sequence/intents" + "github.com/0xsequence/waas-authenticator/proto" + "github.com/0xsequence/waas-authenticator/rpc/auth" + "github.com/0xsequence/waas-authenticator/rpc/tenant" + "github.com/0xsequence/waas-authenticator/rpc/tracing" + "github.com/goware/cachestore" + "github.com/goware/cachestore/cachestorectl" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "golang.org/x/sync/errgroup" +) + +type AuthProvider struct { + client HTTPClient + store cachestore.Store[jwk.Key] +} + +func NewAuthProvider(cacheBackend cachestore.Backend, client HTTPClient) (auth.Provider, error) { + if client == nil { + client = http.DefaultClient + } + store, err := cachestorectl.Open[jwk.Key](cacheBackend) + if err != nil { + return nil, err + } + return &AuthProvider{ + client: client, + store: store, + }, nil +} + +func (v *AuthProvider) InitiateAuth( + ctx context.Context, + verifCtx *proto.VerificationContext, + verifier string, + intent *intents.Intent, + storeFn auth.StoreVerificationContextFn, +) (*intents.IntentResponseAuthInitiated, error) { + tnt := tenant.FromContext(ctx) + + if verifCtx != nil { + return nil, fmt.Errorf("cannot reuse an old ID token") + } + + _, expiresAt, err := v.extractVerifier(verifier) + if err != nil { + return nil, err + } + + if time.Now().After(expiresAt) { + return nil, fmt.Errorf("token expired") + } + + verifCtx = &proto.VerificationContext{ + ProjectID: tnt.ProjectID, + SessionID: intent.Signers()[0], + IdentityType: proto.IdentityType_OIDC, + Verifier: verifier, + ExpiresAt: expiresAt, + } + + if err := storeFn(ctx, verifCtx); err != nil { + return nil, err + } + + res := &intents.IntentResponseAuthInitiated{ + SessionID: verifCtx.SessionID, + IdentityType: intents.IdentityType_Email, + ExpiresIn: int(verifCtx.ExpiresAt.Sub(time.Now()).Seconds()), + } + return res, nil +} + +func (v *AuthProvider) Verify(ctx context.Context, verifCtx *proto.VerificationContext, sessionID string, answer string) (ident proto.Identity, err error) { + ctx, span := tracing.Span(ctx, "AuthProvider.Verify") + defer func() { + if err != nil { + span.RecordError(err) + } + span.End() + }() + + if verifCtx == nil { + return proto.Identity{}, fmt.Errorf("auth session not found") + } + + tokHash, expiresAt, err := v.extractVerifier(verifCtx.Verifier) + if err != nil { + return proto.Identity{}, err + } + + tok, err := jwt.Parse([]byte(answer), jwt.WithVerify(false), jwt.WithValidate(false)) + if err != nil { + return proto.Identity{}, fmt.Errorf("parse JWT: %w", err) + } + + issuer := normalizeIssuer(tok.Issuer()) + idp := getOIDCProvider(ctx, issuer) + if idp == nil { + return proto.Identity{}, fmt.Errorf("issuer %q not valid for this tenant", issuer) + } + + expectedHash := hexutil.Encode(ethcrypto.Keccak256([]byte(answer))) + if tokHash != expectedHash { + return proto.Identity{}, fmt.Errorf("invalid token hash") + } + + if !tok.Expiration().Equal(expiresAt) { + return proto.Identity{}, fmt.Errorf("invalid exp claim") + } + + ks := &operationKeySet{ + ctx: ctx, + iss: issuer, + store: v.store, + getKeySet: v.GetKeySet, + } + + if _, err := jws.Verify([]byte(answer), jws.WithKeySet(ks, jws.WithMultipleKeysPerKeyID(false))); err != nil { + return proto.Identity{}, fmt.Errorf("signature verification: %w", err) + } + + validateOptions := []jwt.ValidateOption{ + jwt.WithValidator(withIssuer(idp.Issuer)), + jwt.WithAcceptableSkew(10 * time.Second), + jwt.WithValidator(withAudience(idp.Audience)), + } + + if err := jwt.Validate(tok, validateOptions...); err != nil { + return proto.Identity{}, fmt.Errorf("JWT validation: %w", err) + } + + identity := proto.Identity{ + Type: proto.IdentityType_OIDC, + Issuer: issuer, + Subject: tok.Subject(), + Email: getEmailFromToken(tok), + } + return identity, nil +} + +func (v *AuthProvider) ValidateTenant(ctx context.Context, tenant *proto.TenantData) error { + var wg errgroup.Group + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + for i, provider := range tenant.OIDCProviders { + provider := provider + + if provider.Issuer == "" { + return fmt.Errorf("provider %d: empty issuer", i) + } + + if len(provider.Audience) < 1 { + return fmt.Errorf("provider %d: at least one audience is required", i) + } + + wg.Go(func() error { + if _, err := v.GetKeySet(ctx, provider.Issuer); err != nil { + return err + } + return nil + }) + } + + return wg.Wait() +} + +func (v *AuthProvider) GetKeySet(ctx context.Context, issuer string) (set jwk.Set, err error) { + ctx, span := tracing.Span(ctx, "AuthProvider.GetKeySet") + defer func() { + if err != nil { + span.RecordError(err) + } + span.End() + }() + + jwksURL, err := fetchJWKSURL(ctx, v.client, issuer) + if err != nil { + return nil, fmt.Errorf("fetch issuer keys: %w", err) + } + + keySet, err := jwk.Fetch(ctx, jwksURL, jwk.WithHTTPClient(tracing.WrapClientWithContext(ctx, v.client))) + if err != nil { + return nil, fmt.Errorf("fetch issuer keys: %w", err) + } + return keySet, nil +} + +func (v *AuthProvider) extractVerifier(verifier string) (tokHash string, expiresAt time.Time, err error) { + parts := strings.SplitN(verifier, ";", 2) + + tokHash = parts[0] + exp, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return "", time.Time{}, fmt.Errorf("parse exp: %w", err) + } + expiresAt = time.Unix(exp, 0) + + return tokHash, expiresAt, nil +}