Skip to content

Commit

Permalink
Key package rotation take 2 (#1079)
Browse files Browse the repository at this point in the history
* Create database tables

* Actually perform rotation

* Fix duplicate import
  • Loading branch information
neekolas authored Sep 20, 2024
1 parent 998623d commit cd0fc5c
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 18 deletions.
102 changes: 96 additions & 6 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use openmls::{
framing::{MlsMessageBodyIn, MlsMessageIn},
group::GroupEpoch,
messages::Welcome,
prelude::tls_codec::{Deserialize, Error as TlsCodecError, Serialize},
prelude::tls_codec::{Deserialize, Error as TlsCodecError},
};
use openmls_traits::OpenMlsProvider;
use prost::EncodeError;
Expand Down Expand Up @@ -586,11 +586,13 @@ where
/// Upload a new key package to the network replacing an existing key package
/// This is expected to be run any time the client receives new Welcome messages
pub async fn rotate_key_package(&self) -> Result<(), ClientError> {
let provider: XmtpOpenMlsProvider = self.store().conn()?.into();

let kp = self.identity().new_key_package(&provider)?;
let kp_bytes = kp.tls_serialize_detached()?;
self.api_client.upload_key_package(kp_bytes, true).await?;
self.store()
.transaction_async(|provider| async move {
self.identity()
.rotate_key_package(&provider, &self.api_client)
.await
})
.await?;

Ok(())
}
Expand Down Expand Up @@ -668,6 +670,7 @@ where
/// Returns any new groups created in the operation
pub async fn sync_welcomes(&self) -> Result<Vec<MlsGroup>, ClientError> {
let envelopes = self.query_welcome_messages(&self.store().conn()?).await?;
let num_envelopes = envelopes.len();
let id = self.installation_public_key();

let groups: Vec<MlsGroup> = stream::iter(envelopes.into_iter())
Expand Down Expand Up @@ -717,6 +720,11 @@ where
.collect()
.await;

// If any welcomes were found, rotate your key package
if num_envelopes > 0 {
self.rotate_key_package().await?;
}

Ok(groups)
}

Expand Down Expand Up @@ -848,12 +856,16 @@ mod tests {
builder::ClientBuilder,
groups::GroupMetadataOptions,
hpke::{decrypt_welcome, encrypt_welcome},
identity::serialize_key_package_hash_ref,
storage::{
consent_record::{ConsentState, ConsentType, StoredConsentRecord},
schema::identity_updates,
},
XmtpApi,
};

use super::Client;

#[tokio::test]
async fn test_group_member_recovery() {
let amal = ClientBuilder::new_test_client(&generate_local_wallet()).await;
Expand Down Expand Up @@ -1179,4 +1191,82 @@ mod tests {
assert_eq!(inbox_consent, ConsentState::Denied);
assert_eq!(address_consent, ConsentState::Denied);
}

async fn get_key_package_init_key<ApiClient: XmtpApi>(
client: &Client<ApiClient>,
installation_id: &[u8],
) -> Vec<u8> {
let kps = client
.get_key_packages_for_installation_ids(vec![installation_id.to_vec()])
.await
.unwrap();
let kp = kps.first().unwrap();

serialize_key_package_hash_ref(&kp.inner, &client.mls_provider().unwrap()).unwrap()
}

#[tokio::test]
async fn test_key_package_rotation() {
let alix_wallet = generate_local_wallet();
let bo_wallet = generate_local_wallet();
let alix = ClientBuilder::new_test_client(&alix_wallet).await;
let bo = ClientBuilder::new_test_client(&bo_wallet).await;
let bo_store = bo.store();

let alix_original_init_key =
get_key_package_init_key(&alix, &alix.installation_public_key()).await;
let bo_original_init_key =
get_key_package_init_key(&bo, &bo.installation_public_key()).await;

// Bo's original key should be deleted
let bo_original_from_db = bo_store
.conn()
.unwrap()
.find_key_package_history_entry_by_hash_ref(bo_original_init_key.clone());
assert!(bo_original_from_db.is_ok());

alix.create_group_with_members(
vec![bo_wallet.get_address()],
None,
GroupMetadataOptions::default(),
)
.await
.unwrap();

bo.sync_welcomes().await.unwrap();

let bo_new_key = get_key_package_init_key(&bo, &bo.installation_public_key()).await;
// Bo's key should have changed
assert_ne!(bo_original_init_key, bo_new_key);

bo.sync_welcomes().await.unwrap();
let bo_new_key_2 = get_key_package_init_key(&bo, &bo.installation_public_key()).await;
// Bo's key should not have changed syncing the second time.
assert_eq!(bo_new_key, bo_new_key_2);

alix.sync_welcomes().await.unwrap();
let alix_key_2 = get_key_package_init_key(&alix, &alix.installation_public_key()).await;
// Alix's key should not have changed at all
assert_eq!(alix_original_init_key, alix_key_2);

alix.create_group_with_members(
vec![bo_wallet.get_address()],
None,
GroupMetadataOptions::default(),
)
.await
.unwrap();
bo.sync_welcomes().await.unwrap();

// Bo should have two groups now
let bo_groups = bo.find_groups(None, None, None, None).unwrap();
assert_eq!(bo_groups.len(), 2);

// Bo's original key should be deleted
let bo_original_after_delete = bo_store
.conn()
.unwrap()
.find_key_package_history_entry_by_hash_ref(bo_original_init_key);
assert!(bo_original_after_delete.is_err());
}
}
75 changes: 63 additions & 12 deletions xmtp_mls/src/identity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use crate::{retryable, Fetch, Store};
use ed25519_dalek::SigningKey;
use log::debug;
use log::info;
use openmls::prelude::hash_ref::HashReference;
use openmls::prelude::tls_codec::Serialize;
use openmls::{
credentials::{errors::BasicCredentialError, BasicCredential, CredentialWithKey},
Expand All @@ -29,6 +30,7 @@ use openmls::{
prelude_test::KeyPackage,
};
use openmls_basic_credential::SignatureKeyPair;
use openmls_traits::storage::StorageProvider;
use openmls_traits::types::CryptoError;
use openmls_traits::OpenMlsProvider;
use prost::Message;
Expand Down Expand Up @@ -162,6 +164,8 @@ pub enum IdentityError {
RequiredIdentityNotFound,
#[error("error creating new identity: {0}")]
NewIdentity(String),
#[error(transparent)]
DieselResult(#[from] diesel::result::Error),
}

impl RetryableError for IdentityError {
Expand All @@ -171,6 +175,7 @@ impl RetryableError for IdentityError {
Self::WrappedApi(err) => retryable!(err),
Self::StorageError(err) => retryable!(err),
Self::OpenMlsStorageError(err) => retryable!(err),
Self::DieselResult(err) => retryable!(err),
_ => false,
}
}
Expand Down Expand Up @@ -424,16 +429,7 @@ impl Identity {
// This is needed to get to the private key when decrypting welcome messages.
let public_init_key = kp.key_package().hpke_init_key().tls_serialize_detached()?;

let key_package_hash_ref = match kp.key_package().hash_ref(provider.crypto()) {
Ok(key_package_hash_ref) => key_package_hash_ref,
Err(_) => return Err(IdentityError::UninitializedIdentity),
};

// Serialize the hash reference (with bincode)
let hash_ref = match bincode::serialize(&key_package_hash_ref) {
Ok(hash_ref) => hash_ref,
Err(_) => return Err(IdentityError::UninitializedIdentity),
};
let hash_ref = serialize_key_package_hash_ref(kp.key_package(), &provider)?;
// Store the hash reference, keyed with the public init key
provider
.storage()
Expand All @@ -455,15 +451,70 @@ impl Identity {
info!("Identity already registered. skipping key package publishing");
return Ok(());
}

self.rotate_key_package(provider, api_client).await?;
self.is_ready.store(true, Ordering::SeqCst);

Ok(StoredIdentity::try_from(self)?.store(provider.conn_ref())?)
}

pub(crate) async fn rotate_key_package<ApiClient: XmtpApi>(
&self,
provider: &XmtpOpenMlsProvider,
api_client: &ApiClientWrapper<ApiClient>,
) -> Result<(), IdentityError> {
let kp = self.new_key_package(provider)?;
let kp_bytes = kp.tls_serialize_detached()?;
let conn = provider.conn_ref();
let hash_ref = serialize_key_package_hash_ref(&kp, provider)?;
let history_id = conn.store_key_package_history_entry(hash_ref)?.id;
let old_id = history_id - 1;

// Find all key packages that are not the current or previous KPs
// We can delete before uploading because this is either run inside a transaction or is being applied to a brand
// new identity
let old_key_packages = conn.find_key_package_history_entries_before_id(old_id)?;
for kp in old_key_packages {
self.delete_key_package(provider, kp.key_package_hash_ref)?;
}
conn.delete_key_package_history_entries_before_id(old_id)?;

api_client.upload_key_package(kp_bytes, true).await?;
self.is_ready.store(true, Ordering::SeqCst);
Ok(())
}

Ok(StoredIdentity::try_from(self)?.store(provider.conn_ref())?)
pub(crate) fn delete_key_package(
&self,
provider: &XmtpOpenMlsProvider,
hash_ref: Vec<u8>,
) -> Result<(), IdentityError> {
let openmls_hash_ref = deserialize_key_package_hash_ref(&hash_ref)?;
provider.storage().delete_key_package(&openmls_hash_ref)?;

Ok(())
}
}

pub(crate) fn serialize_key_package_hash_ref(
kp: &KeyPackage,
provider: &impl OpenMlsProvider<StorageProvider = SqlKeyStore>,
) -> Result<Vec<u8>, IdentityError> {
let key_package_hash_ref = kp
.hash_ref(provider.crypto())
.map_err(|_| IdentityError::UninitializedIdentity)?;
let serialized = bincode::serialize(&key_package_hash_ref)
.map_err(|_| IdentityError::UninitializedIdentity)?;

Ok(serialized)
}

fn deserialize_key_package_hash_ref(hash_ref: &[u8]) -> Result<HashReference, IdentityError> {
let key_package_hash_ref: HashReference =
bincode::deserialize(hash_ref).map_err(|_| IdentityError::UninitializedIdentity)?;

Ok(key_package_hash_ref)
}

async fn sign_with_installation_key(
signature_text: String,
installation_private_key: &[u8; 32],
Expand Down
15 changes: 15 additions & 0 deletions xmtp_mls/src/storage/encrypted_store/key_package_history.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,21 @@ impl DbConnection {

Ok(result)
}

pub fn delete_key_package_history_entries_before_id(
&self,
id: i32,
) -> Result<(), StorageError> {
self.raw_query(|conn| {
diesel::delete(
key_package_history::dsl::key_package_history
.filter(key_package_history::dsl::id.lt(id)),
)
.execute(conn)
})?;

Ok(())
}
}

#[cfg(test)]
Expand Down

0 comments on commit cd0fc5c

Please sign in to comment.