Skip to content
This repository has been archived by the owner on Sep 10, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/main' into quenting/user-reactiv…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
sandhose committed Jul 16, 2024
2 parents a319976 + 3eab106 commit 3f30498
Show file tree
Hide file tree
Showing 19 changed files with 465 additions and 121 deletions.
20 changes: 5 additions & 15 deletions crates/cli/src/commands/manage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use mas_matrix_synapse::SynapseConnection;
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
job::{
DeactivateUserJob, DeleteDeviceJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob,
DeactivateUserJob, JobRepositoryExt, ProvisionUserJob, ReactivateUserJob, SyncDevicesJob,
},
user::{UserEmailRepository, UserPasswordRepository, UserRepository},
Clock, RepositoryAccess, SystemClock,
Expand Down Expand Up @@ -370,10 +370,6 @@ impl Options {
if dry_run {
continue;
}

let job = DeleteDeviceJob::new(&user, &compat_session.device);
repo.job().schedule_job(job).await?;
repo.compat_session().finish(&clock, compat_session).await?;
}

let oauth2_sessions_ids: Vec<Uuid> = sqlx::query_scalar(
Expand All @@ -400,16 +396,6 @@ impl Options {
if dry_run {
continue;
}

for scope in &*oauth2_session.scope {
if let Some(device) = Device::from_scope_token(scope) {
// Schedule a job to delete the device.
repo.job()
.schedule_job(DeleteDeviceJob::new(&user, &device))
.await?;
}
}

repo.oauth2_session().finish(&clock, oauth2_session).await?;
}

Expand Down Expand Up @@ -441,6 +427,10 @@ impl Options {
.await?;
}

// Schedule a job to sync the devices of the user with the homeserver
warn!("Scheduling job to sync devices for the user");
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;

let txn = repo.into_inner();
if dry_run {
info!("Dry run, not saving");
Expand Down
44 changes: 35 additions & 9 deletions crates/handlers/src/compat/login.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ use mas_storage::{
CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository,
CompatSsoLoginRepository,
},
job::{JobRepositoryExt, ProvisionDeviceJob},
user::{UserPasswordRepository, UserRepository},
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
Expand Down Expand Up @@ -168,6 +167,9 @@ pub enum RouteError {

#[error("invalid login token")]
InvalidLoginToken,

#[error("failed to provision device")]
ProvisionDeviceFailed(#[source] anyhow::Error),
}

impl_from_error_for_route!(mas_storage::RepositoryError);
Expand All @@ -176,11 +178,13 @@ impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let event_id = sentry::capture_error(&self);
let response = match self {
Self::Internal(_) | Self::SessionNotFound => MatrixError {
errcode: "M_UNKNOWN",
error: "Internal server error",
status: StatusCode::INTERNAL_SERVER_ERROR,
},
Self::Internal(_) | Self::SessionNotFound | Self::ProvisionDeviceFailed(_) => {
MatrixError {
errcode: "M_UNKNOWN",
error: "Internal server error",
status: StatusCode::INTERNAL_SERVER_ERROR,
}
}
Self::Unsupported => MatrixError {
errcode: "M_UNRECOGNIZED",
error: "Invalid login type",
Expand Down Expand Up @@ -235,6 +239,7 @@ pub(crate) async fn post(
&clock,
&password_manager,
&mut repo,
&homeserver,
user,
password,
)
Expand Down Expand Up @@ -368,6 +373,7 @@ async fn user_password_login(
clock: &impl Clock,
password_manager: &PasswordManager,
repo: &mut BoxRepository,
homeserver: &BoxHomeserverConnection,
username: String,
password: String,
) -> Result<(CompatSession, User), RouteError> {
Expand Down Expand Up @@ -413,11 +419,16 @@ async fn user_password_login(
.await?;
}

// Lock the user sync to make sure we don't get into a race condition
repo.user().acquire_lock_for_sync(&user).await?;

// Now that the user credentials have been verified, start a new compat session
let device = Device::generate(&mut rng);
repo.job()
.schedule_job(ProvisionDeviceJob::new(&user, &device))
.await?;
let mxid = homeserver.mxid(&user.username);
homeserver
.create_device(&mxid, device.as_str())
.await
.map_err(RouteError::ProvisionDeviceFailed)?;

let session = repo
.compat_session()
Expand All @@ -430,6 +441,7 @@ async fn user_password_login(
#[cfg(test)]
mod tests {
use hyper::Request;
use mas_matrix::{HomeserverConnection, ProvisionRequest};
use rand::distributions::{Alphanumeric, DistString};
use sqlx::PgPool;

Expand Down Expand Up @@ -532,6 +544,13 @@ mod tests {
.await
.unwrap();

let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();

let (version, hashed_password) = state
.password_manager
.hash(
Expand Down Expand Up @@ -664,6 +683,13 @@ mod tests {
.unwrap();
repo.save().await.unwrap();

let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();

// First try with an invalid token
let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({
"type": "m.login.token",
Expand Down
14 changes: 10 additions & 4 deletions crates/handlers/src/compat/login_sso_complete.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ use mas_axum_utils::{
FancyError, SessionInfoExt,
};
use mas_data_model::Device;
use mas_matrix::BoxHomeserverConnection;
use mas_router::{CompatLoginSsoAction, PostAuthAction, UrlBuilder};
use mas_storage::{
compat::{CompatSessionRepository, CompatSsoLoginRepository},
job::{JobRepositoryExt, ProvisionDeviceJob},
BoxClock, BoxRepository, BoxRng, Clock, RepositoryAccess,
};
use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
Expand Down Expand Up @@ -136,6 +136,7 @@ pub async fn post(
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(url_builder): State<UrlBuilder>,
State(homeserver): State<BoxHomeserverConnection>,
cookie_jar: CookieJar,
Path(id): Path<Ulid>,
Query(params): Query<Params>,
Expand Down Expand Up @@ -201,10 +202,15 @@ pub async fn post(
redirect_uri
};

// Lock the user sync to make sure we don't get into a race condition
repo.user().acquire_lock_for_sync(&session.user).await?;

let device = Device::generate(&mut rng);
repo.job()
.schedule_job(ProvisionDeviceJob::new(&session.user, &device))
.await?;
let mxid = homeserver.mxid(&session.user.username);
homeserver
.create_device(&mxid, device.as_str())
.await
.context("Failed to provision device")?;

let compat_session = repo
.compat_session()
Expand Down
7 changes: 3 additions & 4 deletions crates/handlers/src/compat/logout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use mas_axum_utils::sentry::SentryEventID;
use mas_data_model::TokenType;
use mas_storage::{
compat::{CompatAccessTokenRepository, CompatSessionRepository},
job::{DeleteDeviceJob, JobRepositoryExt},
job::{JobRepositoryExt, SyncDevicesJob},
BoxClock, BoxRepository, Clock, RepositoryAccess,
};
use thiserror::Error;
Expand Down Expand Up @@ -111,9 +111,8 @@ pub(crate) async fn post(
// XXX: this is probably not the right error
.ok_or(RouteError::InvalidAuthorization)?;

repo.job()
.schedule_job(DeleteDeviceJob::new(&user, &session.device))
.await?;
// Schedule a job to sync the devices of the user with the homeserver
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;

repo.compat_session().finish(&clock, session).await?;

Expand Down
8 changes: 3 additions & 5 deletions crates/handlers/src/graphql/mutations/compat_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use anyhow::Context as _;
use async_graphql::{Context, Enum, InputObject, Object, ID};
use mas_storage::{
compat::CompatSessionRepository,
job::{DeleteDeviceJob, JobRepositoryExt},
job::{JobRepositoryExt, SyncDevicesJob},
RepositoryAccess,
};

Expand Down Expand Up @@ -101,10 +101,8 @@ impl CompatSessionMutations {
.await?
.context("Could not load user")?;

// Schedule a job to delete the device.
repo.job()
.schedule_job(DeleteDeviceJob::new(&user, &session.device))
.await?;
// Schedule a job to sync the devices of the user with the homeserver
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;

let session = repo.compat_session().finish(&clock, session).await?;

Expand Down
30 changes: 12 additions & 18 deletions crates/handlers/src/graphql/mutations/oauth2_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use async_graphql::{Context, Description, Enum, InputObject, Object, ID};
use chrono::Duration;
use mas_data_model::{Device, TokenType};
use mas_storage::{
job::{DeleteDeviceJob, JobRepositoryExt, ProvisionDeviceJob},
job::{JobRepositoryExt, SyncDevicesJob},
oauth2::{
OAuth2AccessTokenRepository, OAuth2ClientRepository, OAuth2RefreshTokenRepository,
OAuth2SessionRepository,
Expand Down Expand Up @@ -129,6 +129,7 @@ impl OAuth2SessionMutations {
input: CreateOAuth2SessionInput,
) -> Result<CreateOAuth2SessionPayload, async_graphql::Error> {
let state = ctx.state();
let homeserver = state.homeserver_connection();
let user_id = NodeType::User.extract_ulid(&input.user_id)?;
let scope: Scope = input.scope.parse().context("Invalid scope")?;
let permanent = input.permanent.unwrap_or(false);
Expand Down Expand Up @@ -167,12 +168,17 @@ impl OAuth2SessionMutations {
.add(&mut rng, &clock, &client, Some(&user), None, scope)
.await?;

// Lock the user sync to make sure we don't get into a race condition
repo.user().acquire_lock_for_sync(&user).await?;

// Look for devices to provision
let mxid = homeserver.mxid(&user.username);
for scope in &*session.scope {
if let Some(device) = Device::from_scope_token(scope) {
repo.job()
.schedule_job(ProvisionDeviceJob::new(&user, &device))
.await?;
homeserver
.create_device(&mxid, device.as_str())
.await
.context("Failed to provision device")?;
}
}

Expand Down Expand Up @@ -236,20 +242,8 @@ impl OAuth2SessionMutations {
.await?
.context("Could not load user")?;

// Scan the scopes of the session to find if there is any device that should be
// deleted from the Matrix server.
// TODO: this should be moved in a higher level "end oauth session" method.
// XXX: this might not be the right semantic, but it's the best we
// can do for now, since we're not explicitly storing devices for OAuth2
// sessions.
for scope in &*session.scope {
if let Some(device) = Device::from_scope_token(scope) {
// Schedule a job to delete the device.
repo.job()
.schedule_job(DeleteDeviceJob::new(&user, &device))
.await?;
}
}
// Schedule a job to sync the devices of the user with the homeserver
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;
}

let session = repo.oauth2_session().finish(&clock, session).await?;
Expand Down
13 changes: 12 additions & 1 deletion crates/handlers/src/graphql/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use axum::http::Request;
use hyper::StatusCode;
use mas_data_model::{AccessToken, Client, TokenType, User};
use mas_matrix::{HomeserverConnection, ProvisionRequest};
use mas_router::SimpleRoute;
use mas_storage::{
oauth2::{OAuth2AccessTokenRepository, OAuth2ClientRepository},
Expand Down Expand Up @@ -517,7 +518,7 @@ async fn test_oauth2_client_credentials(pool: PgPool) {
response.assert_status(StatusCode::OK);
let response: GraphQLResponse = response.json();
assert!(response.errors.is_empty(), "{:?}", response.errors);
let user_id = &response.data["addUser"]["user"]["id"];
let user_id = response.data["addUser"]["user"]["id"].as_str().unwrap();

assert_eq!(
response.data,
Expand All @@ -531,6 +532,16 @@ async fn test_oauth2_client_credentials(pool: PgPool) {
})
);

// XXX: we don't run the task worker here, so even though the addUser mutation
// should have scheduled a job to provision the user, it won't run in the test,
// so we need to do it manually
let mxid = state.homeserver_connection.mxid("alice");
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, user_id))
.await
.unwrap();

// We should now be able to create an arbitrary access token for the user
let request = Request::post("/graphql")
.bearer(&access_token)
Expand Down
1 change: 1 addition & 0 deletions crates/handlers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ where
Encrypter: FromRef<S>,
HttpClientFactory: FromRef<S>,
SiteConfig: FromRef<S>,
BoxHomeserverConnection: FromRef<S>,
BoxClock: FromRequestParts<S>,
BoxRng: FromRequestParts<S>,
Policy: FromRequestParts<S>,
Expand Down
15 changes: 15 additions & 0 deletions crates/handlers/src/oauth2/introspection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,7 @@ mod tests {
use hyper::{Request, StatusCode};
use mas_data_model::{AccessToken, RefreshToken};
use mas_iana::oauth::OAuthTokenTypeHint;
use mas_matrix::{HomeserverConnection, ProvisionRequest};
use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute};
use mas_storage::Clock;
use oauth2_types::{
Expand Down Expand Up @@ -518,6 +519,13 @@ mod tests {
.await
.unwrap();

let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();

let client = repo
.oauth2_client()
.find_by_client_id(&client_id)
Expand Down Expand Up @@ -703,6 +711,13 @@ mod tests {
.await
.unwrap();

let mxid = state.homeserver_connection.mxid(&user.username);
state
.homeserver_connection
.provision_user(&ProvisionRequest::new(mxid, &user.sub))
.await
.unwrap();

let (version, hashed_password) = state
.password_manager
.hash(&mut state.rng(), Zeroizing::new(b"password".to_vec()))
Expand Down
Loading

0 comments on commit 3f30498

Please sign in to comment.