diff --git a/config/config.go b/config/config.go index 690418a6..08691f16 100644 --- a/config/config.go +++ b/config/config.go @@ -9,18 +9,19 @@ import ( ) type Config struct { - Mode Mode `toml:"-"` - Region string `toml:"region"` - Service ServiceConfig `toml:"service"` - Admin AdminConfig `toml:"admin"` - Endpoints EndpointsConfig `toml:"endpoints"` - KMS KMSConfig `toml:"kms"` - SES SESConfig `toml:"ses"` - Builder BuilderConfig `toml:"builder"` - Database DatabaseConfig `toml:"database"` - Signing SigningConfig `toml:"signing"` - Telemetry telemetry.Config `toml:"telemetry"` - Tracing TracingConfig `toml:"tracing"` + Mode Mode `toml:"-"` + Region string `toml:"region"` + Service ServiceConfig `toml:"service"` + Admin AdminConfig `toml:"admin"` + Endpoints EndpointsConfig `toml:"endpoints"` + KMS KMSConfig `toml:"kms"` + SES SESConfig `toml:"ses"` + Builder BuilderConfig `toml:"builder"` + Database DatabaseConfig `toml:"database"` + Signing SigningConfig `toml:"signing"` + Telemetry telemetry.Config `toml:"telemetry"` + Tracing TracingConfig `toml:"tracing"` + Migrations MigrationsConfig `toml:"migrations"` } type AdminConfig struct { diff --git a/config/migrations.go b/config/migrations.go new file mode 100644 index 00000000..e743fde2 --- /dev/null +++ b/config/migrations.go @@ -0,0 +1,11 @@ +package config + +type MigrationsConfig struct { + OIDCToStytch []OIDCToStytchConfig `toml:"oidc_to_stytch"` +} + +type OIDCToStytchConfig struct { + SequenceProject uint64 `toml:"sequence_project"` + StytchProject string `toml:"stytch_project"` + FromIssuer string `toml:"from_issuer"` +} diff --git a/data/account.go b/data/account.go index 29601f4f..cafcbcfa 100644 --- a/data/account.go +++ b/data/account.go @@ -63,6 +63,28 @@ func NewAccountTable(db DB, tableARN string, indices AccountIndices) *AccountTab } } +// Create creates a new Account or fails if it already exists. +func (t *AccountTable) Create(ctx context.Context, acct *Account) error { + acct.CreatedAt = time.Now() + + av, err := attributevalue.MarshalMap(acct) + if err != nil { + return fmt.Errorf("marshal input: %w", err) + } + input := &dynamodb.PutItemInput{ + TableName: &t.tableARN, + Item: av, + ConditionExpression: aws.String("attribute_not_exists(#I)"), + ExpressionAttributeNames: map[string]string{ + "#I": "Identity", + }, + } + if _, err := t.db.PutItem(ctx, input); err != nil { + return fmt.Errorf("PutItem: %w", err) + } + return nil +} + // Put updates an Account by ProjectID or creates one if it doesn't exist yet. func (t *AccountTable) Put(ctx context.Context, acct *Account) error { acct.CreatedAt = time.Now() @@ -175,3 +197,83 @@ func (t *AccountTable) Delete(ctx context.Context, projectID uint64, identity pr } return nil } + +func (t *AccountTable) ListByProjectAndIdentity(ctx context.Context, page Page, projectID uint64, identityType proto.IdentityType, issuer string) ([]*Account, Page, error) { + if page.Limit <= 0 { + page.Limit = 25 + } + if page.Limit > 100 { + page.Limit = 100 + } + + input := &dynamodb.QueryInput{ + TableName: &t.tableARN, + KeyConditionExpression: aws.String("#P = :projectID"), + ExpressionAttributeNames: map[string]string{ + "#P": "ProjectID", + }, + ExpressionAttributeValues: map[string]types.AttributeValue{ + ":projectID": &types.AttributeValueMemberN{Value: fmt.Sprintf("%d", projectID)}, + }, + Limit: &page.Limit, + ExclusiveStartKey: page.NextKey, + } + + var identCond string + if identityType != proto.IdentityType_None { + identCond = string(identityType) + ":" + if issuer != "" { + identCond += issuer + "#" + } + + *input.KeyConditionExpression += " and begins_with(#I, :identCond)" + input.ExpressionAttributeNames["#I"] = "Identity" + input.ExpressionAttributeValues[":identCond"] = &types.AttributeValueMemberS{Value: identCond} + } + + out, err := t.db.Query(ctx, input) + if err != nil { + return nil, page, fmt.Errorf("Query: %w", err) + } + + accounts := make([]*Account, len(out.Items)) + for i, item := range out.Items { + if err := attributevalue.UnmarshalMap(item, &accounts[i]); err != nil { + return nil, page, fmt.Errorf("unmarshal result: %w", err) + } + } + + page.NextKey = out.LastEvaluatedKey + return accounts, page, nil +} + +func (t *AccountTable) GetBatch(ctx context.Context, projectID uint64, identities []proto.Identity) ([]*Account, error) { + keys := make([]map[string]types.AttributeValue, len(identities)) + for i, identity := range identities { + acct := Account{ProjectID: projectID, Identity: Identity(identity)} + keys[i] = acct.Key() + } + + input := &dynamodb.BatchGetItemInput{ + RequestItems: map[string]types.KeysAndAttributes{ + t.tableARN: {Keys: keys}, + }, + } + + out, err := t.db.BatchGetItem(ctx, input) + if err != nil { + return nil, fmt.Errorf("BatchGetItem: %w", err) + } + + for _, results := range out.Responses { + accounts := make([]*Account, len(results)) + for i, item := range results { + if err := attributevalue.UnmarshalMap(item, &accounts[i]); err != nil { + return nil, fmt.Errorf("unmarshal result: %w", err) + } + } + return accounts, nil + } + + return make([]*Account, 0), nil +} diff --git a/data/interfaces.go b/data/interfaces.go index d05b4fa8..2ec58af4 100644 --- a/data/interfaces.go +++ b/data/interfaces.go @@ -9,6 +9,7 @@ import ( // DB is an abstraction over *dynamodb.Client defining the methods that we need for DynamoDB access type DB interface { GetItem(ctx context.Context, params *dynamodb.GetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.GetItemOutput, error) + BatchGetItem(ctx context.Context, params *dynamodb.BatchGetItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.BatchGetItemOutput, error) PutItem(ctx context.Context, params *dynamodb.PutItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.PutItemOutput, error) DeleteItem(ctx context.Context, params *dynamodb.DeleteItemInput, optFns ...func(*dynamodb.Options)) (*dynamodb.DeleteItemOutput, error) Query(ctx context.Context, params *dynamodb.QueryInput, optFns ...func(*dynamodb.Options)) (*dynamodb.QueryOutput, error) diff --git a/data/page.go b/data/page.go new file mode 100644 index 00000000..ed234be0 --- /dev/null +++ b/data/page.go @@ -0,0 +1,61 @@ +package data + +import ( + "encoding/base64" + "encoding/json" + + "github.com/0xsequence/waas-authenticator/proto" + "github.com/aws/aws-sdk-go-v2/feature/dynamodb/attributevalue" + "github.com/aws/aws-sdk-go-v2/service/dynamodb/types" +) + +type Page struct { + NextKey map[string]types.AttributeValue + Limit int32 +} + +func PageFromProto(protoPage *proto.Page) (Page, error) { + page := Page{ + Limit: 25, + } + + if protoPage != nil { + if protoPage.Limit > 0 { + page.Limit = int32(protoPage.Limit) + } + + if protoPage.After != "" { + nextKeyString, err := base64.StdEncoding.DecodeString(protoPage.After) + if err != nil { + return page, err + } + var nextKeyMap map[string]any + if err := json.Unmarshal(nextKeyString, &nextKeyMap); err != nil { + return page, err + } + avMap, err := attributevalue.MarshalMap(nextKeyMap) + if err != nil { + return page, err + } + page.NextKey = avMap + } + } + + return page, nil +} + +func (p *Page) ToProto() (*proto.Page, error) { + protoPage := &proto.Page{Limit: uint32(p.Limit)} + if p.NextKey != nil { + nextKeyMap := make(map[string]any, len(p.NextKey)) + if err := attributevalue.UnmarshalMap(p.NextKey, nextKeyMap); err != nil { + return nil, err + } + b, err := json.Marshal(nextKeyMap) + if err != nil { + return nil, err + } + protoPage.After = base64.StdEncoding.EncodeToString(b) + } + return protoPage, nil +} diff --git a/etc/waas-auth.dev.conf b/etc/waas-auth.dev.conf index 2aa2fbb8..02f9c8dd 100644 --- a/etc/waas-auth.dev.conf +++ b/etc/waas-auth.dev.conf @@ -53,3 +53,8 @@ QwIDAQAB [signing] issuer = "https://dev-waas.sequence.app" audience_prefix = "https://dev.sequence.build/project/" + +[[migrations.oidc_to_stytch]] + sequence_project = 694 + stytch_project = "project-test-c6241c64-de15-412a-a843-09966c98de57" + from_issuer = "https://oidc-wrapper.sequence.info" diff --git a/proto/authenticator.gen.go b/proto/authenticator.gen.go index ba1f2da8..f9dc9842 100644 --- a/proto/authenticator.gen.go +++ b/proto/authenticator.gen.go @@ -1,4 +1,4 @@ -// sequence-waas-authenticator v0.1.0 35f86317a98af91896d1114ad52dd22102d9de9f +// sequence-waas-authenticator v0.1.0 2434bf308eeece8d32c65c08f787c7d152d5d199 // -- // Code generated by webrpc-gen@v0.18.8 with golang generator. DO NOT EDIT. // @@ -34,7 +34,7 @@ func WebRPCSchemaVersion() string { // Schema hash generated from your RIDL schema func WebRPCSchemaHash() string { - return "35f86317a98af91896d1114ad52dd22102d9de9f" + return "2434bf308eeece8d32c65c08f787c7d152d5d199" } // @@ -177,6 +177,33 @@ type IntentResponse struct { Data interface{} `json:"data"` } +type Migration string + +const ( + Migration_OIDCToStytch Migration = "OIDCToStytch" +) + +func (x Migration) MarshalText() ([]byte, error) { + return []byte(x), nil +} + +func (x *Migration) UnmarshalText(b []byte) error { + *x = Migration(string(b)) + return nil +} + +func (x *Migration) Is(values ...Migration) bool { + if x == nil { + return false + } + for _, v := range values { + if *x == v { + return true + } + } + return false +} + type Version struct { WebrpcVersion string `json:"webrpcVersion"` SchemaVersion string `json:"schemaVersion"` @@ -303,6 +330,11 @@ type VerificationContext struct { ExpiresAt time.Time `json:"expiresAt"` } +type Page struct { + Limit uint32 `json:"limit,omitempty"` + After string `json:"after,omitempty"` +} + var WebRPCServices = map[string][]string{ "WaasAuthenticator": { "RegisterSession", @@ -316,6 +348,8 @@ var WebRPCServices = map[string][]string{ "GetTenant", "CreateTenant", "UpdateTenant", + "NextMigrationBatch", + "ProcessMigrationBatch", }, } @@ -336,6 +370,8 @@ type WaasAuthenticatorAdmin interface { GetTenant(ctx context.Context, projectId uint64) (*Tenant, error) CreateTenant(ctx context.Context, projectId uint64, waasAccessToken string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string, password *string) (*Tenant, string, error) UpdateTenant(ctx context.Context, projectId uint64, upgradeCode string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string) (*Tenant, error) + NextMigrationBatch(ctx context.Context, migration Migration, projectId uint64, page *Page) (*Page, []string, error) + ProcessMigrationBatch(ctx context.Context, migration Migration, projectId uint64, items []string) (map[string][]string, map[string]string, error) } // @@ -355,6 +391,8 @@ type WaasAuthenticatorAdminClient interface { GetTenant(ctx context.Context, projectId uint64) (*Tenant, error) CreateTenant(ctx context.Context, projectId uint64, waasAccessToken string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string, password *string) (*Tenant, string, error) UpdateTenant(ctx context.Context, projectId uint64, upgradeCode string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string) (*Tenant, error) + NextMigrationBatch(ctx context.Context, migration Migration, projectId uint64, page *Page) (*Page, []string, error) + ProcessMigrationBatch(ctx context.Context, migration Migration, projectId uint64, items []string) (map[string][]string, map[string]string, error) } // @@ -593,6 +631,10 @@ func (s *waasAuthenticatorAdminServer) ServeHTTP(w http.ResponseWriter, r *http. handler = s.serveCreateTenantJSON case "/rpc/WaasAuthenticatorAdmin/UpdateTenant": handler = s.serveUpdateTenantJSON + case "/rpc/WaasAuthenticatorAdmin/NextMigrationBatch": + handler = s.serveNextMigrationBatchJSON + case "/rpc/WaasAuthenticatorAdmin/ProcessMigrationBatch": + handler = s.serveProcessMigrationBatchJSON default: err := ErrWebrpcBadRoute.WithCause(fmt.Errorf("no handler for path %q", r.URL.Path)) s.sendErrorJSON(w, r, err) @@ -844,6 +886,98 @@ func (s *waasAuthenticatorAdminServer) serveUpdateTenantJSON(ctx context.Context w.Write(respBody) } +func (s *waasAuthenticatorAdminServer) serveNextMigrationBatchJSON(ctx context.Context, w http.ResponseWriter, r *http.Request) { + ctx = context.WithValue(ctx, MethodNameCtxKey, "NextMigrationBatch") + + reqBody, err := io.ReadAll(r.Body) + if err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to read request data: %w", err))) + return + } + defer r.Body.Close() + + reqPayload := struct { + Arg0 Migration `json:"migration"` + Arg1 uint64 `json:"projectId"` + Arg2 *Page `json:"page"` + }{} + if err := json.Unmarshal(reqBody, &reqPayload); err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to unmarshal request data: %w", err))) + return + } + + // Call service method implementation. + ret0, ret1, err := s.WaasAuthenticatorAdmin.NextMigrationBatch(ctx, reqPayload.Arg0, reqPayload.Arg1, reqPayload.Arg2) + if err != nil { + rpcErr, ok := err.(WebRPCError) + if !ok { + rpcErr = ErrWebrpcEndpoint.WithCause(err) + } + s.sendErrorJSON(w, r, rpcErr) + return + } + + respPayload := struct { + Ret0 *Page `json:"page"` + Ret1 []string `json:"items"` + }{ret0, ret1} + respBody, err := json.Marshal(respPayload) + if err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to marshal json response: %w", err))) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(respBody) +} + +func (s *waasAuthenticatorAdminServer) serveProcessMigrationBatchJSON(ctx context.Context, w http.ResponseWriter, r *http.Request) { + ctx = context.WithValue(ctx, MethodNameCtxKey, "ProcessMigrationBatch") + + reqBody, err := io.ReadAll(r.Body) + if err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to read request data: %w", err))) + return + } + defer r.Body.Close() + + reqPayload := struct { + Arg0 Migration `json:"migration"` + Arg1 uint64 `json:"projectId"` + Arg2 []string `json:"items"` + }{} + if err := json.Unmarshal(reqBody, &reqPayload); err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadRequest.WithCause(fmt.Errorf("failed to unmarshal request data: %w", err))) + return + } + + // Call service method implementation. + ret0, ret1, err := s.WaasAuthenticatorAdmin.ProcessMigrationBatch(ctx, reqPayload.Arg0, reqPayload.Arg1, reqPayload.Arg2) + if err != nil { + rpcErr, ok := err.(WebRPCError) + if !ok { + rpcErr = ErrWebrpcEndpoint.WithCause(err) + } + s.sendErrorJSON(w, r, rpcErr) + return + } + + respPayload := struct { + Ret0 map[string][]string `json:"logs"` + Ret1 map[string]string `json:"errors"` + }{ret0, ret1} + respBody, err := json.Marshal(respPayload) + if err != nil { + s.sendErrorJSON(w, r, ErrWebrpcBadResponse.WithCause(fmt.Errorf("failed to marshal json response: %w", err))) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(respBody) +} + func (s *waasAuthenticatorAdminServer) sendErrorJSON(w http.ResponseWriter, r *http.Request, rpcErr WebRPCError) { if s.OnError != nil { s.OnError(r, &rpcErr) @@ -952,18 +1086,20 @@ func (c *waasAuthenticatorClient) ChainList(ctx context.Context) ([]*Chain, erro type waasAuthenticatorAdminClient struct { client HTTPClient - urls [6]string + urls [8]string } func NewWaasAuthenticatorAdminClient(addr string, client HTTPClient) WaasAuthenticatorAdminClient { prefix := urlBase(addr) + WaasAuthenticatorAdminPathPrefix - urls := [6]string{ + urls := [8]string{ prefix + "Version", prefix + "RuntimeStatus", prefix + "Clock", prefix + "GetTenant", prefix + "CreateTenant", prefix + "UpdateTenant", + prefix + "NextMigrationBatch", + prefix + "ProcessMigrationBatch", } return &waasAuthenticatorAdminClient{ client: client, @@ -1086,6 +1222,50 @@ func (c *waasAuthenticatorAdminClient) UpdateTenant(ctx context.Context, project return out.Ret0, err } +func (c *waasAuthenticatorAdminClient) NextMigrationBatch(ctx context.Context, migration Migration, projectId uint64, page *Page) (*Page, []string, error) { + in := struct { + Arg0 Migration `json:"migration"` + Arg1 uint64 `json:"projectId"` + Arg2 *Page `json:"page"` + }{migration, projectId, page} + out := struct { + Ret0 *Page `json:"page"` + Ret1 []string `json:"items"` + }{} + + resp, err := doHTTPRequest(ctx, c.client, c.urls[6], in, &out) + if resp != nil { + cerr := resp.Body.Close() + if err == nil && cerr != nil { + err = ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to close response body: %w", cerr)) + } + } + + return out.Ret0, out.Ret1, err +} + +func (c *waasAuthenticatorAdminClient) ProcessMigrationBatch(ctx context.Context, migration Migration, projectId uint64, items []string) (map[string][]string, map[string]string, error) { + in := struct { + Arg0 Migration `json:"migration"` + Arg1 uint64 `json:"projectId"` + Arg2 []string `json:"items"` + }{migration, projectId, items} + out := struct { + Ret0 map[string][]string `json:"logs"` + Ret1 map[string]string `json:"errors"` + }{} + + resp, err := doHTTPRequest(ctx, c.client, c.urls[7], in, &out) + if resp != nil { + cerr := resp.Body.Close() + if err == nil && cerr != nil { + err = ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to close response body: %w", cerr)) + } + } + + return out.Ret0, out.Ret1, err +} + // HTTPClient is the interface used by generated clients to send HTTP requests. // It is fulfilled by *(net/http).Client, which is sufficient for most users. // Users can provide their own implementation for special retry policies. diff --git a/proto/authenticator.ridl b/proto/authenticator.ridl index f9faf3a1..6a6b1ded 100644 --- a/proto/authenticator.ridl +++ b/proto/authenticator.ridl @@ -189,6 +189,19 @@ struct VerificationContext - expiresAt: timestamp +struct Page + - limit?: uint32 + + go.field.type = uint32 + + go.tag.json = limit,omitempty + - after?: string + + go.field.type = string + + go.tag.json = after,omitempty + + +enum Migration: string + - OIDCToStytch + + ## ## Errors ## @@ -221,3 +234,6 @@ service WaasAuthenticatorAdmin - GetTenant(projectId: uint64) => (tenant: Tenant) - CreateTenant(projectId: uint64, waasAccessToken: string, authConfig: AuthConfig, oidcProviders: []OpenIdProvider, allowedOrigins: []string, password?: string) => (tenant: Tenant, upgradeCode: string) - UpdateTenant(projectId: uint64, upgradeCode: string, authConfig: AuthConfig, oidcProviders: []OpenIdProvider, allowedOrigins: []string) => (tenant: Tenant) + + - NextMigrationBatch(migration: Migration, projectId: uint64, page: Page) => (page: Page, items: []string) + - ProcessMigrationBatch(migration: Migration, projectId: uint64, items: []string) => (logs: map, errors: map) diff --git a/proto/clients/authenticator.gen.go b/proto/clients/authenticator.gen.go index 35d535cd..9428db37 100644 --- a/proto/clients/authenticator.gen.go +++ b/proto/clients/authenticator.gen.go @@ -1,4 +1,4 @@ -// sequence-waas-authenticator v0.1.0 35f86317a98af91896d1114ad52dd22102d9de9f +// sequence-waas-authenticator v0.1.0 2434bf308eeece8d32c65c08f787c7d152d5d199 // -- // Code generated by webrpc-gen@v0.18.8 with golang generator. DO NOT EDIT. // @@ -33,7 +33,7 @@ func WebRPCSchemaVersion() string { // Schema hash generated from your RIDL schema func WebRPCSchemaHash() string { - return "35f86317a98af91896d1114ad52dd22102d9de9f" + return "2434bf308eeece8d32c65c08f787c7d152d5d199" } // @@ -176,6 +176,33 @@ type IntentResponse struct { Data interface{} `json:"data"` } +type Migration string + +const ( + Migration_OIDCToStytch Migration = "OIDCToStytch" +) + +func (x Migration) MarshalText() ([]byte, error) { + return []byte(x), nil +} + +func (x *Migration) UnmarshalText(b []byte) error { + *x = Migration(string(b)) + return nil +} + +func (x *Migration) Is(values ...Migration) bool { + if x == nil { + return false + } + for _, v := range values { + if *x == v { + return true + } + } + return false +} + type Version struct { WebrpcVersion string `json:"webrpcVersion"` SchemaVersion string `json:"schemaVersion"` @@ -302,6 +329,11 @@ type VerificationContext struct { ExpiresAt time.Time `json:"expiresAt"` } +type Page struct { + Limit uint32 `json:"limit,omitempty"` + After string `json:"after,omitempty"` +} + var WebRPCServices = map[string][]string{ "WaasAuthenticator": { "RegisterSession", @@ -315,6 +347,8 @@ var WebRPCServices = map[string][]string{ "GetTenant", "CreateTenant", "UpdateTenant", + "NextMigrationBatch", + "ProcessMigrationBatch", }, } @@ -335,6 +369,8 @@ type WaasAuthenticatorAdmin interface { GetTenant(ctx context.Context, projectId uint64) (*Tenant, error) CreateTenant(ctx context.Context, projectId uint64, waasAccessToken string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string, password *string) (*Tenant, string, error) UpdateTenant(ctx context.Context, projectId uint64, upgradeCode string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string) (*Tenant, error) + NextMigrationBatch(ctx context.Context, migration Migration, projectId uint64, page *Page) (*Page, []string, error) + ProcessMigrationBatch(ctx context.Context, migration Migration, projectId uint64, items []string) (map[string][]string, map[string]string, error) } // @@ -354,6 +390,8 @@ type WaasAuthenticatorAdminClient interface { GetTenant(ctx context.Context, projectId uint64) (*Tenant, error) CreateTenant(ctx context.Context, projectId uint64, waasAccessToken string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string, password *string) (*Tenant, string, error) UpdateTenant(ctx context.Context, projectId uint64, upgradeCode string, authConfig *AuthConfig, oidcProviders []*OpenIdProvider, allowedOrigins []string) (*Tenant, error) + NextMigrationBatch(ctx context.Context, migration Migration, projectId uint64, page *Page) (*Page, []string, error) + ProcessMigrationBatch(ctx context.Context, migration Migration, projectId uint64, items []string) (map[string][]string, map[string]string, error) } // @@ -439,18 +477,20 @@ func (c *waasAuthenticatorClient) ChainList(ctx context.Context) ([]*Chain, erro type waasAuthenticatorAdminClient struct { client HTTPClient - urls [6]string + urls [8]string } func NewWaasAuthenticatorAdminClient(addr string, client HTTPClient) WaasAuthenticatorAdminClient { prefix := urlBase(addr) + WaasAuthenticatorAdminPathPrefix - urls := [6]string{ + urls := [8]string{ prefix + "Version", prefix + "RuntimeStatus", prefix + "Clock", prefix + "GetTenant", prefix + "CreateTenant", prefix + "UpdateTenant", + prefix + "NextMigrationBatch", + prefix + "ProcessMigrationBatch", } return &waasAuthenticatorAdminClient{ client: client, @@ -573,6 +613,50 @@ func (c *waasAuthenticatorAdminClient) UpdateTenant(ctx context.Context, project return out.Ret0, err } +func (c *waasAuthenticatorAdminClient) NextMigrationBatch(ctx context.Context, migration Migration, projectId uint64, page *Page) (*Page, []string, error) { + in := struct { + Arg0 Migration `json:"migration"` + Arg1 uint64 `json:"projectId"` + Arg2 *Page `json:"page"` + }{migration, projectId, page} + out := struct { + Ret0 *Page `json:"page"` + Ret1 []string `json:"items"` + }{} + + resp, err := doHTTPRequest(ctx, c.client, c.urls[6], in, &out) + if resp != nil { + cerr := resp.Body.Close() + if err == nil && cerr != nil { + err = ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to close response body: %w", cerr)) + } + } + + return out.Ret0, out.Ret1, err +} + +func (c *waasAuthenticatorAdminClient) ProcessMigrationBatch(ctx context.Context, migration Migration, projectId uint64, items []string) (map[string][]string, map[string]string, error) { + in := struct { + Arg0 Migration `json:"migration"` + Arg1 uint64 `json:"projectId"` + Arg2 []string `json:"items"` + }{migration, projectId, items} + out := struct { + Ret0 map[string][]string `json:"logs"` + Ret1 map[string]string `json:"errors"` + }{} + + resp, err := doHTTPRequest(ctx, c.client, c.urls[7], in, &out) + if resp != nil { + cerr := resp.Body.Close() + if err == nil && cerr != nil { + err = ErrWebrpcRequestFailed.WithCause(fmt.Errorf("failed to close response body: %w", cerr)) + } + } + + return out.Ret0, out.Ret1, err +} + // HTTPClient is the interface used by generated clients to send HTTP requests. // It is fulfilled by *(net/http).Client, which is sufficient for most users. // Users can provide their own implementation for special retry policies. diff --git a/proto/clients/authenticator.gen.ts b/proto/clients/authenticator.gen.ts index fba2ae43..31f49012 100644 --- a/proto/clients/authenticator.gen.ts +++ b/proto/clients/authenticator.gen.ts @@ -1,5 +1,5 @@ /* eslint-disable */ -// sequence-waas-authenticator v0.1.0 35f86317a98af91896d1114ad52dd22102d9de9f +// sequence-waas-authenticator v0.1.0 2434bf308eeece8d32c65c08f787c7d152d5d199 // -- // Code generated by webrpc-gen@v0.18.8 with typescript generator. DO NOT EDIT. // @@ -12,7 +12,7 @@ export const WebRPCVersion = "v1" export const WebRPCSchemaVersion = "v0.1.0" // Schema hash generated from your RIDL schema -export const WebRPCSchemaHash = "35f86317a98af91896d1114ad52dd22102d9de9f" +export const WebRPCSchemaHash = "2434bf308eeece8d32c65c08f787c7d152d5d199" // // Types @@ -86,6 +86,10 @@ export interface IntentResponse { data: any } +export enum Migration { + OIDCToStytch = 'OIDCToStytch' +} + export interface Version { webrpcVersion: string schemaVersion: string @@ -211,6 +215,11 @@ export interface VerificationContext { expiresAt: string } +export interface Page { + limit?: number + after?: string +} + export interface WaasAuthenticator { registerSession(args: RegisterSessionArgs, headers?: object, signal?: AbortSignal): Promise sendIntent(args: SendIntentArgs, headers?: object, signal?: AbortSignal): Promise @@ -247,6 +256,8 @@ export interface WaasAuthenticatorAdmin { getTenant(args: GetTenantArgs, headers?: object, signal?: AbortSignal): Promise createTenant(args: CreateTenantArgs, headers?: object, signal?: AbortSignal): Promise updateTenant(args: UpdateTenantArgs, headers?: object, signal?: AbortSignal): Promise + nextMigrationBatch(args: NextMigrationBatchArgs, headers?: object, signal?: AbortSignal): Promise + processMigrationBatch(args: ProcessMigrationBatchArgs, headers?: object, signal?: AbortSignal): Promise } export interface VersionArgs { @@ -298,6 +309,26 @@ export interface UpdateTenantArgs { export interface UpdateTenantReturn { tenant: Tenant } +export interface NextMigrationBatchArgs { + migration: Migration + projectId: number + page: Page +} + +export interface NextMigrationBatchReturn { + page: Page + items: Array +} +export interface ProcessMigrationBatchArgs { + migration: Migration + projectId: number + items: Array +} + +export interface ProcessMigrationBatchReturn { + logs: {[key: string]: Array} + errors: {[key: string]: string} +} @@ -465,6 +496,36 @@ export class WaasAuthenticatorAdmin implements WaasAuthenticatorAdmin { }) } + nextMigrationBatch = (args: NextMigrationBatchArgs, headers?: object, signal?: AbortSignal): Promise => { + return this.fetch( + this.url('NextMigrationBatch'), + createHTTPRequest(args, headers, signal)).then((res) => { + return buildResponse(res).then(_data => { + return { + page: (_data.page), + items: >(_data.items), + } + }) + }, (error) => { + throw WebrpcRequestFailedError.new({ cause: `fetch(): ${error.message || ''}` }) + }) + } + + processMigrationBatch = (args: ProcessMigrationBatchArgs, headers?: object, signal?: AbortSignal): Promise => { + return this.fetch( + this.url('ProcessMigrationBatch'), + createHTTPRequest(args, headers, signal)).then((res) => { + return buildResponse(res).then(_data => { + return { + logs: <{[key: string]: Array}>(_data.logs), + errors: <{[key: string]: string}>(_data.errors), + } + }) + }, (error) => { + throw WebrpcRequestFailedError.new({ cause: `fetch(): ${error.message || ''}` }) + }) + } + } const createHTTPRequest = (body: object = {}, headers: object = {}, signal: AbortSignal | null = null): object => { diff --git a/rpc/auth/oidc/stytch.go b/rpc/auth/oidc/stytch.go index 3bf92941..1943de81 100644 --- a/rpc/auth/oidc/stytch.go +++ b/rpc/auth/oidc/stytch.go @@ -142,12 +142,16 @@ func (p *StytchAuthProvider) getEmailFromToken(tok jwt.Token) string { if !ok { return "" } - authFactors, ok := sessionMap["auth_factors"].([]map[string]any) + authFactors, ok := sessionMap["authentication_factors"].([]any) if !ok || len(authFactors) == 0 { return "" } for _, authFactor := range authFactors { - emailFactor, ok := authFactor["email_factor"].(map[string]any) + factorMap, ok := authFactor.(map[string]any) + if !ok { + continue + } + emailFactor, ok := factorMap["email_factor"].(map[string]any) if !ok { continue } diff --git a/rpc/helpers_test.go b/rpc/helpers_test.go index 612ef9c8..23ac351c 100644 --- a/rpc/helpers_test.go +++ b/rpc/helpers_test.go @@ -378,12 +378,13 @@ func newAccount(t *testing.T, tnt *data.Tenant, enc *enclave.Enclave, identity p encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(context.Background(), att, "27ebbde0-49d2-4cb6-ad78-4f2c24fe7b79", payload) require.NoError(t, err) + email := "user@example.com" return &data.Account{ ProjectID: tnt.ProjectID, Identity: data.Identity(identity), UserID: payload.UserID, - Email: "user@example.com", - ProjectScopedEmail: fmt.Sprintf("%d|user@example.com", tnt.ProjectID), + Email: email, + ProjectScopedEmail: fmt.Sprintf("%d|%s", tnt.ProjectID, email), EncryptedKey: encryptedKey, Algorithm: algorithm, Ciphertext: ciphertext, @@ -391,11 +392,15 @@ func newAccount(t *testing.T, tnt *data.Tenant, enc *enclave.Enclave, identity p } } -func newOIDCIdentity(issuer string) proto.Identity { +func newOIDCIdentity(issuer string, optSubject ...string) proto.Identity { + subject := "SUBJECT" + if len(optSubject) > 0 { + subject = optSubject[0] + } return proto.Identity{ Type: proto.IdentityType_OIDC, Issuer: issuer, - Subject: "SUBJECT", + Subject: subject, } } @@ -407,6 +412,18 @@ func newEmailIdentity(email string) proto.Identity { } } +func newStytchIdentity(stytchProjectID string, optSubject ...string) proto.Identity { + subject := "SUBJECT" + if len(optSubject) > 0 { + subject = optSubject[0] + } + return proto.Identity{ + Type: proto.IdentityType_Stytch, + Issuer: stytchProjectID, + Subject: subject, + } +} + func newSessionFromData(t *testing.T, tnt *data.Tenant, enc *enclave.Enclave, payload *proto.SessionData) *data.Session { att, err := enc.GetAttestation(context.Background(), nil) require.NoError(t, err) diff --git a/rpc/migration/oidc_to_stytch.go b/rpc/migration/oidc_to_stytch.go new file mode 100644 index 00000000..ba806666 --- /dev/null +++ b/rpc/migration/oidc_to_stytch.go @@ -0,0 +1,187 @@ +package migration + +import ( + "context" + "errors" + "fmt" + + "github.com/0xsequence/waas-authenticator/config" + "github.com/0xsequence/waas-authenticator/data" + "github.com/0xsequence/waas-authenticator/proto" + "github.com/0xsequence/waas-authenticator/rpc/attestation" + "github.com/0xsequence/waas-authenticator/rpc/crypto" + "github.com/0xsequence/waas-authenticator/rpc/tenant" +) + +type OIDCToStytch struct { + accounts *data.AccountTable + tenants *data.TenantTable + configs map[uint64]config.OIDCToStytchConfig +} + +func (m *OIDCToStytch) OnRegisterSession(ctx context.Context, originalAccount *data.Account) error { + att := attestation.FromContext(ctx) + tntData := tenant.FromContext(ctx) + + if originalAccount.ProjectID != tntData.ProjectID { + return errors.New("project id does not match") + } + if originalAccount.Identity.Type != proto.IdentityType_OIDC { + return nil + } + + cfg, ok := m.configs[tntData.ProjectID] + if !ok { + return nil + } + if originalAccount.Identity.Issuer != cfg.FromIssuer { + return nil + } + + migratedIdentity := proto.Identity{ + Type: proto.IdentityType_Stytch, + Issuer: cfg.StytchProject, + Subject: originalAccount.Identity.Subject, + Email: originalAccount.Email, + } + _, accountFound, err := m.accounts.Get(ctx, tntData.ProjectID, migratedIdentity) + if err != nil { + return fmt.Errorf("failed to retrieve account: %w", err) + } + if accountFound { + return nil + } + + accData := &proto.AccountData{ + ProjectID: tntData.ProjectID, + UserID: originalAccount.UserID, + Identity: migratedIdentity.String(), + CreatedAt: originalAccount.CreatedAt, + } + encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tntData.KMSKeys[0], accData) + if err != nil { + return fmt.Errorf("encrypting account data: %w", err) + } + + account := &data.Account{ + ProjectID: tntData.ProjectID, + Identity: data.Identity(migratedIdentity), + UserID: accData.UserID, + Email: migratedIdentity.Email, + ProjectScopedEmail: fmt.Sprintf("%d|%s", tntData.ProjectID, migratedIdentity.Email), + EncryptedKey: encryptedKey, + Algorithm: algorithm, + Ciphertext: ciphertext, + CreatedAt: accData.CreatedAt, + } + if err := m.accounts.Create(ctx, account); err != nil { + return fmt.Errorf("saving account: %w", err) + } + return nil +} + +func (m *OIDCToStytch) NextBatch(ctx context.Context, projectID uint64, page data.Page) ([]string, data.Page, error) { + cfg, ok := m.configs[projectID] + if !ok { + return nil, page, fmt.Errorf("project %d not found", projectID) + } + + items := make([]string, 0, page.Limit) + for { + accounts, page, err := m.accounts.ListByProjectAndIdentity(ctx, page, projectID, proto.IdentityType_OIDC, cfg.FromIssuer) + if err != nil { + return nil, page, err + } + + for _, acc := range accounts { + migratedIdentity := proto.Identity{ + Type: proto.IdentityType_Stytch, + Issuer: cfg.StytchProject, + Subject: acc.Identity.Subject, + } + _, found, err := m.accounts.Get(ctx, acc.ProjectID, migratedIdentity) + if err != nil { + return nil, page, err + } + if !found { + items = append(items, acc.Identity.String()) + } + } + + if len(accounts) < int(page.Limit) || len(items) >= int(page.Limit) { + return items, page, nil + } + } +} + +func (m *OIDCToStytch) ProcessItems(ctx context.Context, tenant *proto.TenantData, items []string) (*Result, error) { + if len(items) > 100 { + return nil, fmt.Errorf("can only process 100 items at a time") + } + + att := attestation.FromContext(ctx) + cfg, ok := m.configs[tenant.ProjectID] + if !ok { + return nil, fmt.Errorf("project not configured for migration") + } + + res := NewResult() + + identities := make([]proto.Identity, len(items)) + for i, item := range items { + if err := identities[i].FromString(item); err != nil { + res.Errorf(item, "parsing identity: %w", err) + continue + } + if identities[i].Type != proto.IdentityType_OIDC || identities[i].Issuer != cfg.FromIssuer { + res.Errorf(item, "incorrect identity: %s", identities[i].String()) + continue + } + } + + originalAccounts, err := m.accounts.GetBatch(ctx, tenant.ProjectID, identities) + if err != nil { + return nil, fmt.Errorf("getting accounts: %w", err) + } + + for _, originalAccount := range originalAccounts { + item := originalAccount.Identity.String() + migratedIdentity := proto.Identity{ + Type: proto.IdentityType_Stytch, + Issuer: cfg.StytchProject, + Subject: originalAccount.Identity.Subject, + Email: originalAccount.Email, + } + accData := &proto.AccountData{ + ProjectID: tenant.ProjectID, + UserID: originalAccount.UserID, + Identity: migratedIdentity.String(), + CreatedAt: originalAccount.CreatedAt, + } + encryptedKey, algorithm, ciphertext, err := crypto.EncryptData(ctx, att, tenant.KMSKeys[0], accData) + if err != nil { + res.Errorf(item, "encrypting account data: %w", err) + continue + } + + account := &data.Account{ + ProjectID: tenant.ProjectID, + Identity: data.Identity(migratedIdentity), + UserID: accData.UserID, + Email: migratedIdentity.Email, + ProjectScopedEmail: fmt.Sprintf("%d|%s", tenant.ProjectID, migratedIdentity.Email), + EncryptedKey: encryptedKey, + Algorithm: algorithm, + Ciphertext: ciphertext, + CreatedAt: accData.CreatedAt, + } + if err := m.accounts.Create(ctx, account); err != nil { + res.Errorf(item, "saving account: %w", err) + continue + } + + res.AddItem(item) + } + + return res, nil +} diff --git a/rpc/migration/result.go b/rpc/migration/result.go new file mode 100644 index 00000000..a9f52f9d --- /dev/null +++ b/rpc/migration/result.go @@ -0,0 +1,36 @@ +package migration + +import ( + "fmt" +) + +type Result struct { + RowsAffected int + Logs map[string][]string + ItemsProcessed []string + Errors map[string]error +} + +func NewResult() *Result { + return &Result{ + Logs: make(map[string][]string), + Errors: make(map[string]error), + } +} + +func (r *Result) AddItem(item string) { + r.ItemsProcessed = append(r.ItemsProcessed, item) + r.RowsAffected++ +} + +func (r *Result) Errorf(item string, format string, args ...interface{}) { + r.Errors[item] = fmt.Errorf(format, args...) +} + +func (r *Result) Log(item string, logEntry string) { + r.Logs[item] = append(r.Logs[item], logEntry) +} + +func (r *Result) Logf(item string, format string, a ...interface{}) { + r.Logs[item] = append(r.Logs[item], fmt.Sprintf(format, a...)) +} diff --git a/rpc/migration/runner.go b/rpc/migration/runner.go new file mode 100644 index 00000000..abfee6a5 --- /dev/null +++ b/rpc/migration/runner.go @@ -0,0 +1,54 @@ +package migration + +import ( + "context" + "fmt" + + "github.com/0xsequence/waas-authenticator/config" + "github.com/0xsequence/waas-authenticator/data" + "github.com/0xsequence/waas-authenticator/proto" +) + +type Migration interface { + OnRegisterSession(ctx context.Context, account *data.Account) error + NextBatch(ctx context.Context, projectID uint64, page data.Page) ([]string, data.Page, error) + ProcessItems(ctx context.Context, tenant *proto.TenantData, items []string) (*Result, error) +} + +type Runner struct { + migrations map[proto.Migration]Migration +} + +func NewRunner(cfg config.MigrationsConfig, accounts *data.AccountTable) *Runner { + r := &Runner{ + migrations: make(map[proto.Migration]Migration), + } + if len(cfg.OIDCToStytch) > 0 { + m := &OIDCToStytch{ + accounts: accounts, + configs: make(map[uint64]config.OIDCToStytchConfig), + } + for _, mCfg := range cfg.OIDCToStytch { + m.configs[mCfg.SequenceProject] = mCfg + } + r.migrations[proto.Migration_OIDCToStytch] = m + } + return r +} + +func (r *Runner) OnRegisterSession(ctx context.Context, account *data.Account) error { + for _, m := range r.migrations { + if err := m.OnRegisterSession(ctx, account); err != nil { + return err + } + } + return nil +} + +func (r *Runner) Get(migration proto.Migration) (Migration, error) { + m, ok := r.migrations[migration] + if !ok { + return nil, fmt.Errorf("no migration found for %s", migration) + } + return m, nil +} diff --git a/rpc/migrations.go b/rpc/migrations.go new file mode 100644 index 00000000..05fbbeb3 --- /dev/null +++ b/rpc/migrations.go @@ -0,0 +1,66 @@ +package rpc + +import ( + "context" + "fmt" + + "github.com/0xsequence/waas-authenticator/data" + "github.com/0xsequence/waas-authenticator/proto" + "github.com/0xsequence/waas-authenticator/rpc/crypto" +) + +func (s *RPC) NextMigrationBatch(ctx context.Context, migration proto.Migration, projectId uint64, page *proto.Page) (*proto.Page, []string, error) { + m, err := s.Migrations.Get(migration) + if err != nil { + return nil, nil, err + } + + dbPage, err := data.PageFromProto(page) + if err != nil { + return nil, nil, err + } + + items, dbPage, err := m.NextBatch(ctx, projectId, dbPage) + if err != nil { + return nil, nil, err + } + + page, err = dbPage.ToProto() + if err != nil { + return nil, nil, err + } + + return page, items, nil +} + +func (s *RPC) ProcessMigrationBatch(ctx context.Context, migration proto.Migration, projectID uint64, items []string) (map[string][]string, map[string]string, error) { + m, err := s.Migrations.Get(migration) + if err != nil { + return nil, nil, err + } + + tenant, found, err := s.Tenants.GetLatest(ctx, projectID) + if err != nil { + return nil, nil, fmt.Errorf("could not retrieve tenant: %w", err) + } + if !found { + return nil, nil, fmt.Errorf("invalid tenant: %v", projectID) + } + + tntData, _, err := crypto.DecryptData[*proto.TenantData](ctx, tenant.EncryptedKey, tenant.Ciphertext, s.Config.KMS.TenantKeys) + if err != nil { + return nil, nil, fmt.Errorf("could not decrypt tenant data: %v", projectID) + } + + res, err := m.ProcessItems(ctx, tntData, items) + if err != nil { + return nil, nil, fmt.Errorf("could not process items: %w", err) + } + + itemErrors := make(map[string]string) + for item, err := range res.Errors { + itemErrors[item] = err.Error() + } + + return res.Logs, itemErrors, nil +} diff --git a/rpc/migrations_test.go b/rpc/migrations_test.go new file mode 100644 index 00000000..a9520a41 --- /dev/null +++ b/rpc/migrations_test.go @@ -0,0 +1,233 @@ +package rpc_test + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "strconv" + "testing" + "time" + + "github.com/0xsequence/ethkit/ethwallet" + "github.com/0xsequence/ethkit/go-ethereum/common/hexutil" + "github.com/0xsequence/ethkit/go-ethereum/crypto" + "github.com/0xsequence/go-sequence/intents" + "github.com/0xsequence/waas-authenticator/config" + "github.com/0xsequence/waas-authenticator/data" + "github.com/0xsequence/waas-authenticator/proto" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMigrationOIDCToStytch(t *testing.T) { + t.Run("WithoutConfig", func(t *testing.T) { + t.Run("NoContinuousMigration", func(t *testing.T) { + ctx := context.Background() + + exp := time.Now().Add(120 * time.Second) + tokBuilderFn := func(b *jwt.Builder, url string) { + b.Expiration(exp) + } + + issuer, tok, closeJWKS := issueAccessTokenAndRunJwksServer(t, tokBuilderFn) + defer closeJWKS() + + svc := initRPC(t, func(cfg *config.Config) { + cfg.Migrations.OIDCToStytch = []config.OIDCToStytchConfig{ + { + SequenceProject: currentProjectID.Load() + 1, + StytchProject: "TEST", + FromIssuer: "FAKE_ISSUER", + }, + } + }) + tenant, _ := newTenant(t, svc.Enclave, issuer) + require.NoError(t, svc.Tenants.Add(ctx, tenant)) + + sessWallet, err := ethwallet.NewWalletFromRandomEntropy() + require.NoError(t, err) + signingSession := intents.NewSessionP256K1(sessWallet) + + srv := httptest.NewServer(svc.Handler()) + defer srv.Close() + + c := proto.NewWaasAuthenticatorClient(srv.URL, http.DefaultClient) + header := make(http.Header) + header.Set("X-Sequence-Project", strconv.Itoa(int(tenant.ProjectID))) + ctx, err = proto.WithHTTPRequestHeaders(context.Background(), header) + require.NoError(t, err) + + hashedToken := hexutil.Encode(crypto.Keccak256([]byte(tok))) + verifier := hashedToken + ";" + strconv.Itoa(int(exp.Unix())) + initiateAuth := generateSignedIntent(t, intents.IntentName_initiateAuth, intents.IntentDataInitiateAuth{ + SessionID: signingSession.SessionID(), + IdentityType: intents.IdentityType_OIDC, + Verifier: verifier, + }, signingSession) + initRes, err := c.SendIntent(ctx, initiateAuth) + require.NoError(t, err) + assert.Equal(t, proto.IntentResponseCode_authInitiated, initRes.Code) + + b, err := json.Marshal(initRes.Data) + require.NoError(t, err) + var initResData intents.IntentResponseAuthInitiated + require.NoError(t, json.Unmarshal(b, &initResData)) + + registerSession := generateSignedIntent(t, intents.IntentName_openSession, intents.IntentDataOpenSession{ + SessionID: signingSession.SessionID(), + IdentityType: intents.IdentityType_OIDC, + Verifier: verifier, + Answer: tok, + }, signingSession) + sess, registerRes, err := c.RegisterSession(ctx, registerSession, "Friendly name") + require.NoError(t, err) + assert.Equal(t, "OIDC:"+issuer+"#subject", sess.Identity.String()) + assert.Equal(t, proto.IntentResponseCode_sessionOpened, registerRes.Code) + + accs, _, err := svc.Accounts.ListByProjectAndIdentity(ctx, data.Page{}, tenant.ProjectID, proto.IdentityType_Stytch, "") + require.NoError(t, err) + assert.Len(t, accs, 0) + }) + }) + + t.Run("ContinuousMigration", func(t *testing.T) { + ctx := context.Background() + + exp := time.Now().Add(120 * time.Second) + tokBuilderFn := func(b *jwt.Builder, url string) { + b.Expiration(exp) + } + + issuer, tok, closeJWKS := issueAccessTokenAndRunJwksServer(t, tokBuilderFn) + defer closeJWKS() + + svc := initRPC(t, func(cfg *config.Config) { + cfg.Migrations.OIDCToStytch = []config.OIDCToStytchConfig{ + { + SequenceProject: currentProjectID.Load() + 1, + StytchProject: "TEST", + FromIssuer: issuer, + }, + } + }) + tenant, _ := newTenant(t, svc.Enclave, issuer) + require.NoError(t, svc.Tenants.Add(ctx, tenant)) + + sessWallet, err := ethwallet.NewWalletFromRandomEntropy() + require.NoError(t, err) + signingSession := intents.NewSessionP256K1(sessWallet) + + srv := httptest.NewServer(svc.Handler()) + defer srv.Close() + + c := proto.NewWaasAuthenticatorClient(srv.URL, http.DefaultClient) + header := make(http.Header) + header.Set("X-Sequence-Project", strconv.Itoa(int(tenant.ProjectID))) + ctx, err = proto.WithHTTPRequestHeaders(context.Background(), header) + require.NoError(t, err) + + hashedToken := hexutil.Encode(crypto.Keccak256([]byte(tok))) + verifier := hashedToken + ";" + strconv.Itoa(int(exp.Unix())) + initiateAuth := generateSignedIntent(t, intents.IntentName_initiateAuth, intents.IntentDataInitiateAuth{ + SessionID: signingSession.SessionID(), + IdentityType: intents.IdentityType_OIDC, + Verifier: verifier, + }, signingSession) + initRes, err := c.SendIntent(ctx, initiateAuth) + require.NoError(t, err) + assert.Equal(t, proto.IntentResponseCode_authInitiated, initRes.Code) + + b, err := json.Marshal(initRes.Data) + require.NoError(t, err) + var initResData intents.IntentResponseAuthInitiated + require.NoError(t, json.Unmarshal(b, &initResData)) + + registerSession := generateSignedIntent(t, intents.IntentName_openSession, intents.IntentDataOpenSession{ + SessionID: signingSession.SessionID(), + IdentityType: intents.IdentityType_OIDC, + Verifier: verifier, + Answer: tok, + }, signingSession) + sess, registerRes, err := c.RegisterSession(ctx, registerSession, "Friendly name") + require.NoError(t, err) + assert.Equal(t, "OIDC:"+issuer+"#subject", sess.Identity.String()) + assert.Equal(t, proto.IntentResponseCode_sessionOpened, registerRes.Code) + + expectedIdentity := proto.Identity{ + Type: proto.IdentityType_Stytch, + Issuer: "TEST", + Subject: "subject", + } + accs, _, err := svc.Accounts.ListByProjectAndIdentity(ctx, data.Page{}, tenant.ProjectID, proto.IdentityType_Stytch, "") + require.NoError(t, err) + require.Len(t, accs, 1) + assert.Equal(t, expectedIdentity, proto.Identity(accs[0].Identity)) + assert.Equal(t, sess.UserID, accs[0].UserID) + }) + + t.Run("OneTimeMigration", func(t *testing.T) { + ctx := context.Background() + + issuer, _, closeJWKS := issueAccessTokenAndRunJwksServer(t) + defer closeJWKS() + + projectID := currentProjectID.Load() + 1 + svc := initRPC(t, func(cfg *config.Config) { + cfg.Migrations.OIDCToStytch = []config.OIDCToStytchConfig{ + { + SequenceProject: projectID, + StytchProject: "TEST", + FromIssuer: issuer, + }, + } + }) + tenant, _ := newTenant(t, svc.Enclave, issuer) + require.NoError(t, svc.Tenants.Add(ctx, tenant)) + require.Equal(t, projectID, tenant.ProjectID) + account := newAccount(t, tenant, svc.Enclave, newOIDCIdentity(issuer), nil) + require.NoError(t, svc.Accounts.Put(ctx, account)) + + // Add more accounts + for i := 0; i < 10; i++ { + acc := newAccount(t, tenant, svc.Enclave, newOIDCIdentity(issuer, fmt.Sprintf("acc-%d", i)), nil) + require.NoError(t, svc.Accounts.Put(ctx, acc)) + } + + srv := httptest.NewServer(svc.Handler()) + defer srv.Close() + + c := proto.NewWaasAuthenticatorAdminClient(srv.URL, http.DefaultClient) + header := make(http.Header) + header.Set("Authorization", "Bearer "+adminJWT) + ctx, err := proto.WithHTTPRequestHeaders(context.Background(), header) + require.NoError(t, err) + + _, items, err := c.NextMigrationBatch(ctx, proto.Migration_OIDCToStytch, tenant.ProjectID, nil) + require.NoError(t, err) + require.Len(t, items, 11) + + itemLogs, itemErrors, err := c.ProcessMigrationBatch(ctx, proto.Migration_OIDCToStytch, tenant.ProjectID, items) + require.NoError(t, err) + assert.Empty(t, itemLogs) + assert.Empty(t, itemErrors) + + // There should be now 2 accounts of the original user: original + stytch + resultAccounts, err := svc.Accounts.ListByUserID(ctx, account.UserID) + require.NoError(t, err) + require.Len(t, resultAccounts, 2) + + identities := make([]proto.Identity, len(resultAccounts)) + for i, acc := range resultAccounts { + identities[i] = proto.Identity(acc.Identity) + } + assert.Contains(t, identities, newStytchIdentity("TEST")) + + // 2 accounts of original user + 10 doubled (migrated) additional users + allAccounts, _, err := svc.Accounts.ListByProjectAndIdentity(ctx, data.Page{}, tenant.ProjectID, proto.IdentityType_None, "") + require.NoError(t, err) + require.Len(t, allAccounts, 22) + }) +} diff --git a/rpc/rpc.go b/rpc/rpc.go index a86a37ec..92986561 100644 --- a/rpc/rpc.go +++ b/rpc/rpc.go @@ -24,6 +24,7 @@ import ( "github.com/0xsequence/waas-authenticator/rpc/auth/oidc" "github.com/0xsequence/waas-authenticator/rpc/auth/playfab" "github.com/0xsequence/waas-authenticator/rpc/awscreds" + "github.com/0xsequence/waas-authenticator/rpc/migration" "github.com/0xsequence/waas-authenticator/rpc/signing" "github.com/0xsequence/waas-authenticator/rpc/tenant" "github.com/0xsequence/waas-authenticator/rpc/tracing" @@ -67,6 +68,8 @@ type RPC struct { AuthProviders map[intents.IdentityType]auth.Provider Signer signing.Signer + Migrations *migration.Runner + measurements *enclave.Measurements startTime time.Time running int32 @@ -161,6 +164,7 @@ func New(cfg *config.Config, client *http.Client) (*RPC, error) { startTime: time.Now(), measurements: m, } + s.Migrations = migration.NewRunner(cfg.Migrations, s.Accounts) return s, nil } diff --git a/rpc/sessions.go b/rpc/sessions.go index 6b37fab1..4f986947 100644 --- a/rpc/sessions.go +++ b/rpc/sessions.go @@ -162,7 +162,8 @@ func (s *RPC) RegisterSession( return nil, nil, proto.ErrWebrpcInternalError.WithCausef("registering session with WaaS API: %w", err) } - if !accountFound { + if !accountFound || (ident.Email != "" && account.Email != ident.Email) { + account.Email = ident.Email if err := s.Accounts.Put(ctx, account); err != nil { return nil, nil, proto.ErrWebrpcInternalError.WithCausef("save account: %w", err) } @@ -200,6 +201,10 @@ func (s *RPC) RegisterSession( return nil, convertIntentResponse(res), proto.ErrWebrpcInternalError.WithCausef("save session: %w", err) } + if err := s.Migrations.OnRegisterSession(ctx, account); err != nil { + return nil, convertIntentResponse(res), proto.ErrWebrpcInternalError.WithCausef("migrate account: %w", err) + } + retSess := &proto.Session{ ID: dbSess.ID, UserID: dbSess.UserID,