Skip to content

Commit

Permalink
add send_id newtype
Browse files Browse the repository at this point in the history
  • Loading branch information
stefan0xC committed Dec 23, 2024
1 parent e361d24 commit eff2ea0
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/api/core/accounts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ fn validate_keydata(
}

// Check that we're correctly rotating all the user's sends
let existing_send_ids = existing_sends.iter().map(|s| s.uuid.as_str()).collect::<HashSet<_>>();
let provided_send_ids = data.sends.iter().filter_map(|s| s.id.as_deref()).collect::<HashSet<_>>();
let existing_send_ids = existing_sends.iter().map(|s| &s.uuid).collect::<HashSet<&SendId>>();
let provided_send_ids = data.sends.iter().filter_map(|s| s.id.as_ref()).collect::<HashSet<&SendId>>();
if !provided_send_ids.is_superset(&existing_send_ids) {
err!("All existing sends must be included in the rotation")
}
Expand Down
54 changes: 30 additions & 24 deletions src/api/core/sends.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub struct SendData {
file_length: Option<NumberOrString>,

// Used for key rotations
pub id: Option<String>,
pub id: Option<SendId>,
}

/// Enforces the `Disable Send` policy. A non-owner/admin user belonging to
Expand Down Expand Up @@ -158,8 +158,8 @@ async fn get_sends(headers: Headers, mut conn: DbConn) -> Json<Value> {
}

#[get("/sends/<uuid>")]
async fn get_send(uuid: &str, headers: Headers, mut conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await {
async fn get_send(uuid: SendId, headers: Headers, mut conn: DbConn) -> JsonResult {
match Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await {
Some(send) => Ok(Json(send.to_json())),
None => err!("Send not found", "Invalid uuid or does not belong to user"),
}
Expand Down Expand Up @@ -249,7 +249,7 @@ async fn post_send_file(data: Form<UploadData<'_>>, headers: Headers, mut conn:
err!("Send content is not a file");
}

let file_id = crate::crypto::generate_send_id();
let file_id = crate::crypto::generate_send_file_id();
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(&send.uuid);
let file_path = folder_path.join(&file_id);
tokio::fs::create_dir_all(&folder_path).await?;
Expand Down Expand Up @@ -324,7 +324,7 @@ async fn post_send_file_v2(data: Json<SendData>, headers: Headers, mut conn: DbC

let mut send = create_send(data, headers.user.uuid)?;

let file_id = crate::crypto::generate_send_id();
let file_id = crate::crypto::generate_send_file_id();

let mut data_value: Value = serde_json::from_str(&send.data)?;
if let Some(o) = data_value.as_object_mut() {
Expand Down Expand Up @@ -352,9 +352,9 @@ pub struct SendFileData {
}

// https://github.com/bitwarden/server/blob/66f95d1c443490b653e5a15d32977e2f5a3f9e32/src/Api/Tools/Controllers/SendsController.cs#L250
#[post("/sends/<send_uuid>/file/<file_id>", format = "multipart/form-data", data = "<data>")]
#[post("/sends/<uuid>/file/<file_id>", format = "multipart/form-data", data = "<data>")]
async fn post_send_file_v2_data(
send_uuid: &str,
uuid: SendId,
file_id: &str,
data: Form<UploadDataV2<'_>>,
headers: Headers,
Expand All @@ -365,7 +365,7 @@ async fn post_send_file_v2_data(

let mut data = data.into_inner();

let Some(send) = Send::find_by_uuid_and_user(send_uuid, &headers.user.uuid, &mut conn).await else {
let Some(send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found. Unable to save the file.", "Invalid uuid or does not belong to user.")
};

Expand Down Expand Up @@ -402,7 +402,7 @@ async fn post_send_file_v2_data(
err!("Send file size does not match.", format!("Expected a file size of {} got {size}", send_data.size));
}

let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(send_uuid);
let folder_path = tokio::fs::canonicalize(&CONFIG.sends_folder()).await?.join(uuid);
let file_path = folder_path.join(file_id);

// Check if the file already exists, if that is the case do not overwrite it
Expand Down Expand Up @@ -493,16 +493,16 @@ async fn post_access(
Ok(Json(send.to_json_access(&mut conn).await))
}

#[post("/sends/<send_id>/access/file/<file_id>", data = "<data>")]
#[post("/sends/<uuid>/access/file/<file_id>", data = "<data>")]
async fn post_access_file(
send_id: &str,
uuid: SendId,
file_id: &str,
data: Json<SendAccessData>,
host: Host,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
let Some(mut send) = Send::find_by_uuid(send_id, &mut conn).await else {
let Some(mut send) = Send::find_by_uuid(&uuid, &mut conn).await else {
err_code!(SEND_INACCESSIBLE_MSG, 404)
};

Expand Down Expand Up @@ -547,33 +547,39 @@ async fn post_access_file(
)
.await;

let token_claims = crate::auth::generate_send_claims(send_id, file_id);
let token_claims = crate::auth::generate_send_claims(&uuid, file_id);
let token = crate::auth::encode_jwt(&token_claims);
Ok(Json(json!({
"object": "send-fileDownload",
"id": file_id,
"url": format!("{}/api/sends/{}/{}?t={}", &host.host, send_id, file_id, token)
"url": format!("{}/api/sends/{}/{}?t={}", &host.host, uuid, file_id, token)
})))
}

#[get("/sends/<send_id>/<file_id>?<t>")]
async fn download_send(send_id: SafeString, file_id: SafeString, t: &str) -> Option<NamedFile> {
#[get("/sends/<uuid>/<file_id>?<t>")]
async fn download_send(uuid: SendId, file_id: SafeString, t: &str) -> Option<NamedFile> {
if let Ok(claims) = crate::auth::decode_send(t) {
if claims.sub == format!("{send_id}/{file_id}") {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(send_id).join(file_id)).await.ok();
if claims.sub == format!("{uuid}/{file_id}") {
return NamedFile::open(Path::new(&CONFIG.sends_folder()).join(uuid).join(file_id)).await.ok();
}
}
None
}

#[put("/sends/<uuid>", data = "<data>")]
async fn put_send(uuid: &str, data: Json<SendData>, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
async fn put_send(
uuid: SendId,
data: Json<SendData>,
headers: Headers,
mut conn: DbConn,
nt: Notify<'_>,
) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;

let data: SendData = data.into_inner();
enforce_disable_hide_email_policy(&data, &headers, &mut conn).await?;

let Some(mut send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else {
let Some(mut send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found", "Send uuid is invalid or does not belong to user")
};

Expand Down Expand Up @@ -641,8 +647,8 @@ pub async fn update_send_from_data(
}

#[delete("/sends/<uuid>")]
async fn delete_send(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else {
async fn delete_send(uuid: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> EmptyResult {
let Some(send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user")
};

Expand All @@ -660,10 +666,10 @@ async fn delete_send(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<
}

#[put("/sends/<uuid>/remove-password")]
async fn put_remove_password(uuid: &str, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
async fn put_remove_password(uuid: SendId, headers: Headers, mut conn: DbConn, nt: Notify<'_>) -> JsonResult {
enforce_disable_send_policy(&headers, &mut conn).await?;

let Some(mut send) = Send::find_by_uuid_and_user(uuid, &headers.user.uuid, &mut conn).await else {
let Some(mut send) = Send::find_by_uuid_and_user(&uuid, &headers.user.uuid, &mut conn).await else {
err!("Send not found", "Invalid send uuid, or does not belong to user")
};

Expand Down
2 changes: 1 addition & 1 deletion src/api/notifications.rs
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ impl WebSocketUsers {

let data = create_update(
vec![
("Id".into(), send.uuid.clone().into()),
("Id".into(), send.uuid.to_string().into()),
("UserId".into(), user_uuid),
("RevisionDate".into(), serialize_date(send.revision_date)),
],
Expand Down
6 changes: 3 additions & 3 deletions src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use std::{
};

use crate::db::models::{
AttachmentId, CipherId, CollectionId, DeviceId, MembershipId, OrgApiKeyId, OrganizationId, UserId,
AttachmentId, CipherId, CollectionId, DeviceId, MembershipId, OrgApiKeyId, OrganizationId, SendId, UserId,
};
use crate::{error::Error, CONFIG};

Expand Down Expand Up @@ -358,13 +358,13 @@ pub fn generate_admin_claims() -> BasicJwtClaims {
}
}

pub fn generate_send_claims(send_id: &str, file_id: &str) -> BasicJwtClaims {
pub fn generate_send_claims(uuid: &SendId, file_id: &str) -> BasicJwtClaims {
let time_now = Utc::now();
BasicJwtClaims {
nbf: time_now.timestamp(),
exp: (time_now + TimeDelta::try_minutes(2).unwrap()).timestamp(),
iss: JWT_SEND_ISSUER.to_string(),
sub: format!("{send_id}/{file_id}"),
sub: format!("{uuid}/{file_id}"),
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ pub fn generate_id<const N: usize>() -> String {
encode_random_bytes::<N>(HEXLOWER)
}

pub fn generate_send_id() -> String {
// Send IDs are globally scoped, so make them longer to avoid collisions.
pub fn generate_send_file_id() -> String {
// Send File IDs are globally scoped, so make them longer to avoid collisions.
generate_id::<32>() // 256 bits
}

Expand Down
2 changes: 1 addition & 1 deletion src/db/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub use self::organization::{
Membership, MembershipId, MembershipStatus, MembershipType, OrgApiKeyId, Organization, OrganizationApiKey,
OrganizationId,
};
pub use self::send::{Send, SendType};
pub use self::send::{id::SendId, Send, SendType};
pub use self::two_factor::{TwoFactor, TwoFactorType};
pub use self::two_factor_duo_context::TwoFactorDuoContext;
pub use self::two_factor_incomplete::TwoFactorIncomplete;
Expand Down
43 changes: 38 additions & 5 deletions src/db/models/send.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@ use serde_json::Value;
use crate::util::LowerCase;

use super::{OrganizationId, User, UserId};
use id::SendId;

db_object! {
#[derive(Identifiable, Queryable, Insertable, AsChangeset)]
#[diesel(table_name = sends)]
#[diesel(treat_none_as_null = true)]
#[diesel(primary_key(uuid))]
pub struct Send {
pub uuid: String,
pub uuid: SendId,

pub user_uuid: Option<UserId>,
pub organization_uuid: Option<OrganizationId>,
Expand Down Expand Up @@ -50,7 +51,7 @@ impl Send {
let now = Utc::now().naive_utc();

Self {
uuid: crate::util::get_uuid(),
uuid: SendId::from(crate::util::get_uuid()),
user_uuid: None,
organization_uuid: None,

Expand Down Expand Up @@ -272,14 +273,14 @@ impl Send {
};

let uuid = match Uuid::from_slice(&uuid_vec) {
Ok(u) => u.to_string(),
Ok(u) => SendId::from(u.to_string()),
Err(_) => return None,
};

Self::find_by_uuid(&uuid, conn).await
}

pub async fn find_by_uuid(uuid: &str, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid(uuid: &SendId, conn: &mut DbConn) -> Option<Self> {
db_run! {conn: {
sends::table
.filter(sends::uuid.eq(uuid))
Expand All @@ -289,7 +290,7 @@ impl Send {
}}
}

pub async fn find_by_uuid_and_user(uuid: &str, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
pub async fn find_by_uuid_and_user(uuid: &SendId, user_uuid: &UserId, conn: &mut DbConn) -> Option<Self> {
db_run! {conn: {
sends::table
.filter(sends::uuid.eq(uuid))
Expand Down Expand Up @@ -348,3 +349,35 @@ impl Send {
}}
}
}

// separate namespace to avoid name collision with std::marker::Send
pub mod id {
use derive_more::{AsRef, Deref, Display, From};
use rocket::request::FromParam;
use std::marker::Send;
use std::path::Path;
#[derive(
Clone, Debug, AsRef, Deref, DieselNewType, Display, From, FromForm, Hash, PartialEq, Eq, Serialize, Deserialize,
)]
pub struct SendId(String);

impl AsRef<Path> for SendId {
#[inline]
fn as_ref(&self) -> &Path {
Path::new(&self.0)
}
}

impl<'r> FromParam<'r> for SendId {
type Error = ();

#[inline(always)]
fn from_param(param: &'r str) -> Result<Self, Self::Error> {
if param.chars().all(|c| matches!(c, 'a'..='z' | 'A'..='Z' |'0'..='9' | '-')) {
Ok(Self(param.to_string()))
} else {
Err(())
}
}
}
}

0 comments on commit eff2ea0

Please sign in to comment.