diff --git a/crates/cli/src/commands/manage.rs b/crates/cli/src/commands/manage.rs index 4e40ff7ab..b8d6c6a52 100644 --- a/crates/cli/src/commands/manage.rs +++ b/crates/cli/src/commands/manage.rs @@ -28,7 +28,7 @@ use mas_matrix_synapse::SynapseConnection; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, job::{ - DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob, + DeactivateUserJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob, }, user::{UserEmailRepository, UserPasswordRepository, UserRepository}, Clock, RepositoryAccess, SystemClock, @@ -370,10 +370,6 @@ impl Options { if dry_run { continue; } - - let job = DeleteDeviceJob::new(&user, &compat_session.device); - repo.job().schedule_job(job).await?; - repo.compat_session().finish(&clock, compat_session).await?; } let oauth2_sessions_ids: Vec = sqlx::query_scalar( @@ -400,16 +396,6 @@ impl Options { if dry_run { continue; } - - for scope in &*oauth2_session.scope { - if let Some(device) = Device::from_scope_token(scope) { - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &device)) - .await?; - } - } - repo.oauth2_session().finish(&clock, oauth2_session).await?; } @@ -441,6 +427,10 @@ impl Options { .await?; } + // Schedule a job to sync the devices of the user with the homeserver + warn!("Scheduling job to sync devices for the user"); + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + let txn = repo.into_inner(); if dry_run { info!("Dry run, not saving"); diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index 749314ee4..3e83a260c 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -26,7 +26,6 @@ use mas_storage::{ CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository, CompatSsoLoginRepository, }, - job::{JobRepositoryExt, ProvisionDeviceJob}, user::{UserPasswordRepository, UserRepository}, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; @@ -168,6 +167,9 @@ pub enum RouteError { #[error("invalid login token")] InvalidLoginToken, + + #[error("failed to provision device")] + ProvisionDeviceFailed(#[source] anyhow::Error), } impl_from_error_for_route!(mas_storage::RepositoryError); @@ -176,11 +178,13 @@ impl IntoResponse for RouteError { fn into_response(self) -> axum::response::Response { let event_id = sentry::capture_error(&self); let response = match self { - Self::Internal(_) | Self::SessionNotFound => MatrixError { - errcode: "M_UNKNOWN", - error: "Internal server error", - status: StatusCode::INTERNAL_SERVER_ERROR, - }, + Self::Internal(_) | Self::SessionNotFound | Self::ProvisionDeviceFailed(_) => { + MatrixError { + errcode: "M_UNKNOWN", + error: "Internal server error", + status: StatusCode::INTERNAL_SERVER_ERROR, + } + } Self::Unsupported => MatrixError { errcode: "M_UNRECOGNIZED", error: "Invalid login type", @@ -235,6 +239,7 @@ pub(crate) async fn post( &clock, &password_manager, &mut repo, + &homeserver, user, password, ) @@ -368,6 +373,7 @@ async fn user_password_login( clock: &impl Clock, password_manager: &PasswordManager, repo: &mut BoxRepository, + homeserver: &BoxHomeserverConnection, username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { @@ -413,11 +419,16 @@ async fn user_password_login( .await?; } + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + // Now that the user credentials have been verified, start a new compat session let device = Device::generate(&mut rng); - repo.job() - .schedule_job(ProvisionDeviceJob::new(&user, &device)) - .await?; + let mxid = homeserver.mxid(&user.username); + homeserver + .create_device(&mxid, device.as_str()) + .await + .map_err(RouteError::ProvisionDeviceFailed)?; let session = repo .compat_session() @@ -430,6 +441,7 @@ async fn user_password_login( #[cfg(test)] mod tests { use hyper::Request; + use mas_matrix::{HomeserverConnection, ProvisionRequest}; use rand::distributions::{Alphanumeric, DistString}; use sqlx::PgPool; @@ -532,6 +544,13 @@ mod tests { .await .unwrap(); + let mxid = state.homeserver_connection.mxid(&user.username); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(mxid, &user.sub)) + .await + .unwrap(); + let (version, hashed_password) = state .password_manager .hash( @@ -664,6 +683,13 @@ mod tests { .unwrap(); repo.save().await.unwrap(); + let mxid = state.homeserver_connection.mxid(&user.username); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(mxid, &user.sub)) + .await + .unwrap(); + // First try with an invalid token let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ "type": "m.login.token", diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 728ed073f..79eaa2994 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -27,10 +27,10 @@ use mas_axum_utils::{ FancyError, SessionInfoExt, }; use mas_data_model::Device; +use mas_matrix::BoxHomeserverConnection; use mas_router::{CompatLoginSsoAction, PostAuthAction, UrlBuilder}; use mas_storage::{ compat::{CompatSessionRepository, CompatSsoLoginRepository}, - job::{JobRepositoryExt, ProvisionDeviceJob}, BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess, }; use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates}; @@ -136,6 +136,7 @@ pub async fn post( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, + State(homeserver): State, cookie_jar: CookieJar, Path(id): Path, Query(params): Query, @@ -201,10 +202,15 @@ pub async fn post( redirect_uri }; + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&session.user).await?; + let device = Device::generate(&mut rng); - repo.job() - .schedule_job(ProvisionDeviceJob::new(&session.user, &device)) - .await?; + let mxid = homeserver.mxid(&session.user.username); + homeserver + .create_device(&mxid, device.as_str()) + .await + .context("Failed to provision device")?; let compat_session = repo .compat_session() diff --git a/crates/handlers/src/compat/logout.rs b/crates/handlers/src/compat/logout.rs index 55fb20664..76000aef2 100644 --- a/crates/handlers/src/compat/logout.rs +++ b/crates/handlers/src/compat/logout.rs @@ -20,7 +20,7 @@ use mas_axum_utils::sentry::SentryEventID; use mas_data_model::TokenType; use mas_storage::{ compat::{CompatAccessTokenRepository, CompatSessionRepository}, - job::{DeleteDeviceJob, JobRepositoryExt}, + job::{JobRepositoryExt, SyncDevicesJob}, BoxClock, BoxRepository, Clock, RepositoryAccess, }; use thiserror::Error; @@ -111,9 +111,8 @@ pub(crate) async fn post( // XXX: this is probably not the right error .ok_or(RouteError::InvalidAuthorization)?; - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &session.device)) - .await?; + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; repo.compat_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/compat_session.rs b/crates/handlers/src/graphql/mutations/compat_session.rs index 49c57735a..9c8034b98 100644 --- a/crates/handlers/src/graphql/mutations/compat_session.rs +++ b/crates/handlers/src/graphql/mutations/compat_session.rs @@ -16,7 +16,7 @@ use anyhow::Context as _; use async_graphql::{Context, Enum, InputObject, Object, ID}; use mas_storage::{ compat::CompatSessionRepository, - job::{DeleteDeviceJob, JobRepositoryExt}, + job::{JobRepositoryExt, SyncDevicesJob}, RepositoryAccess, }; @@ -101,10 +101,8 @@ impl CompatSessionMutations { .await? .context("Could not load user")?; - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &session.device)) - .await?; + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; let session = repo.compat_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/mutations/oauth2_session.rs b/crates/handlers/src/graphql/mutations/oauth2_session.rs index 0acc137ba..f7fc6c3af 100644 --- a/crates/handlers/src/graphql/mutations/oauth2_session.rs +++ b/crates/handlers/src/graphql/mutations/oauth2_session.rs @@ -17,7 +17,7 @@ use async_graphql::{Context, Description, Enum, InputObject, Object, ID}; use chrono::Duration; use mas_data_model::{Device, TokenType}; use mas_storage::{ - job::{DeleteDeviceJob, JobRepositoryExt, ProvisionDeviceJob}, + job::{JobRepositoryExt, SyncDevicesJob}, oauth2::{ OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, @@ -129,6 +129,7 @@ impl OAuth2SessionMutations { input: CreateOAuth2SessionInput, ) -> Result { let state = ctx.state(); + let homeserver = state.homeserver_connection(); let user_id = NodeType::User.extract_ulid(&input.user_id)?; let scope: Scope = input.scope.parse().context("Invalid scope")?; let permanent = input.permanent.unwrap_or(false); @@ -167,12 +168,17 @@ impl OAuth2SessionMutations { .add(&mut rng, &clock, &client, Some(&user), None, scope) .await?; + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + // Look for devices to provision + let mxid = homeserver.mxid(&user.username); for scope in &*session.scope { if let Some(device) = Device::from_scope_token(scope) { - repo.job() - .schedule_job(ProvisionDeviceJob::new(&user, &device)) - .await?; + homeserver + .create_device(&mxid, device.as_str()) + .await + .context("Failed to provision device")?; } } @@ -236,20 +242,8 @@ impl OAuth2SessionMutations { .await? .context("Could not load user")?; - // Scan the scopes of the session to find if there is any device that should be - // deleted from the Matrix server. - // TODO: this should be moved in a higher level "end oauth session" method. - // XXX: this might not be the right semantic, but it's the best we - // can do for now, since we're not explicitly storing devices for OAuth2 - // sessions. - for scope in &*session.scope { - if let Some(device) = Device::from_scope_token(scope) { - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &device)) - .await?; - } - } + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; } let session = repo.oauth2_session().finish(&clock, session).await?; diff --git a/crates/handlers/src/graphql/tests.rs b/crates/handlers/src/graphql/tests.rs index 463abb61d..23e083ac1 100644 --- a/crates/handlers/src/graphql/tests.rs +++ b/crates/handlers/src/graphql/tests.rs @@ -15,6 +15,7 @@ use axum::http::Request; use hyper::StatusCode; use mas_data_model::{AccessToken, Client, TokenType, User}; +use mas_matrix::{HomeserverConnection, ProvisionRequest}; use mas_router::SimpleRoute; use mas_storage::{ oauth2::{OAuth2AccessTokenRepository, OAuth2ClientRepository}, @@ -517,7 +518,7 @@ async fn test_oauth2_client_credentials(pool: PgPool) { response.assert_status(StatusCode::OK); let response: GraphQLResponse = response.json(); assert!(response.errors.is_empty(), "{:?}", response.errors); - let user_id = &response.data["addUser"]["user"]["id"]; + let user_id = response.data["addUser"]["user"]["id"].as_str().unwrap(); assert_eq!( response.data, @@ -531,6 +532,16 @@ async fn test_oauth2_client_credentials(pool: PgPool) { }) ); + // XXX: we don't run the task worker here, so even though the addUser mutation + // should have scheduled a job to provision the user, it won't run in the test, + // so we need to do it manually + let mxid = state.homeserver_connection.mxid("alice"); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(mxid, user_id)) + .await + .unwrap(); + // We should now be able to create an arbitrary access token for the user let request = Request::post("/graphql") .bearer(&access_token) diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index bc7def815..74c681504 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -186,6 +186,7 @@ where Encrypter: FromRef, HttpClientFactory: FromRef, SiteConfig: FromRef, + BoxHomeserverConnection: FromRef, BoxClock: FromRequestParts, BoxRng: FromRequestParts, Policy: FromRequestParts, diff --git a/crates/handlers/src/oauth2/introspection.rs b/crates/handlers/src/oauth2/introspection.rs index 44eff0cdc..e468314cd 100644 --- a/crates/handlers/src/oauth2/introspection.rs +++ b/crates/handlers/src/oauth2/introspection.rs @@ -461,6 +461,7 @@ mod tests { use hyper::{Request, StatusCode}; use mas_data_model::{AccessToken, RefreshToken}; use mas_iana::oauth::OAuthTokenTypeHint; + use mas_matrix::{HomeserverConnection, ProvisionRequest}; use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute}; use mas_storage::Clock; use oauth2_types::{ @@ -518,6 +519,13 @@ mod tests { .await .unwrap(); + let mxid = state.homeserver_connection.mxid(&user.username); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(mxid, &user.sub)) + .await + .unwrap(); + let client = repo .oauth2_client() .find_by_client_id(&client_id) @@ -703,6 +711,13 @@ mod tests { .await .unwrap(); + let mxid = state.homeserver_connection.mxid(&user.username); + state + .homeserver_connection + .provision_user(&ProvisionRequest::new(mxid, &user.sub)) + .await + .unwrap(); + let (version, hashed_password) = state .password_manager .hash(&mut state.rng(), Zeroizing::new(b"password".to_vec())) diff --git a/crates/handlers/src/oauth2/revoke.rs b/crates/handlers/src/oauth2/revoke.rs index ff725df32..21d330232 100644 --- a/crates/handlers/src/oauth2/revoke.rs +++ b/crates/handlers/src/oauth2/revoke.rs @@ -19,11 +19,11 @@ use mas_axum_utils::{ http_client_factory::HttpClientFactory, sentry::SentryEventID, }; -use mas_data_model::{Device, TokenType}; +use mas_data_model::TokenType; use mas_iana::oauth::OAuthTokenTypeHint; use mas_keystore::Encrypter; use mas_storage::{ - job::{DeleteDeviceJob, JobRepositoryExt}, + job::{JobRepositoryExt, SyncDevicesJob}, BoxClock, BoxRepository, RepositoryAccess, }; use oauth2_types::{ @@ -217,20 +217,8 @@ pub(crate) async fn post( .await? .ok_or(RouteError::UnknownToken)?; - // Scan the scopes of the session to find if there is any device that should be - // deleted from the Matrix server. - // TODO: this should be moved in a higher level "end oauth session" method. - // XXX: this might not be the right semantic, but it's the best we - // can do for now, since we're not explicitly storing devices for OAuth2 - // sessions. - for scope in &*session.scope { - if let Some(device) = Device::from_scope_token(scope) { - // Schedule a job to delete the device. - repo.job() - .schedule_job(DeleteDeviceJob::new(&user, &device)) - .await?; - } - } + // Schedule a job to sync the devices of the user with the homeserver + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; } // Now that we checked everything, we can end the session. diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 79e95a4b8..8ef50837f 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -26,11 +26,11 @@ use mas_data_model::{ AuthorizationGrantStage, Client, Device, DeviceCodeGrantState, SiteConfig, TokenType, UserAgent, }; use mas_keystore::{Encrypter, Keystore}; +use mas_matrix::BoxHomeserverConnection; use mas_oidc_client::types::scope::ScopeToken; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{ - job::{JobRepositoryExt, ProvisionDeviceJob}, oauth2::{ OAuth2AccessTokenRepository, OAuth2AuthorizationGrantRepository, OAuth2RefreshTokenRepository, OAuth2SessionRepository, @@ -118,6 +118,9 @@ pub(crate) enum RouteError { #[error("device code grant was already exchanged")] DeviceCodeExchanged, + + #[error("failed to provision device")] + ProvisionDeviceFailed(#[source] anyhow::Error), } impl IntoResponse for RouteError { @@ -125,7 +128,10 @@ impl IntoResponse for RouteError { let event_id = sentry::capture_error(&self); let response = match self { - Self::Internal(_) | Self::NoSuchBrowserSession | Self::NoSuchOAuthSession => ( + Self::Internal(_) + | Self::NoSuchBrowserSession + | Self::NoSuchOAuthSession + | Self::ProvisionDeviceFailed(_) => ( StatusCode::INTERNAL_SERVER_ERROR, Json(ClientError::from(ClientErrorCode::ServerError)), ), @@ -210,6 +216,7 @@ pub(crate) async fn post( State(url_builder): State, activity_tracker: BoundActivityTracker, mut repo: BoxRepository, + State(homeserver): State, State(site_config): State, State(encrypter): State, policy: Policy, @@ -247,6 +254,7 @@ pub(crate) async fn post( &url_builder, &site_config, repo, + &homeserver, user_agent, ) .await? @@ -289,6 +297,7 @@ pub(crate) async fn post( &url_builder, &site_config, repo, + &homeserver, user_agent, ) .await? @@ -318,6 +327,7 @@ async fn authorization_code_grant( url_builder: &UrlBuilder, site_config: &SiteConfig, mut repo: BoxRepository, + homeserver: &BoxHomeserverConnection, user_agent: Option, ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { // Check that the client is allowed to use this grant type @@ -451,17 +461,19 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } + // Lock the user sync to make sure we don't get into a race condition + repo.user() + .acquire_lock_for_sync(&browser_session.user) + .await?; + // Look for device to provision + let mxid = homeserver.mxid(&browser_session.user.username); for scope in &*session.scope { if let Some(device) = Device::from_scope_token(scope) { - // Note that we're not waiting for the job to finish, we just schedule it. We - // might get in a situation where the provisioning job is not finished when the - // client does its first request to the Homeserver. This is fine for now, since - // Synapse still provision devices on-the-fly if it doesn't find them in the - // database. - repo.job() - .schedule_job(ProvisionDeviceJob::new(&browser_session.user, &device)) - .await?; + homeserver + .create_device(&mxid, device.as_str()) + .await + .map_err(RouteError::ProvisionDeviceFailed)?; } } @@ -639,6 +651,7 @@ async fn device_code_grant( url_builder: &UrlBuilder, site_config: &SiteConfig, mut repo: BoxRepository, + homeserver: &BoxHomeserverConnection, user_agent: Option, ) -> Result<(AccessTokenResponse, BoxRepository), RouteError> { // Check that the client is allowed to use this grant type @@ -740,17 +753,19 @@ async fn device_code_grant( params = params.with_id_token(id_token); } + // Lock the user sync to make sure we don't get into a race condition + repo.user() + .acquire_lock_for_sync(&browser_session.user) + .await?; + // Look for device to provision + let mxid = homeserver.mxid(&browser_session.user.username); for scope in &*session.scope { if let Some(device) = Device::from_scope_token(scope) { - // Note that we're not waiting for the job to finish, we just schedule it. We - // might get in a situation where the provisioning job is not finished when the - // client does its first request to the Homeserver. This is fine for now, since - // Synapse still provision devices on-the-fly if it doesn't find them in the - // database. - repo.job() - .schedule_job(ProvisionDeviceJob::new(&browser_session.user, &device)) - .await?; + homeserver + .create_device(&mxid, device.as_str()) + .await + .map_err(RouteError::ProvisionDeviceFailed)?; } } diff --git a/crates/matrix-synapse/src/lib.rs b/crates/matrix-synapse/src/lib.rs index d35bd9b3b..ecc01e2b7 100644 --- a/crates/matrix-synapse/src/lib.rs +++ b/crates/matrix-synapse/src/lib.rs @@ -14,6 +14,8 @@ #![allow(clippy::blocks_in_conditions)] +use std::collections::HashSet; + use anyhow::{bail, Context}; use http::{header::AUTHORIZATION, request::Builder, Method, Request, StatusCode}; use mas_axum_utils::http_client_factory::HttpClientFactory; @@ -134,9 +136,19 @@ struct SynapseUser { deactivated: Option, } +#[derive(Deserialize)] +struct SynapseDeviceListResponse { + devices: Vec, +} + +#[derive(Serialize, Deserialize)] +struct SynapseDevice { + device_id: String, +} + #[derive(Serialize)] -struct SynapseDevice<'a> { - device_id: &'a str, +struct SynapseDeleteDevicesRequest { + devices: Vec, } #[derive(Serialize)] @@ -360,7 +372,9 @@ impl HomeserverConnection for SynapseConnection { let request = self .post(&format!("_synapse/admin/v2/users/{mxid}/devices")) - .body(SynapseDevice { device_id })?; + .body(SynapseDevice { + device_id: device_id.to_owned(), + })?; let response = client .ready() @@ -415,6 +429,82 @@ impl HomeserverConnection for SynapseConnection { Ok(()) } + #[tracing::instrument( + name = "homeserver.sync_devices", + skip_all, + fields( + matrix.homeserver = self.homeserver, + matrix.mxid = mxid, + ), + err(Debug), + )] + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + // Get the list of current devices + let mxid_url = urlencoding::encode(mxid); + let mut client = self + .http_client_factory + .client("homeserver.sync_devices.query") + .response_body_to_bytes() + .catch_http_errors(catch_homeserver_error) + .json_response(); + + let request = self + .get(&format!("_synapse/admin/v2/users/{mxid_url}/devices")) + .body(EmptyBody::new())?; + + let response = client + .ready() + .await? + .call(request) + .await + .context("Failed to query user from Synapse")?; + + if response.status() != StatusCode::OK { + return Err(anyhow::anyhow!("Failed to query user devices from Synapse")); + } + + let body: SynapseDeviceListResponse = response.into_body(); + + let existing_devices: HashSet = + body.devices.into_iter().map(|d| d.device_id).collect(); + + // First, delete all the devices that are not needed anymore + let to_delete = existing_devices.difference(&devices).cloned().collect(); + + let mut client = self + .http_client_factory + .client("homeserver.sync_devices.delete") + .response_body_to_bytes() + .catch_http_errors(catch_homeserver_error) + .request_bytes_to_body() + .json_request(); + + let request = self + .post(&format!( + "_synapse/admin/v2/users/{mxid_url}/delete_devices" + )) + .body(SynapseDeleteDevicesRequest { devices: to_delete })?; + + let response = client + .ready() + .await? + .call(request) + .await + .context("Failed to query user from Synapse")?; + + if response.status() != StatusCode::OK { + return Err(anyhow::anyhow!("Failed to delete devices from Synapse")); + } + + // Then, create the devices that are missing. There is no batching API to do + // this, so we do this sequentially, which is fine as the API is idempotent. + for device_id in devices.difference(&existing_devices) { + self.create_device(mxid, device_id).await?; + } + + Ok(()) + } + #[tracing::instrument( name = "homeserver.delete_user", skip_all, diff --git a/crates/matrix/src/lib.rs b/crates/matrix/src/lib.rs index 0cdea5030..779c263f6 100644 --- a/crates/matrix/src/lib.rs +++ b/crates/matrix/src/lib.rs @@ -14,7 +14,7 @@ mod mock; -use std::sync::Arc; +use std::{collections::HashSet, sync::Arc}; pub use self::mock::HomeserverConnection as MockHomeserverConnection; @@ -263,6 +263,19 @@ pub trait HomeserverConnection: Send + Sync { /// not be deleted. async fn delete_device(&self, mxid: &str, device_id: &str) -> Result<(), Self::Error>; + /// Sync the list of devices of a user with the homeserver. + /// + /// # Parameters + /// + /// * `mxid` - The Matrix ID of the user to sync the devices for. + /// * `devices` - The list of devices to sync. + /// + /// # Errors + /// + /// Returns an error if the homeserver is unreachable or the devices could + /// not be synced. + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error>; + /// Delete a user on the homeserver. /// /// # Parameters @@ -354,6 +367,10 @@ impl HomeserverConnection for &T (**self).delete_device(mxid, device_id).await } + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + (**self).sync_devices(mxid, devices).await + } + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { (**self).delete_user(mxid, erase).await } @@ -404,6 +421,10 @@ impl HomeserverConnection for Arc { (**self).delete_device(mxid, device_id).await } + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + (**self).sync_devices(mxid, devices).await + } + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { (**self).delete_user(mxid, erase).await } diff --git a/crates/matrix/src/mock.rs b/crates/matrix/src/mock.rs index 10389b5ea..d834dec01 100644 --- a/crates/matrix/src/mock.rs +++ b/crates/matrix/src/mock.rs @@ -131,6 +131,13 @@ impl crate::HomeserverConnection for HomeserverConnection { Ok(()) } + async fn sync_devices(&self, mxid: &str, devices: HashSet) -> Result<(), Self::Error> { + let mut users = self.users.write().await; + let user = users.get_mut(mxid).context("User not found")?; + user.devices = devices; + Ok(()) + } + async fn delete_user(&self, mxid: &str, erase: bool) -> Result<(), Self::Error> { let mut users = self.users.write().await; let user = users.get_mut(mxid).context("User not found")?; diff --git a/crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json b/crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json new file mode 100644 index 000000000..aa1737850 --- /dev/null +++ b/crates/storage-pg/.sqlx/query-e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT pg_advisory_xact_lock($1)\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "pg_advisory_xact_lock", + "type_info": "Void" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + null + ] + }, + "hash": "e68a7084d44462d19f30902d7e6c1bd60bb771c6f075df15ab0137a7ffc896da" +} diff --git a/crates/storage-pg/src/user/mod.rs b/crates/storage-pg/src/user/mod.rs index a4675f2f2..377182822 100644 --- a/crates/storage-pg/src/user/mod.rs +++ b/crates/storage-pg/src/user/mod.rs @@ -437,4 +437,38 @@ impl<'c> UserRepository for PgUserRepository<'c> { .try_into() .map_err(DatabaseError::to_invalid_operation) } + + #[tracing::instrument( + name = "db.user.acquire_lock_for_sync", + skip_all, + fields( + db.statement, + user.id = %user.id, + ), + err, + )] + async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error> { + // XXX: this lock isn't stictly scoped to users, but as we don't use many + // postgres advisory locks, it's fine for now. Later on, we could use row-level + // locks to make sure we don't get into trouble + + // Convert the user ID to a u128 and grab the lower 64 bits + // As this includes 64bit of the random part of the ULID, it should be random + // enough to not collide + let lock_id = (u128::from(user.id) & 0xffff_ffff_ffff_ffff) as i64; + + // Use a PG advisory lock, which will be released when the transaction is + // committed or rolled back + sqlx::query!( + r#" + SELECT pg_advisory_xact_lock($1) + "#, + lock_id, + ) + .traced() + .execute(&mut *self.conn) + .await?; + + Ok(()) + } } diff --git a/crates/storage/src/job.rs b/crates/storage/src/job.rs index ddc805ff4..1ac66fb94 100644 --- a/crates/storage/src/job.rs +++ b/crates/storage/src/job.rs @@ -327,6 +327,9 @@ mod jobs { } /// A job to provision a device for a user on the homeserver. + /// + /// This job is deprecated, use the `SyncDevicesJob` instead. It is kept to + /// not break existing jobs in the database. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ProvisionDeviceJob { user_id: Ulid, @@ -334,15 +337,6 @@ mod jobs { } impl ProvisionDeviceJob { - /// Create a new job to provision a device for a user on the homeserver. - #[must_use] - pub fn new(user: &User, device: &Device) -> Self { - Self { - user_id: user.id, - device_id: device.as_str().to_owned(), - } - } - /// The ID of the user to provision the device for. #[must_use] pub fn user_id(&self) -> Ulid { @@ -361,6 +355,9 @@ mod jobs { } /// A job to delete a device for a user on the homeserver. + /// + /// This job is deprecated, use the `SyncDevicesJob` instead. It is kept to + /// not break existing jobs in the database. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct DeleteDeviceJob { user_id: Ulid, @@ -394,6 +391,31 @@ mod jobs { const NAME: &'static str = "delete-device"; } + /// A job which syncs the list of devices of a user with the homeserver + #[derive(Serialize, Deserialize, Debug, Clone)] + pub struct SyncDevicesJob { + user_id: Ulid, + } + + impl SyncDevicesJob { + /// Create a new job to sync the list of devices of a user with the + /// homeserver + #[must_use] + pub fn new(user: &User) -> Self { + Self { user_id: user.id } + } + + /// The ID of the user to sync the devices for + #[must_use] + pub fn user_id(&self) -> Ulid { + self.user_id + } + } + + impl Job for SyncDevicesJob { + const NAME: &'static str = "sync-devices"; + } + /// A job to deactivate and lock a user #[derive(Serialize, Deserialize, Debug, Clone)] pub struct DeactivateUserJob { @@ -496,5 +518,5 @@ mod jobs { pub use self::jobs::{ DeactivateUserJob, DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, ReactivateUserJob, - SendAccountRecoveryEmailsJob, VerifyEmailJob, + SendAccountRecoveryEmailsJob, SyncDevicesJob, VerifyEmailJob, }; diff --git a/crates/storage/src/user/mod.rs b/crates/storage/src/user/mod.rs index 9763afedc..6e696c06c 100644 --- a/crates/storage/src/user/mod.rs +++ b/crates/storage/src/user/mod.rs @@ -259,6 +259,19 @@ pub trait UserRepository: Send + Sync { /// /// Returns [`Self::Error`] if the underlying repository fails async fn count(&mut self, filter: UserFilter<'_>) -> Result; + + /// Acquire a lock on the user to make sure device operations are done in a + /// sequential way. The lock is released when the repository is saved or + /// rolled back. + /// + /// # Parameters + /// + /// * `user`: The user to lock + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error>; } repository_impl!(UserRepository: @@ -284,4 +297,5 @@ repository_impl!(UserRepository: pagination: Pagination, ) -> Result, Self::Error>; async fn count(&mut self, filter: UserFilter<'_>) -> Result; + async fn acquire_lock_for_sync(&mut self, user: &User) -> Result<(), Self::Error>; ); diff --git a/crates/tasks/src/matrix.rs b/crates/tasks/src/matrix.rs index a3c993094..7f10ed096 100644 --- a/crates/tasks/src/matrix.rs +++ b/crates/tasks/src/matrix.rs @@ -1,4 +1,4 @@ -// Copyright 2023 The Matrix.org Foundation C.I.C. +// Copyright 2023, 2024 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,13 +12,21 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::collections::HashSet; + use anyhow::Context; use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor}; +use mas_data_model::Device; use mas_matrix::ProvisionRequest; use mas_storage::{ - job::{DeleteDeviceJob, JobWithSpanContext, ProvisionDeviceJob, ProvisionUserJob}, + compat::CompatSessionFilter, + job::{ + DeleteDeviceJob, JobRepositoryExt as _, JobWithSpanContext, ProvisionDeviceJob, + ProvisionUserJob, SyncDevicesJob, + }, + oauth2::OAuth2SessionFilter, user::{UserEmailRepository, UserRepository}, - RepositoryAccess, + Pagination, RepositoryAccess, }; use tracing::info; @@ -56,9 +64,6 @@ async fn provision_user( .filter(|email| email.confirmed_at.is_some()) .map(|email| email.email) .collect(); - - repo.cancel().await?; - let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails); if let Some(display_name) = job.display_name_to_set() { @@ -73,12 +78,18 @@ async fn provision_user( info!(%user.id, %mxid, "User updated"); } + // Schedule a device sync job + let sync_device_job = SyncDevicesJob::new(&user); + repo.job().schedule_job(sync_device_job).await?; + + repo.save().await?; + Ok(()) } /// Job to provision a device on the Matrix homeserver. -/// This works by doing a POST request to the -/// /_synapse/admin/v2/users/{user_id}/devices endpoint. +/// +/// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] #[tracing::instrument( name = "job.provision_device" fields( @@ -93,7 +104,6 @@ async fn provision_device( ctx: JobContext, ) -> Result<(), anyhow::Error> { let state = ctx.state(); - let matrix = state.matrix_connection(); let mut repo = state.repository().await?; let user = repo @@ -102,17 +112,15 @@ async fn provision_device( .await? .context("User not found")?; - let mxid = matrix.mxid(&user.username); - - matrix.create_device(&mxid, job.device_id()).await?; - info!(%user.id, %mxid, device.id = job.device_id(), "Device created"); + // Schedule a device sync job + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; Ok(()) } /// Job to delete a device from a user's account. -/// This works by doing a DELETE request to the -/// /_synapse/admin/v2/users/{user_id}/devices/{device_id} endpoint. +/// +/// This job is deprecated and therefore just schedules a [`SyncDevicesJob`] #[tracing::instrument( name = "job.delete_device" fields( @@ -125,6 +133,32 @@ async fn provision_device( async fn delete_device( job: JobWithSpanContext, ctx: JobContext, +) -> Result<(), anyhow::Error> { + let state = ctx.state(); + let mut repo = state.repository().await?; + + let user = repo + .user() + .lookup(job.user_id()) + .await? + .context("User not found")?; + + // Schedule a device sync job + repo.job().schedule_job(SyncDevicesJob::new(&user)).await?; + + Ok(()) +} + +/// Job to sync the list of devices of a user with the homeserver. +#[tracing::instrument( + name = "job.sync_devices", + fields(user.id = %job.user_id()), + skip_all, + err(Debug), +)] +async fn sync_devices( + job: JobWithSpanContext, + ctx: JobContext, ) -> Result<(), anyhow::Error> { let state = ctx.state(); let matrix = state.matrix_connection(); @@ -136,10 +170,64 @@ async fn delete_device( .await? .context("User not found")?; + // Lock the user sync to make sure we don't get into a race condition + repo.user().acquire_lock_for_sync(&user).await?; + + let mut devices = HashSet::new(); + + // Cycle through all the compat sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .compat_session() + .list( + CompatSessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for (compat_session, _) in page.edges { + devices.insert(compat_session.device.as_str().to_owned()); + cursor = cursor.after(compat_session.id); + } + + if !page.has_next_page { + break; + } + } + + // Cycle though all the oauth2 sessions of the user, and grab the devices + let mut cursor = Pagination::first(100); + loop { + let page = repo + .oauth2_session() + .list( + OAuth2SessionFilter::new().for_user(&user).active_only(), + cursor, + ) + .await?; + + for oauth2_session in page.edges { + for scope in &*oauth2_session.scope { + if let Some(device) = Device::from_scope_token(scope) { + devices.insert(device.as_str().to_owned()); + } + } + + cursor = cursor.after(oauth2_session.id); + } + + if !page.has_next_page { + break; + } + } + let mxid = matrix.mxid(&user.username); + matrix.sync_devices(&mxid, devices).await?; - matrix.delete_device(&mxid, job.device_id()).await?; - info!(%user.id, %mxid, device.id = job.device_id(), "Device deleted"); + // We kept the connection until now, so that we still hold the lock on the user + // throughout the sync + repo.save().await?; Ok(()) } @@ -156,9 +244,12 @@ pub(crate) fn register( crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory); let delete_device_worker = crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory); + let sync_devices_worker = + crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory); monitor .register(provision_user_worker) .register(provision_device_worker) .register(delete_device_worker) + .register(sync_devices_worker) }