Skip to content

Commit 65a090f

Browse files
committed
Fix simple_scenario_sync and remove async for now
1 parent cf32ac9 commit 65a090f

File tree

4 files changed

+102
-34
lines changed

4 files changed

+102
-34
lines changed

mls-rs-uniffi/src/config/group_state.rs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use mls_rs_core::mls_rs_codec::{MlsDecode, MlsEncode};
44

55
use super::FFICallbackError;
66

7-
#[derive(Clone, Debug, uniffi::Object)]
7+
#[derive(Clone, Debug, uniffi::Record)]
88
pub struct GroupState {
99
pub id: Vec<u8>,
1010
pub data: Vec<u8>,
@@ -16,7 +16,7 @@ impl mls_rs_core::group::GroupState for GroupState {
1616
}
1717
}
1818

19-
#[derive(Clone, Debug, uniffi::Object)]
19+
#[derive(Clone, Debug, uniffi::Record)]
2020
pub struct EpochRecord {
2121
pub id: u64,
2222
pub data: Vec<u8>,
@@ -41,9 +41,9 @@ pub trait GroupStateStorage: Send + Sync + Debug {
4141

4242
async fn write(
4343
&self,
44-
state: Arc<GroupState>,
45-
epoch_inserts: Vec<Arc<EpochRecord>>,
46-
epoch_updates: Vec<Arc<EpochRecord>>,
44+
state: GroupState,
45+
epoch_inserts: Vec<EpochRecord>,
46+
epoch_updates: Vec<EpochRecord>,
4747
) -> Result<(), FFICallbackError>;
4848

4949
async fn max_epoch_id(&self, group_id: Vec<u8>) -> Result<Option<u64>, FFICallbackError>;
@@ -99,16 +99,16 @@ impl mls_rs_core::group::GroupStateStorage for GroupStateStorageWrapper {
9999
ST: mls_rs_core::group::GroupState + MlsEncode + MlsDecode + Send + Sync,
100100
ET: mls_rs_core::group::EpochRecord + MlsEncode + MlsDecode + Send + Sync,
101101
{
102-
let state = Arc::new(GroupState {
102+
let state = GroupState {
103103
id: state.id(),
104104
data: state.mls_encode_to_vec()?,
105-
});
105+
};
106106

107107
let epoch_to_record = |v: ET| -> Result<_, Self::Error> {
108-
Ok(Arc::new(EpochRecord {
108+
Ok(EpochRecord {
109109
id: v.id(),
110110
data: v.mls_encode_to_vec()?,
111-
}))
111+
})
112112
};
113113

114114
let inserts = epoch_inserts

mls-rs-uniffi/src/lib.rs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ fn arc_unwrap_or_clone<T: Clone>(arc: Arc<T>) -> T {
5353
#[uniffi(flat_error)]
5454
#[non_exhaustive]
5555
pub enum Error {
56-
#[error("A mls-rs error occurred")]
56+
#[error("A mls-rs error occurred: {inner}")]
5757
MlsError {
5858
#[from]
5959
inner: mls_rs::error::MlsError,
6060
},
61-
#[error("An unknown error occurred")]
61+
#[error("An unknown error occurred: {inner}")]
6262
AnyError {
6363
#[from]
6464
inner: mls_rs::error::AnyError,
@@ -329,6 +329,19 @@ impl Client {
329329
group_info_extensions,
330330
})
331331
}
332+
333+
/// Load an existing group.
334+
///
335+
/// See [`mls_rs::Client::load_group`] for details.
336+
pub async fn load_group(&self, group_id: Vec<u8>) -> Result<Group, Error> {
337+
self.inner
338+
.load_group(&group_id)
339+
.await
340+
.map(|g| Group {
341+
inner: Arc::new(Mutex::new(g)),
342+
})
343+
.map_err(Into::into)
344+
}
332345
}
333346

334347
#[derive(Clone, Debug, uniffi::Object)]
@@ -423,6 +436,13 @@ async fn signing_identity_to_identifier(
423436
#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
424437
#[uniffi::export]
425438
impl Group {
439+
/// Write the current state of the group to storage defined by
440+
/// [`ClientConfig::group_state_storage`]
441+
pub async fn write_to_storage(&self) -> Result<(), Error> {
442+
let mut group = self.inner().await;
443+
group.write_to_storage().await.map_err(Into::into)
444+
}
445+
426446
/// Perform a commit of received proposals (or an empty commit).
427447
///
428448
/// TODO: ensure `path_required` is always set in

mls-rs-uniffi/test_bindings/simple_scenario_async.py

Lines changed: 0 additions & 19 deletions
This file was deleted.
Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,74 @@
1-
from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client
1+
from mls_rs_uniffi import CipherSuite, generate_signature_keypair, Client, GroupStateStorage, ClientConfig
2+
3+
class EpochData:
4+
def __init__(self, id: "int", data: "bytes"):
5+
self.id = id
6+
self.data = data
7+
8+
class GroupStateData:
9+
def __init__(self, state: "bytes"):
10+
self.state = state
11+
self.epoch_data = []
12+
13+
class PythonGroupStateStorage(GroupStateStorage):
14+
def __init__(self):
15+
self.groups = {}
16+
17+
def state(self, group_id: "bytes"):
18+
group = self.groups.get(group_id.hex())
19+
20+
if group == None:
21+
return None
22+
23+
group.state
24+
25+
def epoch(self, group_id: "bytes",epoch_id: "int"):
26+
group = self.groups[group_id.hex()]
27+
28+
if group == None:
29+
return None
30+
31+
for epoch in group.epoch_data:
32+
if epoch.id == epoch_id:
33+
return epoch
34+
35+
return None
36+
37+
def write(self, state: "GroupState",epoch_inserts: "typing.List[EpochRecord]",epoch_updates: "typing.List[EpochRecord]"):
38+
if self.groups.get(state.id.hex()) == None:
39+
self.groups[state.id.hex()] = GroupStateData(state.data)
40+
41+
group = self.groups[state.id.hex()]
42+
43+
for insert in epoch_inserts:
44+
group.epoch_data.append(insert)
45+
46+
for update in epoch_updates:
47+
for i in range(len(group.epoch_data)):
48+
if group.epoch_data[i].id == update.id:
49+
group.epoch_data[i] = update
50+
51+
def max_epoch_id(self, group_id: "bytes"):
52+
group = self.groups.get(group_id.hex())
53+
54+
if group == None:
55+
return None
56+
57+
last = group.epoch_data.last()
58+
59+
if last == None:
60+
return None
61+
62+
return last.id
63+
64+
group_state_storage = PythonGroupStateStorage()
65+
client_config = ClientConfig(group_state_storage)
266

367
key = generate_signature_keypair(CipherSuite.CURVE25519_AES128)
4-
alice = Client(b'alice', key)
68+
alice = Client(b'alice', key, client_config)
569

670
key = generate_signature_keypair(CipherSuite.CURVE25519_AES128)
7-
bob = Client(b'bob', key)
71+
bob = Client(b'bob', key, client_config)
872

973
alice = alice.create_group(None)
1074
kp = bob.generate_key_package_message()
@@ -15,4 +79,7 @@
1579
msg = alice.encrypt_application_message(b'hello, bob')
1680
output = bob.process_incoming_message(msg)
1781

18-
assert output.data == b'hello, bob'
82+
alice.write_to_storage()
83+
84+
assert output.data == b'hello, bob'
85+
assert len(group_state_storage.groups) == 1

0 commit comments

Comments
 (0)