diff --git a/src/api/admin.rs b/src/api/admin.rs index b3dc588c22..603af4fb9b 100644 --- a/src/api/admin.rs +++ b/src/api/admin.rs @@ -12,6 +12,7 @@ use rocket::{ Catcher, Route, }; +use crate::auth::HostInfo; use crate::{ api::{ core::{log_event, two_factor}, @@ -97,10 +98,6 @@ const BASE_TEMPLATE: &str = "admin/base"; const ACTING_ADMIN_USER: &str = "vaultwarden-admin-00000-000000000000"; -fn admin_path() -> String { - format!("{}{}", CONFIG.domain_path(), ADMIN_PATH) -} - #[derive(Debug)] struct IpHeader(Option); @@ -123,8 +120,12 @@ impl<'r> FromRequest<'r> for IpHeader { } } -fn admin_url() -> String { - format!("{}{}", CONFIG.domain_origin(), admin_path()) +fn admin_path() -> String { + format!("{}{}", CONFIG.domain_path(), ADMIN_PATH) +} + +fn admin_url(origin: &str) -> String { + format!("{}{}", origin, admin_path()) } #[derive(Responder)] @@ -668,7 +669,12 @@ async fn get_ntp_time(has_http_access: bool) -> String { } #[get("/diagnostics")] -async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn) -> ApiResult> { +async fn diagnostics( + _token: AdminToken, + ip_header: IpHeader, + host_info: HostInfo, + mut conn: DbConn, +) -> ApiResult> { use chrono::prelude::*; use std::net::ToSocketAddrs; @@ -721,7 +727,7 @@ async fn diagnostics(_token: AdminToken, ip_header: IpHeader, mut conn: DbConn) "uses_proxy": uses_proxy, "db_type": *DB_TYPE, "db_version": get_sql_server_version(&mut conn).await, - "admin_url": format!("{}/diagnostics", admin_url()), + "admin_url": format!("{}/diagnostics", admin_url(&host_info.origin)), "overrides": &CONFIG.get_overrides().join(", "), "host_arch": std::env::consts::ARCH, "host_os": std::env::consts::OS, diff --git a/src/api/core/accounts.rs b/src/api/core/accounts.rs index 5b947b85ac..bb41624bbc 100644 --- a/src/api/core/accounts.rs +++ b/src/api/core/accounts.rs @@ -9,7 +9,7 @@ use crate::{ register_push_device, unregister_push_device, AnonymousNotify, EmptyResult, JsonResult, JsonUpcase, Notify, PasswordOrOtpData, UpdateType, }, - auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers}, + auth::{decode_delete, decode_invite, decode_verify_email, ClientHeaders, Headers, HostInfo}, crypto, db::{models::*, DbConn}, mail, @@ -1118,6 +1118,7 @@ struct AuthRequestRequest { async fn post_auth_request( data: Json, headers: ClientHeaders, + host_info: HostInfo, mut conn: DbConn, nt: Notify<'_>, ) -> JsonResult { @@ -1152,13 +1153,13 @@ async fn post_auth_request( "creationDate": auth_request.creation_date.and_utc(), "responseDate": null, "requestApproved": false, - "origin": CONFIG.domain_origin(), + "origin": host_info.origin, "object": "auth-request" }))) } #[get("/auth-requests/")] -async fn get_auth_request(uuid: &str, mut conn: DbConn) -> JsonResult { +async fn get_auth_request(uuid: &str, host_info: HostInfo, mut conn: DbConn) -> JsonResult { let auth_request = match AuthRequest::find_by_uuid(uuid, &mut conn).await { Some(auth_request) => auth_request, None => { @@ -1179,7 +1180,7 @@ async fn get_auth_request(uuid: &str, mut conn: DbConn) -> JsonResult { "creationDate": auth_request.creation_date.and_utc(), "responseDate": response_date_utc, "requestApproved": auth_request.approved, - "origin": CONFIG.domain_origin(), + "origin": host_info.origin, "object":"auth-request" } ))) @@ -1198,6 +1199,7 @@ struct AuthResponseRequest { async fn put_auth_request( uuid: &str, data: Json, + host_info: HostInfo, mut conn: DbConn, ant: AnonymousNotify<'_>, nt: Notify<'_>, @@ -1234,14 +1236,14 @@ async fn put_auth_request( "creationDate": auth_request.creation_date.and_utc(), "responseDate": response_date_utc, "requestApproved": auth_request.approved, - "origin": CONFIG.domain_origin(), + "origin": host_info.origin, "object":"auth-request" } ))) } #[get("/auth-requests//response?")] -async fn get_auth_request_response(uuid: &str, code: &str, mut conn: DbConn) -> JsonResult { +async fn get_auth_request_response(uuid: &str, code: &str, host_info: HostInfo, mut conn: DbConn) -> JsonResult { let auth_request = match AuthRequest::find_by_uuid(uuid, &mut conn).await { Some(auth_request) => auth_request, None => { @@ -1266,14 +1268,14 @@ async fn get_auth_request_response(uuid: &str, code: &str, mut conn: DbConn) -> "creationDate": auth_request.creation_date.and_utc(), "responseDate": response_date_utc, "requestApproved": auth_request.approved, - "origin": CONFIG.domain_origin(), + "origin": host_info.origin, "object":"auth-request" } ))) } #[get("/auth-requests")] -async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult { +async fn get_auth_requests(headers: Headers, host_info: HostInfo, mut conn: DbConn) -> JsonResult { let auth_requests = AuthRequest::find_by_user(&headers.user.uuid, &mut conn).await; Ok(Json(json!({ @@ -1293,7 +1295,7 @@ async fn get_auth_requests(headers: Headers, mut conn: DbConn) -> JsonResult { "creationDate": request.creation_date.and_utc(), "responseDate": response_date_utc, "requestApproved": request.approved, - "origin": CONFIG.domain_origin(), + "origin": host_info.origin, "object":"auth-request" }) }).collect::>(), diff --git a/src/api/core/ciphers.rs b/src/api/core/ciphers.rs index 18d1b9980c..7cc9640076 100644 --- a/src/api/core/ciphers.rs +++ b/src/api/core/ciphers.rs @@ -114,7 +114,7 @@ async fn sync(data: SyncData, headers: Headers, mut conn: DbConn) -> Json let mut ciphers_json = Vec::with_capacity(ciphers.len()); for c in ciphers { ciphers_json.push( - c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn) + c.to_json(&headers.base_url, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn) .await, ); } @@ -161,7 +161,7 @@ async fn get_ciphers(headers: Headers, mut conn: DbConn) -> Json { let mut ciphers_json = Vec::with_capacity(ciphers.len()); for c in ciphers { ciphers_json.push( - c.to_json(&headers.host, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn) + c.to_json(&headers.base_url, &headers.user.uuid, Some(&cipher_sync_data), CipherSyncType::User, &mut conn) .await, ); } @@ -184,7 +184,7 @@ async fn get_cipher(uuid: &str, headers: Headers, mut conn: DbConn) -> JsonResul err!("Cipher is not owned by user") } - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) } #[get("/ciphers//admin")] @@ -324,7 +324,7 @@ async fn post_ciphers(data: JsonUpcase, headers: Headers, mut conn: let mut cipher = Cipher::new(data.Type, data.Name.clone()); update_cipher_from_data(&mut cipher, data, &headers, None, &mut conn, &nt, UpdateType::SyncCipherCreate).await?; - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) } /// Enforces the personal ownership policy on user-owned ciphers, if applicable. @@ -658,7 +658,7 @@ async fn put_cipher( update_cipher_from_data(&mut cipher, data, &headers, None, &mut conn, &nt, UpdateType::SyncCipherUpdate).await?; - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) } #[post("/ciphers//partial", data = "")] @@ -702,7 +702,7 @@ async fn put_cipher_partial( // Update favorite cipher.set_favorite(Some(data.Favorite), &headers.user.uuid, &mut conn).await?; - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) } #[derive(Deserialize)] @@ -933,7 +933,7 @@ async fn share_cipher_by_uuid( update_cipher_from_data(&mut cipher, data.Cipher, headers, Some(shared_to_collections), conn, nt, ut).await?; - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, conn).await)) } /// v2 API for downloading an attachment. This just redirects the client to @@ -954,7 +954,7 @@ async fn get_attachment(uuid: &str, attachment_id: &str, headers: Headers, mut c } match Attachment::find_by_id(attachment_id, &mut conn).await { - Some(attachment) if uuid == attachment.cipher_uuid => Ok(Json(attachment.to_json(&headers.host))), + Some(attachment) if uuid == attachment.cipher_uuid => Ok(Json(attachment.to_json(&headers.base_url))), Some(_) => err!("Attachment doesn't belong to cipher"), None => err!("Attachment doesn't exist"), } @@ -1016,7 +1016,7 @@ async fn post_attachment_v2( "AttachmentId": attachment_id, "Url": url, "FileUploadType": FileUploadType::Direct as i32, - response_key: cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await, + response_key: cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await, }))) } @@ -1243,7 +1243,7 @@ async fn post_attachment( let (cipher, mut conn) = save_attachment(attachment, uuid, data, &headers, conn, nt).await?; - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, &mut conn).await)) } #[post("/ciphers//attachment-admin", format = "multipart/form-data", data = "")] @@ -1677,7 +1677,7 @@ async fn _restore_cipher_by_uuid(uuid: &str, headers: &Headers, conn: &mut DbCon .await; } - Ok(Json(cipher.to_json(&headers.host, &headers.user.uuid, None, CipherSyncType::User, conn).await)) + Ok(Json(cipher.to_json(&headers.base_url, &headers.user.uuid, None, CipherSyncType::User, conn).await)) } async fn _restore_multiple_ciphers( diff --git a/src/api/core/emergency_access.rs b/src/api/core/emergency_access.rs index 5d522c61c2..abcff8ae5c 100644 --- a/src/api/core/emergency_access.rs +++ b/src/api/core/emergency_access.rs @@ -574,7 +574,7 @@ async fn view_emergency_access(emer_id: &str, headers: Headers, mut conn: DbConn for c in ciphers { ciphers_json.push( c.to_json( - &headers.host, + &headers.base_url, &emergency_access.grantor_uuid, Some(&cipher_sync_data), CipherSyncType::User, diff --git a/src/api/core/mod.rs b/src/api/core/mod.rs index 1d31b27c85..537d9e4ff6 100644 --- a/src/api/core/mod.rs +++ b/src/api/core/mod.rs @@ -190,7 +190,8 @@ fn version() -> Json<&'static str> { #[get("/config")] fn config() -> Json { - let domain = crate::CONFIG.domain(); + // TODO: maybe this should be extracted from the current request params + let domain = crate::CONFIG.main_domain(); let mut feature_states = parse_experimental_client_feature_flags(&crate::CONFIG.experimental_client_feature_flags()); // Force the new key rotation feature diff --git a/src/api/core/organizations.rs b/src/api/core/organizations.rs index c6556f774d..664f0213b5 100644 --- a/src/api/core/organizations.rs +++ b/src/api/core/organizations.rs @@ -749,20 +749,20 @@ struct OrgIdData { #[get("/ciphers/organization-details?")] async fn get_org_details(data: OrgIdData, headers: Headers, mut conn: DbConn) -> Json { Json(json!({ - "Data": _get_org_details(&data.organization_id, &headers.host, &headers.user.uuid, &mut conn).await, + "Data": _get_org_details(&data.organization_id, &headers.base_url, &headers.user.uuid, &mut conn).await, "Object": "list", "ContinuationToken": null, })) } -async fn _get_org_details(org_id: &str, host: &str, user_uuid: &str, conn: &mut DbConn) -> Value { +async fn _get_org_details(org_id: &str, base_url: &str, user_uuid: &str, conn: &mut DbConn) -> Value { let ciphers = Cipher::find_by_org(org_id, conn).await; let cipher_sync_data = CipherSyncData::new(user_uuid, CipherSyncType::Organization, conn).await; let mut ciphers_json = Vec::with_capacity(ciphers.len()); for c in ciphers { ciphers_json - .push(c.to_json(host, user_uuid, Some(&cipher_sync_data), CipherSyncType::Organization, conn).await); + .push(c.to_json(base_url, user_uuid, Some(&cipher_sync_data), CipherSyncType::Organization, conn).await); } json!(ciphers_json) } @@ -2906,7 +2906,7 @@ async fn get_org_export(org_id: &str, headers: AdminHeaders, mut conn: DbConn) - "continuationToken": null, }, "ciphers": { - "data": convert_json_key_lcase_first(_get_org_details(org_id, &headers.host, &headers.user.uuid, &mut conn).await), + "data": convert_json_key_lcase_first(_get_org_details(org_id, &headers.base_url, &headers.user.uuid, &mut conn).await), "object": "list", "continuationToken": null, } @@ -2915,7 +2915,7 @@ async fn get_org_export(org_id: &str, headers: AdminHeaders, mut conn: DbConn) - // v2023.1.0 and newer response Json(json!({ "collections": convert_json_key_lcase_first(_get_org_collections(org_id, &mut conn).await), - "ciphers": convert_json_key_lcase_first(_get_org_details(org_id, &headers.host, &headers.user.uuid, &mut conn).await), + "ciphers": convert_json_key_lcase_first(_get_org_details(org_id, &headers.base_url, &headers.user.uuid, &mut conn).await), })) } } diff --git a/src/api/core/public.rs b/src/api/core/public.rs index 085ac55289..2726ab2fcd 100644 --- a/src/api/core/public.rs +++ b/src/api/core/public.rs @@ -217,11 +217,13 @@ impl<'r> FromRequest<'r> for PublicToken { err_handler!("Token expired"); } // Check if claims.iss is host|claims.scope[0] - let host = match auth::Host::from_request(request).await { - Outcome::Success(host) => host, + let host_info = match auth::HostInfo::from_request(request).await { + Outcome::Success(host_info) => host_info, _ => err_handler!("Error getting Host"), }; - let complete_host = format!("{}|{}", host.host, claims.scope[0]); + // TODO check if this is fine + // using origin, because that's what they're generated with in auth.rs + let complete_host = format!("{}|{}", host_info.origin, claims.scope[0]); if complete_host != claims.iss { err_handler!("Token not issued by this server"); } diff --git a/src/api/core/sends.rs b/src/api/core/sends.rs index 338510c6ed..98e141bc70 100644 --- a/src/api/core/sends.rs +++ b/src/api/core/sends.rs @@ -10,7 +10,7 @@ use serde_json::Value; use crate::{ api::{ApiResult, EmptyResult, JsonResult, JsonUpcase, Notify, UpdateType}, - auth::{ClientIp, Headers, Host}, + auth::{ClientIp, Headers, HostInfo}, db::{models::*, DbConn, DbPool}, util::{NumberOrString, SafeString}, CONFIG, @@ -465,7 +465,7 @@ async fn post_access_file( send_id: &str, file_id: &str, data: JsonUpcase, - host: Host, + host_info: HostInfo, mut conn: DbConn, nt: Notify<'_>, ) -> JsonResult { @@ -520,7 +520,7 @@ async fn post_access_file( Ok(Json(json!({ "Object": "send-fileDownload", "Id": file_id, - "Url": format!("{}/api/sends/{}/{}?t={}", &host.host, send_id, file_id, token) + "Url": format!("{}/api/sends/{}/{}?t={}", &host_info.base_url, send_id, file_id, token) }))) } diff --git a/src/api/core/two_factor/webauthn.rs b/src/api/core/two_factor/webauthn.rs index 14ba851413..b99d9a90c3 100644 --- a/src/api/core/two_factor/webauthn.rs +++ b/src/api/core/two_factor/webauthn.rs @@ -9,7 +9,7 @@ use crate::{ core::{log_user_event, two_factor::_generate_recover_code}, EmptyResult, JsonResult, JsonUpcase, PasswordOrOtpData, }, - auth::Headers, + auth::{Headers, HostInfo}, db::{ models::{EventType, TwoFactor, TwoFactorType}, DbConn, @@ -52,13 +52,11 @@ struct WebauthnConfig { } impl WebauthnConfig { - fn load() -> Webauthn { - let domain = CONFIG.domain(); - let domain_origin = CONFIG.domain_origin(); + fn load(domain: &str, domain_origin: &str) -> Webauthn { Webauthn::new(Self { - rpid: Url::parse(&domain).map(|u| u.domain().map(str::to_owned)).ok().flatten().unwrap_or_default(), - url: domain, - origin: Url::parse(&domain_origin).unwrap(), + rpid: Url::parse(domain).map(|u| u.domain().map(str::to_owned)).ok().flatten().unwrap_or_default(), + url: domain.to_string(), + origin: Url::parse(domain_origin).unwrap(), }) } } @@ -128,6 +126,7 @@ async fn get_webauthn(data: JsonUpcase, headers: Headers, mut async fn generate_webauthn_challenge( data: JsonUpcase, headers: Headers, + host_info: HostInfo, mut conn: DbConn, ) -> JsonResult { let data: PasswordOrOtpData = data.into_inner().data; @@ -142,14 +141,15 @@ async fn generate_webauthn_challenge( .map(|r| r.credential.cred_id) // We return the credentialIds to the clients to avoid double registering .collect(); - let (challenge, state) = WebauthnConfig::load().generate_challenge_register_options( - user.uuid.as_bytes().to_vec(), - user.email, - user.name, - Some(registrations), - None, - None, - )?; + let (challenge, state) = WebauthnConfig::load(&host_info.base_url, &host_info.origin) + .generate_challenge_register_options( + user.uuid.as_bytes().to_vec(), + user.email, + user.name, + Some(registrations), + None, + None, + )?; let type_ = TwoFactorType::WebauthnRegisterChallenge; TwoFactor::new(user.uuid, type_, serde_json::to_string(&state)?).save(&mut conn).await?; @@ -250,7 +250,12 @@ impl From for PublicKeyCredential { } #[post("/two-factor/webauthn", data = "")] -async fn activate_webauthn(data: JsonUpcase, headers: Headers, mut conn: DbConn) -> JsonResult { +async fn activate_webauthn( + data: JsonUpcase, + headers: Headers, + host_info: HostInfo, + mut conn: DbConn, +) -> JsonResult { let data: EnableWebauthnData = data.into_inner().data; let mut user = headers.user; @@ -273,8 +278,11 @@ async fn activate_webauthn(data: JsonUpcase, headers: Header }; // Verify the credentials with the saved state - let (credential, _data) = - WebauthnConfig::load().register_credential(&data.DeviceResponse.into(), &state, |_| Ok(false))?; + let (credential, _data) = WebauthnConfig::load(&host_info.base_url, &host_info.origin).register_credential( + &data.DeviceResponse.into(), + &state, + |_| Ok(false), + )?; let mut registrations: Vec<_> = get_webauthn_registrations(&user.uuid, &mut conn).await?.1; // TODO: Check for repeated ID's @@ -303,8 +311,13 @@ async fn activate_webauthn(data: JsonUpcase, headers: Header } #[put("/two-factor/webauthn", data = "")] -async fn activate_webauthn_put(data: JsonUpcase, headers: Headers, conn: DbConn) -> JsonResult { - activate_webauthn(data, headers, conn).await +async fn activate_webauthn_put( + data: JsonUpcase, + headers: Headers, + host_info: HostInfo, + conn: DbConn, +) -> JsonResult { + activate_webauthn(data, headers, host_info, conn).await } #[derive(Deserialize, Debug)] @@ -375,7 +388,7 @@ pub async fn get_webauthn_registrations( } } -pub async fn generate_webauthn_login(user_uuid: &str, conn: &mut DbConn) -> JsonResult { +pub async fn generate_webauthn_login(user_uuid: &str, base_url: &str, origin: &str, conn: &mut DbConn) -> JsonResult { // Load saved credentials let creds: Vec = get_webauthn_registrations(user_uuid, conn).await?.1.into_iter().map(|r| r.credential).collect(); @@ -385,8 +398,9 @@ pub async fn generate_webauthn_login(user_uuid: &str, conn: &mut DbConn) -> Json } // Generate a challenge based on the credentials - let ext = RequestAuthenticationExtensions::builder().appid(format!("{}/app-id.json", &CONFIG.domain())).build(); - let (response, state) = WebauthnConfig::load().generate_challenge_authenticate_options(creds, Some(ext))?; + let ext = RequestAuthenticationExtensions::builder().appid(format!("{}/app-id.json", base_url)).build(); + let (response, state) = + WebauthnConfig::load(base_url, origin).generate_challenge_authenticate_options(creds, Some(ext))?; // Save the challenge state for later validation TwoFactor::new(user_uuid.into(), TwoFactorType::WebauthnLoginChallenge, serde_json::to_string(&state)?) @@ -397,7 +411,13 @@ pub async fn generate_webauthn_login(user_uuid: &str, conn: &mut DbConn) -> Json Ok(Json(serde_json::to_value(response.public_key)?)) } -pub async fn validate_webauthn_login(user_uuid: &str, response: &str, conn: &mut DbConn) -> EmptyResult { +pub async fn validate_webauthn_login( + user_uuid: &str, + response: &str, + base_url: &str, + origin: &str, + conn: &mut DbConn, +) -> EmptyResult { let type_ = TwoFactorType::WebauthnLoginChallenge as i32; let state = match TwoFactor::find_by_user_and_type(user_uuid, type_, conn).await { Some(tf) => { @@ -420,7 +440,7 @@ pub async fn validate_webauthn_login(user_uuid: &str, response: &str, conn: &mut // If the credential we received is migrated from U2F, enable the U2F compatibility //let use_u2f = registrations.iter().any(|r| r.migrated && r.credential.cred_id == rsp.raw_id.0); - let (cred_id, auth_data) = WebauthnConfig::load().authenticate_credential(&rsp, &state)?; + let (cred_id, auth_data) = WebauthnConfig::load(base_url, origin).authenticate_credential(&rsp, &state)?; for reg in &mut registrations { if ®.credential.cred_id == cred_id { diff --git a/src/api/identity.rs b/src/api/identity.rs index ad51d664f0..faa7503b82 100644 --- a/src/api/identity.rs +++ b/src/api/identity.rs @@ -17,7 +17,7 @@ use crate::{ push::register_push_device, ApiResult, EmptyResult, JsonResult, JsonUpcase, }, - auth::{generate_organization_api_key_login_claims, ClientHeaders, ClientIp}, + auth::{generate_organization_api_key_login_claims, ClientHeaders, ClientIp, HostInfo}, db::{models::*, DbConn}, error::MapResult, mail, util, CONFIG, @@ -28,7 +28,12 @@ pub fn routes() -> Vec { } #[post("/connect/token", data = "")] -async fn login(data: Form, client_header: ClientHeaders, mut conn: DbConn) -> JsonResult { +async fn login( + data: Form, + client_header: ClientHeaders, + host_info: HostInfo, + mut conn: DbConn, +) -> JsonResult { let data: ConnectData = data.into_inner(); let mut user_uuid: Option = None; @@ -48,7 +53,8 @@ async fn login(data: Form, client_header: ClientHeaders, mut conn: _check_is_some(&data.device_name, "device_name cannot be blank")?; _check_is_some(&data.device_type, "device_type cannot be blank")?; - _password_login(data, &mut user_uuid, &mut conn, &client_header.ip).await + _password_login(data, &mut user_uuid, &mut conn, &client_header.ip, &host_info.base_url, &host_info.origin) + .await } "client_credentials" => { _check_is_some(&data.client_id, "client_id cannot be blank")?; @@ -140,6 +146,8 @@ async fn _password_login( user_uuid: &mut Option, conn: &mut DbConn, ip: &ClientIp, + base_url: &str, + origin: &str, ) -> JsonResult { // Validate scope let scope = data.scope.as_ref().unwrap(); @@ -250,7 +258,7 @@ async fn _password_login( let (mut device, new_device) = get_device(&data, conn, &user).await; - let twofactor_token = twofactor_auth(&user, &data, &mut device, ip, conn).await?; + let twofactor_token = twofactor_auth(&user, &data, &mut device, ip, base_url, origin, conn).await?; if CONFIG.mail_enabled() && new_device { if let Err(e) = mail::send_new_device_logged_in(&user.email, &ip.ip.to_string(), &now, &device.name).await { @@ -485,6 +493,8 @@ async fn twofactor_auth( data: &ConnectData, device: &mut Device, ip: &ClientIp, + base_url: &str, + origin: &str, conn: &mut DbConn, ) -> ApiResult> { let twofactors = TwoFactor::find_by_user(&user.uuid, conn).await; @@ -502,7 +512,10 @@ async fn twofactor_auth( let twofactor_code = match data.two_factor_token { Some(ref code) => code, - None => err_json!(_json_err_twofactor(&twofactor_ids, &user.uuid, conn).await?, "2FA token not provided"), + None => err_json!( + _json_err_twofactor(&twofactor_ids, &user.uuid, base_url, origin, conn).await?, + "2FA token not provided" + ), }; let selected_twofactor = twofactors.into_iter().find(|tf| tf.atype == selected_id && tf.enabled); @@ -516,7 +529,9 @@ async fn twofactor_auth( Some(TwoFactorType::Authenticator) => { authenticator::validate_totp_code_str(&user.uuid, twofactor_code, &selected_data?, ip, conn).await? } - Some(TwoFactorType::Webauthn) => webauthn::validate_webauthn_login(&user.uuid, twofactor_code, conn).await?, + Some(TwoFactorType::Webauthn) => { + webauthn::validate_webauthn_login(&user.uuid, twofactor_code, base_url, origin, conn).await? + } Some(TwoFactorType::YubiKey) => yubikey::validate_yubikey_login(twofactor_code, &selected_data?).await?, Some(TwoFactorType::Duo) => { duo::validate_duo_login(data.username.as_ref().unwrap().trim(), twofactor_code, conn).await? @@ -532,7 +547,7 @@ async fn twofactor_auth( } _ => { err_json!( - _json_err_twofactor(&twofactor_ids, &user.uuid, conn).await?, + _json_err_twofactor(&twofactor_ids, &user.uuid, base_url, origin, conn).await?, "2FA Remember token not provided" ) } @@ -560,7 +575,13 @@ fn _selected_data(tf: Option) -> ApiResult { tf.map(|t| t.data).map_res("Two factor doesn't exist") } -async fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &mut DbConn) -> ApiResult { +async fn _json_err_twofactor( + providers: &[i32], + user_uuid: &str, + base_url: &str, + origin: &str, + conn: &mut DbConn, +) -> ApiResult { let mut result = json!({ "error" : "invalid_grant", "error_description" : "Two factor required.", @@ -575,7 +596,7 @@ async fn _json_err_twofactor(providers: &[i32], user_uuid: &str, conn: &mut DbCo Some(TwoFactorType::Authenticator) => { /* Nothing to do for TOTP */ } Some(TwoFactorType::Webauthn) if CONFIG.domain_set() => { - let request = webauthn::generate_webauthn_login(user_uuid, conn).await?; + let request = webauthn::generate_webauthn_login(user_uuid, base_url, origin, conn).await?; result["TwoFactorProviders2"][provider.to_string()] = request.0; } diff --git a/src/api/web.rs b/src/api/web.rs index 67248c835e..8dc1c64938 100644 --- a/src/api/web.rs +++ b/src/api/web.rs @@ -5,7 +5,7 @@ use serde_json::Value; use crate::{ api::{core::now, ApiResult, EmptyResult}, - auth::decode_file_download, + auth::{decode_file_download, HostInfo}, error::Error, util::{Cached, SafeString}, CONFIG, @@ -62,9 +62,12 @@ fn web_index_head() -> EmptyResult { } #[get("/app-id.json")] -fn app_id() -> Cached<(ContentType, Json)> { +fn app_id(host_info: HostInfo) -> Cached<(ContentType, Json)> { let content_type = ContentType::new("application", "fido.trusted-apps+json"); + // TODO Maybe return all available origins. + let origin = host_info.origin; + Cached::long( ( content_type, @@ -83,7 +86,7 @@ fn app_id() -> Cached<(ContentType, Json)> { // This leaves it unclear as to whether the path must be empty, // or whether it can be non-empty and will be ignored. To be on // the safe side, use a proper web origin (with empty path). - &CONFIG.domain_origin(), + &origin, "ios:bundle-id:com.8bit.bitwarden", "android:apk-key-hash:dUGFzUzf3lmHSLBDBIv+WaFyZMI" ] }] diff --git a/src/auth.rs b/src/auth.rs index f05eba6548..544fbc38a9 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -9,6 +9,7 @@ use openssl::rsa::Rsa; use serde::de::DeserializeOwned; use serde::ser::Serialize; +use crate::config::{extract_url_host, extract_url_origin}; use crate::{error::Error, CONFIG}; const JWT_ALGORITHM: Algorithm = Algorithm::RS256; @@ -16,16 +17,20 @@ const JWT_ALGORITHM: Algorithm = Algorithm::RS256; pub static DEFAULT_VALIDITY: Lazy = Lazy::new(|| TimeDelta::try_hours(2).unwrap()); static JWT_HEADER: Lazy
= Lazy::new(|| Header::new(JWT_ALGORITHM)); -pub static JWT_LOGIN_ISSUER: Lazy = Lazy::new(|| format!("{}|login", CONFIG.domain_origin())); -static JWT_INVITE_ISSUER: Lazy = Lazy::new(|| format!("{}|invite", CONFIG.domain_origin())); +fn jwt_origin() -> String { + extract_url_origin(&CONFIG.main_domain()) +} + +pub static JWT_LOGIN_ISSUER: Lazy = Lazy::new(|| format!("{}|login", jwt_origin())); +static JWT_INVITE_ISSUER: Lazy = Lazy::new(|| format!("{}|invite", jwt_origin())); static JWT_EMERGENCY_ACCESS_INVITE_ISSUER: Lazy = - Lazy::new(|| format!("{}|emergencyaccessinvite", CONFIG.domain_origin())); -static JWT_DELETE_ISSUER: Lazy = Lazy::new(|| format!("{}|delete", CONFIG.domain_origin())); -static JWT_VERIFYEMAIL_ISSUER: Lazy = Lazy::new(|| format!("{}|verifyemail", CONFIG.domain_origin())); -static JWT_ADMIN_ISSUER: Lazy = Lazy::new(|| format!("{}|admin", CONFIG.domain_origin())); -static JWT_SEND_ISSUER: Lazy = Lazy::new(|| format!("{}|send", CONFIG.domain_origin())); -static JWT_ORG_API_KEY_ISSUER: Lazy = Lazy::new(|| format!("{}|api.organization", CONFIG.domain_origin())); -static JWT_FILE_DOWNLOAD_ISSUER: Lazy = Lazy::new(|| format!("{}|file_download", CONFIG.domain_origin())); + Lazy::new(|| format!("{}|emergencyaccessinvite", jwt_origin())); +static JWT_DELETE_ISSUER: Lazy = Lazy::new(|| format!("{}|delete", jwt_origin())); +static JWT_VERIFYEMAIL_ISSUER: Lazy = Lazy::new(|| format!("{}|verifyemail", jwt_origin())); +static JWT_ADMIN_ISSUER: Lazy = Lazy::new(|| format!("{}|admin", jwt_origin())); +static JWT_SEND_ISSUER: Lazy = Lazy::new(|| format!("{}|send", jwt_origin())); +static JWT_ORG_API_KEY_ISSUER: Lazy = Lazy::new(|| format!("{}|api.organization", jwt_origin())); +static JWT_FILE_DOWNLOAD_ISSUER: Lazy = Lazy::new(|| format!("{}|file_download", jwt_origin())); static PRIVATE_RSA_KEY: OnceCell = OnceCell::new(); static PUBLIC_RSA_KEY: OnceCell = OnceCell::new(); @@ -355,29 +360,64 @@ use rocket::{ outcome::try_outcome, request::{FromRequest, Outcome, Request}, }; +use std::borrow::Cow; use crate::db::{ models::{Collection, Device, User, UserOrgStatus, UserOrgType, UserOrganization, UserStampException}, DbConn, }; -pub struct Host { - pub host: String, +#[derive(Clone, Debug)] +pub struct HostInfo { + pub base_url: String, + pub origin: String, +} + +fn get_host_info(host: &str) -> Option { + CONFIG.host_to_domain(host).and_then(|base_url| Some((base_url, CONFIG.host_to_origin(host)?))).map( + |(base_url, origin)| HostInfo { + base_url, + origin, + }, + ) +} + +fn get_main_host() -> String { + extract_url_host(&CONFIG.main_domain()) } #[rocket::async_trait] -impl<'r> FromRequest<'r> for Host { +impl<'r> FromRequest<'r> for HostInfo { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { let headers = request.headers(); // Get host - let host = if CONFIG.domain_set() { - CONFIG.domain() + let host_info = if CONFIG.domain_set() { + log::debug!("Using configured host info"); + let host: Cow<'_, str> = if let Some(host) = headers.get_one("X-Forwarded-Host") { + host.into() + } else if let Some(host) = headers.get_one("Host") { + host.into() + } else { + get_main_host().into() + }; + + let host_info = get_host_info(host.as_ref()).unwrap_or_else(|| { + log::debug!("Falling back to default domain, because {host} was not in domains."); + get_host_info(&get_main_host()).expect("Main domain doesn't have entry!") + }); + + host_info } else if let Some(referer) = headers.get_one("Referer") { - referer.to_string() + log::debug!("Using referer host info"); + HostInfo { + base_url: referer.to_string(), + origin: extract_url_origin(referer), + } } else { + log::debug!("Guessing host info with headers"); // Try to guess from the headers use std::env; @@ -395,17 +435,22 @@ impl<'r> FromRequest<'r> for Host { headers.get_one("Host").unwrap_or_default() }; - format!("{protocol}://{host}") + let base_url_origin = format!("{protocol}://{host}"); + + HostInfo { + base_url: base_url_origin.clone(), + origin: base_url_origin, + } }; - Outcome::Success(Host { - host, - }) + log::debug!("Using host_info: {:?}", host_info); + + Outcome::Success(host_info) } } pub struct ClientHeaders { - pub host: String, + pub base_url: String, pub device_type: i32, pub ip: ClientIp, } @@ -415,7 +460,7 @@ impl<'r> FromRequest<'r> for ClientHeaders { type Error = &'static str; async fn from_request(request: &'r Request<'_>) -> Outcome { - let host = try_outcome!(Host::from_request(request).await).host; + let base_url = try_outcome!(HostInfo::from_request(request).await).base_url; let ip = match ClientIp::from_request(request).await { Outcome::Success(ip) => ip, _ => err_handler!("Error getting Client IP"), @@ -425,7 +470,7 @@ impl<'r> FromRequest<'r> for ClientHeaders { request.headers().get_one("device-type").map(|d| d.parse().unwrap_or(14)).unwrap_or_else(|| 14); Outcome::Success(ClientHeaders { - host, + base_url, device_type, ip, }) @@ -433,7 +478,7 @@ impl<'r> FromRequest<'r> for ClientHeaders { } pub struct Headers { - pub host: String, + pub base_url: String, pub device: Device, pub user: User, pub ip: ClientIp, @@ -446,7 +491,7 @@ impl<'r> FromRequest<'r> for Headers { async fn from_request(request: &'r Request<'_>) -> Outcome { let headers = request.headers(); - let host = try_outcome!(Host::from_request(request).await).host; + let base_url = try_outcome!(HostInfo::from_request(request).await).base_url; let ip = match ClientIp::from_request(request).await { Outcome::Success(ip) => ip, _ => err_handler!("Error getting Client IP"), @@ -517,7 +562,7 @@ impl<'r> FromRequest<'r> for Headers { } Outcome::Success(Headers { - host, + base_url, device, user, ip, @@ -526,7 +571,7 @@ impl<'r> FromRequest<'r> for Headers { } pub struct OrgHeaders { - pub host: String, + pub base_url: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, @@ -582,7 +627,7 @@ impl<'r> FromRequest<'r> for OrgHeaders { }; Outcome::Success(Self { - host: headers.host, + base_url: headers.base_url, device: headers.device, user, org_user_type: { @@ -604,7 +649,7 @@ impl<'r> FromRequest<'r> for OrgHeaders { } pub struct AdminHeaders { - pub host: String, + pub base_url: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, @@ -621,7 +666,7 @@ impl<'r> FromRequest<'r> for AdminHeaders { let client_version = request.headers().get_one("Bitwarden-Client-Version").map(String::from); if headers.org_user_type >= UserOrgType::Admin { Outcome::Success(Self { - host: headers.host, + base_url: headers.base_url, device: headers.device, user: headers.user, org_user_type: headers.org_user_type, @@ -637,7 +682,7 @@ impl<'r> FromRequest<'r> for AdminHeaders { impl From for Headers { fn from(h: AdminHeaders) -> Headers { Headers { - host: h.host, + base_url: h.base_url, device: h.device, user: h.user, ip: h.ip, @@ -668,7 +713,7 @@ fn get_col_id(request: &Request<'_>) -> Option { /// and have access to the specific collection provided via the /collections/collectionId. /// This does strict checking on the collection_id, ManagerHeadersLoose does not. pub struct ManagerHeaders { - pub host: String, + pub base_url: String, pub device: Device, pub user: User, pub org_user_type: UserOrgType, @@ -697,7 +742,7 @@ impl<'r> FromRequest<'r> for ManagerHeaders { } Outcome::Success(Self { - host: headers.host, + base_url: headers.base_url, device: headers.device, user: headers.user, org_user_type: headers.org_user_type, @@ -712,7 +757,7 @@ impl<'r> FromRequest<'r> for ManagerHeaders { impl From for Headers { fn from(h: ManagerHeaders) -> Headers { Headers { - host: h.host, + base_url: h.base_url, device: h.device, user: h.user, ip: h.ip, @@ -723,7 +768,7 @@ impl From for Headers { /// The ManagerHeadersLoose is used when you at least need to be a Manager, /// but there is no collection_id sent with the request (either in the path or as form data). pub struct ManagerHeadersLoose { - pub host: String, + pub base_url: String, pub device: Device, pub user: User, pub org_user: UserOrganization, @@ -739,7 +784,7 @@ impl<'r> FromRequest<'r> for ManagerHeadersLoose { let headers = try_outcome!(OrgHeaders::from_request(request).await); if headers.org_user_type >= UserOrgType::Manager { Outcome::Success(Self { - host: headers.host, + base_url: headers.base_url, device: headers.device, user: headers.user, org_user: headers.org_user, @@ -755,7 +800,7 @@ impl<'r> FromRequest<'r> for ManagerHeadersLoose { impl From for Headers { fn from(h: ManagerHeadersLoose) -> Headers { Headers { - host: h.host, + base_url: h.base_url, device: h.device, user: h.user, ip: h.ip, @@ -779,7 +824,7 @@ impl ManagerHeaders { } Ok(ManagerHeaders { - host: h.host, + base_url: h.base_url, device: h.device, user: h.user, org_user_type: h.org_user_type, @@ -789,7 +834,7 @@ impl ManagerHeaders { } pub struct OwnerHeaders { - pub host: String, + pub base_url: String, pub device: Device, pub user: User, pub ip: ClientIp, @@ -803,7 +848,7 @@ impl<'r> FromRequest<'r> for OwnerHeaders { let headers = try_outcome!(OrgHeaders::from_request(request).await); if headers.org_user_type == UserOrgType::Owner { Outcome::Success(Self { - host: headers.host, + base_url: headers.base_url, device: headers.device, user: headers.user, ip: headers.ip, diff --git a/src/config.rs b/src/config.rs index 489a229d6a..97eb4b9a62 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,12 +1,14 @@ -use std::env::consts::EXE_SUFFIX; use std::process::exit; +use std::sync::OnceLock; use std::sync::RwLock; +use std::{collections::HashMap, env::consts::EXE_SUFFIX}; use job_scheduler_ng::Schedule; use once_cell::sync::Lazy; use reqwest::Url; use crate::{ + auth::HostInfo, db::DbConnType, error::Error, util::{get_env, get_env_bool, parse_experimental_client_feature_flags}, @@ -47,6 +49,8 @@ macro_rules! make_config { _usr: ConfigBuilder, _overrides: Vec, + + domain_hostmap: OnceLock>, } #[derive(Clone, Default, Deserialize, Serialize)] @@ -141,7 +145,15 @@ macro_rules! make_config { )+)+ config.domain_set = _domain_set; - config.domain = config.domain.trim_end_matches('/').to_string(); + // Remove slash from every domain + config.domain = config.domain.split(',').map(|d| d.trim_end_matches('/')).fold(String::new(), |mut acc, d| { + acc.push_str(d); + acc.push(','); + acc + }); + + // Remove trailing comma + config.domain.pop(); config.signups_domains_whitelist = config.signups_domains_whitelist.trim().to_lowercase(); config.org_creation_users = config.org_creation_users.trim().to_lowercase(); @@ -414,15 +426,17 @@ make_config! { /// General settings settings { - /// Domain URL |> This needs to be set to the URL used to access the server, including 'http[s]://' - /// and port, if it's different than the default. Some server functions don't work correctly without this value + /// Comma seperated list of Domain URLs |> This needs to be set to the URL used to access the server, including + /// 'http[s]://' and port, if it's different than the default. Some server functions don't work correctly without this value domain: String, true, def, "http://localhost".to_string(); /// Domain Set |> Indicates if the domain is set by the admin. Otherwise the default will be used. domain_set: bool, false, def, false; - /// Domain origin |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin) - domain_origin: String, false, auto, |c| extract_url_origin(&c.domain); + /// Comma seperated list of domain origins |> Domain URL origin (in https://example.com:8443/path, https://example.com:8443 is the origin) + /// If specified manually, one entry needs to exist for every url in domain. + domain_origin: String, false, auto, |c| extract_origins(&c.domain); /// Domain path |> Domain URL path (in https://example.com:8443/path, /path is the path) - domain_path: String, false, auto, |c| extract_url_path(&c.domain); + /// MUST be the same for all domains. + domain_path: String, false, auto, |c| extract_url_path(c.domain.split(',').next().expect("Missing domain")); /// Enable web vault web_vault_enabled: bool, false, def, true; @@ -720,11 +734,17 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { } } - let dom = cfg.domain.to_lowercase(); - if !dom.starts_with("http://") && !dom.starts_with("https://") { - err!( - "DOMAIN variable needs to contain the protocol (http, https). Use 'http[s]://bw.example.com' instead of 'bw.example.com'" - ); + let domains = cfg.domain.split(',').map(|d| d.to_string().to_lowercase()); + for dom in domains { + if !dom.starts_with("http://") && !dom.starts_with("https://") { + err!( + "DOMAIN variable needs to contain the protocol (http, https). Use 'http[s]://bw.example.com' instead of 'bw.example.com'" + ); + } + } + + if cfg.domain.split(',').count() != cfg.domain_origin.split(',').count() { + err!("Each DOMAIN_ORIGIN entry corresponds to exactly one entry in DOMAIN."); } let whitelist = &cfg.signups_domains_whitelist; @@ -988,7 +1008,7 @@ fn validate_config(cfg: &ConfigItems) -> Result<(), Error> { } /// Extracts an RFC 6454 web origin from a URL. -fn extract_url_origin(url: &str) -> String { +pub fn extract_url_origin(url: &str) -> String { match Url::parse(url) { Ok(u) => u.origin().ascii_serialization(), Err(e) => { @@ -998,6 +1018,24 @@ fn extract_url_origin(url: &str) -> String { } } +// urls should be comma-seperated +fn extract_origins(urls: &str) -> String { + let mut origins = urls + .split(',') + .map(extract_url_origin) + // TODO add itertools as dependency maybe + .fold(String::new(), |mut acc, origin| { + acc.push_str(&origin); + acc.push(','); + acc + }); + + // Pop trailing comma + origins.pop(); + + origins +} + /// Extracts the path from a URL. /// All trailing '/' chars are trimmed, even if the path is a lone '/'. fn extract_url_path(url: &str) -> String { @@ -1010,10 +1048,34 @@ fn extract_url_path(url: &str) -> String { } } -fn generate_smtp_img_src(embed_images: bool, domain: &str) -> String { +/// Extracts host part from a URL. +pub fn extract_url_host(url: &str) -> String { + match Url::parse(url) { + Ok(u) => { + let Some(mut host) = u.host_str().map(|s| s.to_string()) else { + println!("Domain does not contain host!"); + return String::new(); + }; + + if let Some(port) = u.port().map(|p| p.to_string()) { + host.push(':'); + host.push_str(&port); + } + + host + } + Err(_) => { + // we already print it in the method above, no need to do it again + String::new() + } + } +} + +fn generate_smtp_img_src(embed_images: bool, domains: &str) -> String { if embed_images { "cid:".to_string() } else { + let domain = domains.split(',').next().expect("Domain missing"); format!("{domain}/vw_static/") } } @@ -1082,6 +1144,7 @@ impl Config { _env, _usr, _overrides, + domain_hostmap: OnceLock::new(), }), }) } @@ -1249,6 +1312,45 @@ impl Config { } } } + + fn get_domain_hostmap(&self, host: &str) -> Option { + // This is done to prevent deadlock, when read-locking an rwlock twice + let domains = self.domain(); + + self.inner + .read() + .unwrap() + .domain_hostmap + .get_or_init(|| { + domains + .split(',') + .map(|d| { + let host_info = HostInfo { + base_url: d.to_string(), + origin: extract_url_origin(d), + }; + + (extract_url_host(d), host_info) + }) + .collect() + }) + .get(host) + .cloned() + } + + pub fn host_to_origin(&self, host: &str) -> Option { + self.get_domain_hostmap(host).map(|v| v.origin) + } + + pub fn host_to_domain(&self, host: &str) -> Option { + self.get_domain_hostmap(host).map(|v| v.base_url) + } + + // Yes this is a base_url + // But the configuration precedent says, that we call this a domain. + pub fn main_domain(&self) -> String { + self.domain().split(',').nth(0).expect("Missing domain").to_string() + } } use handlebars::{ diff --git a/src/db/models/attachment.rs b/src/db/models/attachment.rs index f8eca72f68..6d3df3693d 100644 --- a/src/db/models/attachment.rs +++ b/src/db/models/attachment.rs @@ -35,15 +35,15 @@ impl Attachment { format!("{}/{}/{}", CONFIG.attachments_folder(), self.cipher_uuid, self.id) } - pub fn get_url(&self, host: &str) -> String { + pub fn get_url(&self, base_url: &str) -> String { let token = encode_jwt(&generate_file_download_claims(self.cipher_uuid.clone(), self.id.clone())); - format!("{}/attachments/{}/{}?token={}", host, self.cipher_uuid, self.id, token) + format!("{}/attachments/{}/{}?token={}", base_url, self.cipher_uuid, self.id, token) } - pub fn to_json(&self, host: &str) -> Value { + pub fn to_json(&self, base_url: &str) -> Value { json!({ "Id": self.id, - "Url": self.get_url(host), + "Url": self.get_url(base_url), "FileName": self.file_name, "Size": self.file_size.to_string(), "SizeName": crate::util::get_display_size(self.file_size), diff --git a/src/db/models/cipher.rs b/src/db/models/cipher.rs index 3ed3401a5a..38a7c60ddd 100644 --- a/src/db/models/cipher.rs +++ b/src/db/models/cipher.rs @@ -115,7 +115,7 @@ use crate::error::MapResult; impl Cipher { pub async fn to_json( &self, - host: &str, + base_url: &str, user_uuid: &str, cipher_sync_data: Option<&CipherSyncData>, sync_type: CipherSyncType, @@ -126,12 +126,12 @@ impl Cipher { let mut attachments_json: Value = Value::Null; if let Some(cipher_sync_data) = cipher_sync_data { if let Some(attachments) = cipher_sync_data.cipher_attachments.get(&self.uuid) { - attachments_json = attachments.iter().map(|c| c.to_json(host)).collect(); + attachments_json = attachments.iter().map(|c| c.to_json(base_url)).collect(); } } else { let attachments = Attachment::find_by_cipher(&self.uuid, conn).await; if !attachments.is_empty() { - attachments_json = attachments.iter().map(|c| c.to_json(host)).collect() + attachments_json = attachments.iter().map(|c| c.to_json(base_url)).collect() } } diff --git a/src/mail.rs b/src/mail.rs index 151554a1fd..6ac1b0c7fc 100644 --- a/src/mail.rs +++ b/src/mail.rs @@ -118,6 +118,10 @@ fn get_template(template_name: &str, data: &serde_json::Value) -> Result<(String Ok((subject, body)) } +fn mail_domain() -> String { + CONFIG.main_domain() +} + pub async fn send_password_hint(address: &str, hint: Option) -> EmptyResult { let template_name = if hint.is_some() { "email/pw_hint_some" @@ -128,7 +132,7 @@ pub async fn send_password_hint(address: &str, hint: Option) -> EmptyRes let (subject, body_html, body_text) = get_text( template_name, json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "hint": hint, }), @@ -144,7 +148,7 @@ pub async fn send_delete_account(address: &str, uuid: &str) -> EmptyResult { let (subject, body_html, body_text) = get_text( "email/delete_account", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "user_id": uuid, "email": percent_encode(address.as_bytes(), NON_ALPHANUMERIC).to_string(), @@ -162,7 +166,7 @@ pub async fn send_verify_email(address: &str, uuid: &str) -> EmptyResult { let (subject, body_html, body_text) = get_text( "email/verify_email", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "user_id": uuid, "email": percent_encode(address.as_bytes(), NON_ALPHANUMERIC).to_string(), @@ -177,7 +181,7 @@ pub async fn send_welcome(address: &str) -> EmptyResult { let (subject, body_html, body_text) = get_text( "email/welcome", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), }), )?; @@ -192,7 +196,7 @@ pub async fn send_welcome_must_verify(address: &str, uuid: &str) -> EmptyResult let (subject, body_html, body_text) = get_text( "email/welcome_must_verify", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "user_id": uuid, "token": verify_email_token, @@ -206,7 +210,7 @@ pub async fn send_2fa_removed_from_org(address: &str, org_name: &str) -> EmptyRe let (subject, body_html, body_text) = get_text( "email/send_2fa_removed_from_org", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "org_name": org_name, }), @@ -219,7 +223,7 @@ pub async fn send_single_org_removed_from_org(address: &str, org_name: &str) -> let (subject, body_html, body_text) = get_text( "email/send_single_org_removed_from_org", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "org_name": org_name, }), @@ -248,7 +252,7 @@ pub async fn send_invite( let (subject, body_html, body_text) = get_text( "email/send_org_invite", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "org_id": org_id.as_deref().unwrap_or("_"), "org_user_id": org_user_id.as_deref().unwrap_or("_"), @@ -282,7 +286,7 @@ pub async fn send_emergency_access_invite( let (subject, body_html, body_text) = get_text( "email/send_emergency_access_invite", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "emer_id": emer_id, "email": percent_encode(address.as_bytes(), NON_ALPHANUMERIC).to_string(), @@ -298,7 +302,7 @@ pub async fn send_emergency_access_invite_accepted(address: &str, grantee_email: let (subject, body_html, body_text) = get_text( "email/emergency_access_invite_accepted", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantee_email": grantee_email, }), @@ -311,7 +315,7 @@ pub async fn send_emergency_access_invite_confirmed(address: &str, grantor_name: let (subject, body_html, body_text) = get_text( "email/emergency_access_invite_confirmed", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantor_name": grantor_name, }), @@ -324,7 +328,7 @@ pub async fn send_emergency_access_recovery_approved(address: &str, grantor_name let (subject, body_html, body_text) = get_text( "email/emergency_access_recovery_approved", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantor_name": grantor_name, }), @@ -342,7 +346,7 @@ pub async fn send_emergency_access_recovery_initiated( let (subject, body_html, body_text) = get_text( "email/emergency_access_recovery_initiated", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantee_name": grantee_name, "atype": atype, @@ -362,7 +366,7 @@ pub async fn send_emergency_access_recovery_reminder( let (subject, body_html, body_text) = get_text( "email/emergency_access_recovery_reminder", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantee_name": grantee_name, "atype": atype, @@ -377,7 +381,7 @@ pub async fn send_emergency_access_recovery_rejected(address: &str, grantor_name let (subject, body_html, body_text) = get_text( "email/emergency_access_recovery_rejected", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantor_name": grantor_name, }), @@ -390,7 +394,7 @@ pub async fn send_emergency_access_recovery_timed_out(address: &str, grantee_nam let (subject, body_html, body_text) = get_text( "email/emergency_access_recovery_timed_out", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "grantee_name": grantee_name, "atype": atype, @@ -404,7 +408,7 @@ pub async fn send_invite_accepted(new_user_email: &str, address: &str, org_name: let (subject, body_html, body_text) = get_text( "email/invite_accepted", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "email": new_user_email, "org_name": org_name, @@ -418,7 +422,7 @@ pub async fn send_invite_confirmed(address: &str, org_name: &str) -> EmptyResult let (subject, body_html, body_text) = get_text( "email/invite_confirmed", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "org_name": org_name, }), @@ -435,7 +439,7 @@ pub async fn send_new_device_logged_in(address: &str, ip: &str, dt: &NaiveDateTi let (subject, body_html, body_text) = get_text( "email/new_device_logged_in", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "ip": ip, "device": device, @@ -454,7 +458,7 @@ pub async fn send_incomplete_2fa_login(address: &str, ip: &str, dt: &NaiveDateTi let (subject, body_html, body_text) = get_text( "email/incomplete_2fa_login", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "ip": ip, "device": device, @@ -470,7 +474,7 @@ pub async fn send_token(address: &str, token: &str) -> EmptyResult { let (subject, body_html, body_text) = get_text( "email/twofactor_email", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "token": token, }), @@ -483,7 +487,7 @@ pub async fn send_change_email(address: &str, token: &str) -> EmptyResult { let (subject, body_html, body_text) = get_text( "email/change_email", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "token": token, }), @@ -496,7 +500,7 @@ pub async fn send_test(address: &str) -> EmptyResult { let (subject, body_html, body_text) = get_text( "email/smtp_test", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), }), )?; @@ -508,7 +512,7 @@ pub async fn send_admin_reset_password(address: &str, user_name: &str, org_name: let (subject, body_html, body_text) = get_text( "email/admin_reset_password", json!({ - "url": CONFIG.domain(), + "url": mail_domain(), "img_src": CONFIG._smtp_img_src(), "user_name": user_name, "org_name": org_name, diff --git a/src/util.rs b/src/util.rs index e96a1741a4..e97da17af3 100644 --- a/src/util.rs +++ b/src/util.rs @@ -18,7 +18,7 @@ use tokio::{ time::{sleep, Duration}, }; -use crate::CONFIG; +use crate::{config::extract_url_host, CONFIG}; pub struct AppHeaders(); @@ -130,9 +130,19 @@ impl Cors { // If a match exists, return it. Otherwise, return None. fn get_allowed_origin(headers: &HeaderMap<'_>) -> Option { let origin = Cors::get_header(headers, "Origin"); - let domain_origin = CONFIG.domain_origin(); + + let domain_origin_opt = CONFIG.host_to_origin(&extract_url_host(&origin)); let safari_extension_origin = "file://"; - if origin == domain_origin || origin == safari_extension_origin { + + let found_origin = { + if let Some(domain_origin) = domain_origin_opt { + origin == domain_origin + } else { + false + } + }; + + if found_origin || origin == safari_extension_origin { Some(origin) } else { None