Skip to content

Commit

Permalink
rpc: more typed errors
Browse files Browse the repository at this point in the history
  • Loading branch information
patrislav committed Jul 11, 2024
1 parent 8fa73e4 commit 560cf58
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 32 deletions.
45 changes: 32 additions & 13 deletions rpc/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpc

import (
"context"
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -55,16 +56,16 @@ func (s *RPC) federateAccount(
tntData := tenant.FromContext(ctx)

if intent.Data.SessionID != sess.ID {
return nil, fmt.Errorf("sessionId mismatch")
return nil, proto.ErrWebrpcBadRequest.WithCausef("sessionId mismatch")
}

authProvider, err := s.getAuthProvider(intent.Data.IdentityType)
if err != nil {
return nil, fmt.Errorf("get auth provider: %w", err)
return nil, proto.ErrWebrpcBadRequest.WithCausef("get auth provider: %w", err)
}

if intent.Data.IdentityType == intents.IdentityType_Guest {
return nil, fmt.Errorf("cannot federate a guest account")
return nil, proto.ErrWebrpcBadRequest.WithCausef("cannot federate a guest account")
}

var verifCtx *proto.VerificationContext
Expand All @@ -75,34 +76,52 @@ func (s *RPC) federateAccount(
}
dbVerifCtx, found, err := s.VerificationContexts.Get(ctx, authID)
if err != nil {
return nil, fmt.Errorf("getting verification context: %w", err)
return nil, proto.ErrWebrpcInternalError.WithCausef("getting verification context: %w", err)
}
if found && dbVerifCtx != nil {
verifCtx, _, err = crypto.DecryptData[*proto.VerificationContext](ctx, dbVerifCtx.EncryptedKey, dbVerifCtx.Ciphertext, tntData.KMSKeys)
if err != nil {
return nil, fmt.Errorf("decrypting verification context data: %w", err)
return nil, proto.ErrWebrpcInternalError.WithCausef("decrypting verification context data: %w", err)
}

if time.Now().After(verifCtx.ExpiresAt) {
return nil, fmt.Errorf("auth session expired")
return nil, proto.ErrChallengeExpired
}

if !dbVerifCtx.CorrespondsTo(verifCtx) {
return nil, fmt.Errorf("malformed verification context data")
return nil, proto.ErrWebrpcInternalError.WithCausef("malformed verification context data")
}
}

ident, err := authProvider.Verify(ctx, verifCtx, sess.ID, intent.Data.Answer)
if err != nil {
return nil, fmt.Errorf("verifying identity: %w", err)
if verifCtx != nil {
now := time.Now()
verifCtx.Attempts += 1
verifCtx.LastAttemptAt = &now

encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tntData.KMSKeys[0], verifCtx)
if err != nil {
return nil, proto.ErrWebrpcInternalError.WithCausef("encrypt data: %w", err)
}
if err := s.VerificationContexts.UpdateData(ctx, dbVerifCtx, encryptedKey, algorithm, ciphertext); err != nil {
return nil, proto.ErrWebrpcInternalError.WithCausef("update verification context: %w", err)
}
}

var wErr proto.WebRPCError
if errors.As(err, &wErr) {
return nil, wErr
}
return nil, proto.ErrAnswerIncorrect.WithCausef("verifying answer: %w", err)
}

_, found, err = s.Accounts.Get(ctx, tntData.ProjectID, ident)
if err != nil {
return nil, fmt.Errorf("retrieving account: %w", err)
return nil, proto.ErrWebrpcInternalError.WithCausef("retrieving account: %w", err)
}
if found {
return nil, fmt.Errorf("account already exists")
return nil, proto.ErrAccountAlreadyLinked
}

accData := &proto.AccountData{
Expand All @@ -114,7 +133,7 @@ func (s *RPC) federateAccount(

encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tntData.KMSKeys[0], accData)
if err != nil {
return nil, fmt.Errorf("encrypting account data: %w", err)
return nil, proto.ErrWebrpcInternalError.WithCausef("encrypting account data: %w", err)
}

account := &data.Account{
Expand All @@ -130,11 +149,11 @@ func (s *RPC) federateAccount(
}

if _, err := s.Wallets.FederateAccount(waasapi.Context(ctx), account.UserID, waasapi.ConvertToAPIIntent(intent.ToIntent())); err != nil {
return nil, fmt.Errorf("creating account with WaaS API: %w", err)
return nil, proto.ErrWebrpcInternalError.WithCausef("creating account with WaaS API: %w", err)
}

if err := s.Accounts.Put(ctx, account); err != nil {
return nil, fmt.Errorf("save account: %w", err)
return nil, proto.ErrWebrpcInternalError.WithCausef("save account: %w", err)
}

outAcc := &intents.Account{
Expand Down
43 changes: 24 additions & 19 deletions rpc/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package rpc

import (
"context"
"errors"
"fmt"
"strings"
"time"
Expand All @@ -27,28 +28,28 @@ func (s *RPC) RegisterSession(

intent, sessionID, err := parseIntent(protoIntent)
if err != nil {
return nil, nil, fmt.Errorf("parse intent: %w", err)
return nil, nil, proto.ErrWebrpcBadRequest.WithCausef("parse intent: %w", err)
}

if intent.Name != intents.IntentName_openSession {
return nil, nil, fmt.Errorf("unexpected intent name: %q", intent.Name)
return nil, nil, proto.ErrWebrpcBadRequest.WithCausef("unexpected intent name: %q", intent.Name)
}

ctx, span := tracing.Intent(ctx, intent)
defer span.End()

intentTyped, err := intents.NewIntentTypedFromIntent[intents.IntentDataOpenSession](intent)
if err != nil {
return nil, nil, err
return nil, nil, proto.ErrWebrpcBadRequest.WithCause(err)
}

if sessionID != intentTyped.Data.SessionID {
return nil, nil, fmt.Errorf("signing session and session to register must match")
return nil, nil, proto.ErrWebrpcBadRequest.WithCausef("signing session and session to register must match")
}

authProvider, err := s.getAuthProvider(intentTyped.Data.IdentityType)
if err != nil {
return nil, nil, fmt.Errorf("get auth provider: %w", err)
return nil, nil, proto.ErrWebrpcBadRequest.WithCause(err)
}

sessionHash := ethcoder.Keccak256Hash([]byte(strings.ToLower(sessionID))).String()
Expand All @@ -65,20 +66,20 @@ func (s *RPC) RegisterSession(
}
dbVerifCtx, found, err := s.VerificationContexts.Get(ctx, authID)
if err != nil {
return nil, nil, fmt.Errorf("getting verification context: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("getting verification context: %w", err)
}
if found && dbVerifCtx != nil {
verifCtx, _, err = crypto.DecryptData[*proto.VerificationContext](ctx, dbVerifCtx.EncryptedKey, dbVerifCtx.Ciphertext, tntData.KMSKeys)
if err != nil {
return nil, nil, fmt.Errorf("decrypting verification context data: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("decrypting verification context data: %w", err)
}

if time.Now().After(verifCtx.ExpiresAt) {
return nil, nil, fmt.Errorf("verification context expired")
return nil, nil, proto.ErrChallengeExpired
}

if !dbVerifCtx.CorrespondsTo(verifCtx) {
return nil, nil, fmt.Errorf("malformed verification context data")
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("malformed verification context data")
}
}

Expand All @@ -91,29 +92,33 @@ func (s *RPC) RegisterSession(

encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tntData.KMSKeys[0], verifCtx)
if err != nil {
return nil, nil, fmt.Errorf("encrypt data: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("encrypt data: %w", err)
}
if err := s.VerificationContexts.UpdateData(ctx, dbVerifCtx, encryptedKey, algorithm, ciphertext); err != nil {
return nil, nil, fmt.Errorf("update verification context: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("update verification context: %w", err)
}
}

return nil, nil, fmt.Errorf("verifying answer: %w", err)
var wErr proto.WebRPCError
if errors.As(err, &wErr) {
return nil, nil, wErr
}
return nil, nil, proto.ErrAnswerIncorrect.WithCausef("verifying answer: %w", err)
}

// always use normalized email address
ident.Email = email.Normalize(ident.Email)

account, accountFound, err := s.Accounts.Get(ctx, tntData.ProjectID, ident)
if err != nil {
return nil, nil, fmt.Errorf("failed to retrieve account: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("failed to retrieve account: %w", err)
}

if !accountFound {
if !intentTyped.Data.ForceCreateAccount && ident.Email != "" {
accs, err := s.Accounts.ListByEmail(ctx, tntData.ProjectID, ident.Email)
if err != nil {
return nil, nil, fmt.Errorf("failed to perform email check: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("failed to perform email check: %w", err)
}
if len(accs) > 0 {
cause := string(accs[0].Identity.Type) + "|" + accs[0].Email
Expand All @@ -132,7 +137,7 @@ func (s *RPC) RegisterSession(
}
encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tntData.KMSKeys[0], accData)
if err != nil {
return nil, nil, fmt.Errorf("encrypting account data: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("encrypting account data: %w", err)
}

account = &data.Account{
Expand All @@ -150,12 +155,12 @@ func (s *RPC) RegisterSession(

res, err := s.Wallets.RegisterSession(waasapi.Context(ctx), account.UserID, waasapi.ConvertToAPIIntent(intent))
if err != nil {
return nil, nil, fmt.Errorf("registering session with WaaS API: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("registering session with WaaS API: %w", err)
}

if !accountFound {
if err := s.Accounts.Put(ctx, account); err != nil {
return nil, nil, fmt.Errorf("save account: %w", err)
return nil, nil, proto.ErrWebrpcInternalError.WithCausef("save account: %w", err)
}
}

Expand All @@ -171,7 +176,7 @@ func (s *RPC) RegisterSession(

encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tntData.KMSKeys[0], sessData)
if err != nil {
return nil, convertIntentResponse(res), fmt.Errorf("encrypting session data: %w", err)
return nil, convertIntentResponse(res), proto.ErrWebrpcInternalError.WithCausef("encrypting session data: %w", err)
}

dbSess := &data.Session{
Expand All @@ -188,7 +193,7 @@ func (s *RPC) RegisterSession(
}

if err := s.Sessions.Put(ctx, dbSess); err != nil {
return nil, convertIntentResponse(res), fmt.Errorf("save session: %w", err)
return nil, convertIntentResponse(res), proto.ErrWebrpcInternalError.WithCausef("save session: %w", err)
}

retSess := &proto.Session{
Expand Down

0 comments on commit 560cf58

Please sign in to comment.