1
- use std:: sync:: { Arc , Mutex } ;
1
+ use std:: {
2
+ path:: Path ,
3
+ sync:: { Arc , Mutex } ,
4
+ time:: { Duration , Instant } ,
5
+ } ;
2
6
3
7
use anyhow:: Result ;
4
8
use assert_matches:: assert_matches;
@@ -14,14 +18,20 @@ use matrix_sdk::{
14
18
api:: client:: room:: create_room:: v3:: Request as CreateRoomRequest ,
15
19
events:: {
16
20
key:: verification:: { request:: ToDeviceKeyVerificationRequestEvent , VerificationMethod } ,
17
- room:: message:: {
18
- MessageType , OriginalSyncRoomMessageEvent , RoomMessageEventContent ,
19
- SyncRoomMessageEvent ,
21
+ room:: {
22
+ encryption:: RoomEncryptionEventContent ,
23
+ message:: {
24
+ MessageType , OriginalSyncRoomMessageEvent , RoomMessageEventContent ,
25
+ SyncRoomMessageEvent ,
26
+ } ,
20
27
} ,
21
28
} ,
29
+ EventEncryptionAlgorithm , EventId , OwnedEventId , OwnedRoomId , RoomId ,
22
30
} ,
23
- Client ,
31
+ Client , Room ,
24
32
} ;
33
+ use serde_json:: json;
34
+ use tempfile:: tempdir;
25
35
use tracing:: warn;
26
36
27
37
use crate :: helpers:: { SyncTokenAwareClient , TestClientBuilder } ;
@@ -289,6 +299,178 @@ async fn test_mutual_sas_verification() -> Result<()> {
289
299
Ok ( ( ) )
290
300
}
291
301
302
+ struct ClientWrapper {
303
+ pub client : SyncTokenAwareClient ,
304
+ events : Arc < Mutex < Vec < OwnedEventId > > > ,
305
+ }
306
+
307
+ impl ClientWrapper {
308
+ async fn new ( username : & str ) -> Self {
309
+ Self :: from_client_builder ( TestClientBuilder :: new ( username) . use_sqlite ( ) ) . await
310
+ }
311
+
312
+ async fn with_sqlite_dir ( username : & str , sqlite_dir : & Path ) -> Self {
313
+ Self :: from_client_builder ( TestClientBuilder :: new ( username) . use_sqlite_dir ( sqlite_dir) ) . await
314
+ }
315
+
316
+ async fn from_client_builder ( builder : TestClientBuilder ) -> Self {
317
+ let events = Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ;
318
+
319
+ let client = SyncTokenAwareClient :: new (
320
+ builder
321
+ . encryption_settings ( Self :: encryption_settings ( ) )
322
+ . build ( )
323
+ . await
324
+ . expect ( "Failed to create client" ) ,
325
+ ) ;
326
+
327
+ let events_clone = events. clone ( ) ;
328
+ client. add_event_handler ( |ev : OriginalSyncRoomMessageEvent , _: Client | async move {
329
+ events_clone. lock ( ) . unwrap ( ) . push ( ev. event_id . clone ( ) )
330
+ } ) ;
331
+
332
+ Self { client, events }
333
+ }
334
+
335
+ fn encryption_settings ( ) -> EncryptionSettings {
336
+ EncryptionSettings { auto_enable_cross_signing : true , ..Default :: default ( ) }
337
+ }
338
+
339
+ fn timeout ( ) -> Duration {
340
+ Duration :: from_secs ( 10 )
341
+ }
342
+
343
+ async fn create_room ( & self , invite : & [ & ClientWrapper ] ) -> OwnedRoomId {
344
+ let invite = invite. iter ( ) . map ( |cw| cw. client . user_id ( ) . unwrap ( ) . to_owned ( ) ) . collect ( ) ;
345
+
346
+ let request = assign ! ( CreateRoomRequest :: new( ) , {
347
+ invite,
348
+ is_direct: true ,
349
+ } ) ;
350
+
351
+ let room = self . client . create_room ( request) . await . expect ( "Failed to create room" ) ;
352
+ self . enable_encryption ( & room, 1 ) . await ;
353
+ room. room_id ( ) . to_owned ( )
354
+ }
355
+
356
+ async fn enable_encryption ( & self , room : & Room , rotation_period_msgs : usize ) {
357
+ // Adapted from crates/matrix-sdk/src/room/mod.rs enable_encryption
358
+ if !room. is_encrypted ( ) . await . expect ( "Failed to check encrypted" ) {
359
+ let content: RoomEncryptionEventContent = serde_json:: from_value ( json ! ( {
360
+ "algorithm" : EventEncryptionAlgorithm :: MegolmV1AesSha2 ,
361
+ "rotation_period_msgs" : rotation_period_msgs,
362
+ } ) )
363
+ . expect ( "Failed parsing encryption JSON" ) ;
364
+ room. send_state_event ( content) . await . expect ( "Failed to send state event" ) ;
365
+
366
+ self . client . sync_once ( ) . await . expect ( "Failed to sync" ) ;
367
+ }
368
+ }
369
+
370
+ async fn join ( & self , room_id : & RoomId ) {
371
+ let room = self . wait_until_room_exists ( room_id) . await ;
372
+ room. join ( ) . await . expect ( "Unable to join room" )
373
+ }
374
+
375
+ /// Wait (syncing if needed) until the room with supplied ID exists, or time out
376
+ async fn wait_until_room_exists ( & self , room_id : & RoomId ) -> Room {
377
+ let end_time = Instant :: now ( ) + Self :: timeout ( ) ;
378
+ while Instant :: now ( ) < end_time {
379
+ let room = self . client . get_room ( room_id) ;
380
+ if let Some ( room) = room {
381
+ return room;
382
+ }
383
+ self . client . sync_once ( ) . await . expect ( "Sync failed" ) ;
384
+ }
385
+ panic ! ( "Timed out waiting for room {room_id} to exist" ) ;
386
+ }
387
+
388
+ /// Wait (syncing if needed) until the user appears in the supplied room, or time out
389
+ async fn wait_until_user_in_room ( & self , room_id : & RoomId , other : & ClientWrapper ) {
390
+ let room = self . wait_until_room_exists ( room_id) . await ;
391
+ let user_id = other. client . user_id ( ) . unwrap ( ) ;
392
+
393
+ let end_time = Instant :: now ( ) + Self :: timeout ( ) ;
394
+ while Instant :: now ( ) < end_time {
395
+ if room. get_member_no_sync ( user_id) . await . expect ( "get_member failed" ) . is_some ( ) {
396
+ return ;
397
+ }
398
+ self . client . sync_once ( ) . await . expect ( "Sync failed" ) ;
399
+ }
400
+ panic ! ( "Timed out waiting for user {user_id} to be in room {room_id}" ) ;
401
+ }
402
+
403
+ /// Wait (syncing if needed) until the event with this ID appears, or time out
404
+ async fn wait_until_received ( & self , event_id : & EventId ) {
405
+ let event_id = event_id. to_owned ( ) ;
406
+ let end_time = Instant :: now ( ) + Self :: timeout ( ) ;
407
+ while Instant :: now ( ) < end_time {
408
+ if self . events . lock ( ) . unwrap ( ) . contains ( & event_id) {
409
+ return ;
410
+ }
411
+ self . client . sync_once ( ) . await . expect ( "Sync failed" ) ;
412
+ }
413
+ panic ! ( "Timed out waiting for event {event_id} to be received" ) ;
414
+ }
415
+
416
+ /// Send a text message in the supplied room and return the event ID
417
+ async fn send ( & self , room_id : & RoomId , message : & str ) -> OwnedEventId {
418
+ let room = self . wait_until_room_exists ( room_id) . await ;
419
+
420
+ room. send ( RoomMessageEventContent :: text_plain ( message. to_owned ( ) ) )
421
+ . await
422
+ . expect ( "Sending message failed" )
423
+ . event_id
424
+ . to_owned ( )
425
+ }
426
+ }
427
+
428
+ #[ tokio:: test( flavor = "multi_thread" , worker_threads = 4 ) ]
429
+ async fn test_multiple_clients_share_crypto_state ( ) -> Result < ( ) > {
430
+ let alice_sqlite_dir = tempdir ( ) ?;
431
+ let alice1 = ClientWrapper :: with_sqlite_dir ( "alice" , alice_sqlite_dir. path ( ) ) . await ;
432
+ let alice2 = ClientWrapper :: with_sqlite_dir ( "alice" , alice_sqlite_dir. path ( ) ) . await ;
433
+ let bob = ClientWrapper :: new ( "bob" ) . await ;
434
+
435
+ warn ! ( "alice's device: {}" , alice1. client. device_id( ) . unwrap( ) ) ;
436
+ warn ! ( "bob's device: {}" , bob. client. device_id( ) . unwrap( ) ) ;
437
+
438
+ // TODO: surely both alice clients share the same device ID because they are sharing the same DB?
439
+ //assert_eq!(alice1.client.device_id(), alice2.client.device_id());
440
+
441
+ let room_id = alice1. create_room ( & [ & bob] ) . await ;
442
+
443
+ warn ! ( "alice1 has created and enabled encryption in the room" ) ;
444
+
445
+ bob. join ( & room_id) . await ;
446
+ alice1. wait_until_user_in_room ( & room_id, & bob) . await ;
447
+
448
+ warn ! ( "alice1 and bob are both aware of each other in the e2ee room" ) ;
449
+
450
+ let msg1 = bob. send ( & room_id, "msg1_from_bob" ) . await ;
451
+ alice1. wait_until_received ( & msg1) . await ;
452
+
453
+ warn ! ( "alice1 received msg1 from bob" ) ;
454
+
455
+ let msg2 = bob. send ( & room_id, "msg2_from_bob" ) . await ;
456
+ alice2. wait_until_received ( & msg2) . await ;
457
+
458
+ warn ! ( "alice2 received msg2 from bob" ) ;
459
+
460
+ let msg3 = alice1. send ( & room_id, "msg3_from_alice" ) . await ;
461
+ bob. wait_until_received ( & msg3) . await ;
462
+
463
+ warn ! ( "bob received msg3 from alice1" ) ;
464
+
465
+ let msg4 = bob. send ( & room_id, "msg4_from_bob" ) . await ;
466
+ alice1. wait_until_received ( & msg4) . await ;
467
+ alice2. wait_until_received ( & msg4) . await ;
468
+
469
+ warn ! ( "alice1 and alice2 both received msg4 from bob" ) ;
470
+
471
+ Ok ( ( ) )
472
+ }
473
+
292
474
#[ tokio:: test( flavor = "multi_thread" , worker_threads = 4 ) ]
293
475
async fn test_mutual_qrcode_verification ( ) -> Result < ( ) > {
294
476
let encryption_settings =
0 commit comments