|
1 |
| -use std::sync::{Arc, Mutex}; |
| 1 | +use std::{ |
| 2 | + collections::{BTreeMap, HashSet}, |
| 3 | + sync::{ |
| 4 | + atomic::{AtomicBool, Ordering}, |
| 5 | + Arc, Mutex, |
| 6 | + }, |
| 7 | +}; |
2 | 8 |
|
3 | 9 | use futures_util::{pin_mut, StreamExt as _};
|
4 |
| -use matrix_sdk::test_utils::logged_in_client_with_server; |
| 10 | +use matrix_sdk::{ |
| 11 | + config::RequestConfig, |
| 12 | + matrix_auth::{MatrixSession, MatrixSessionTokens}, |
| 13 | + test_utils::{logged_in_client_with_server, test_client_builder_with_server}, |
| 14 | + SessionMeta, |
| 15 | +}; |
5 | 16 | use matrix_sdk_base::crypto::store::Changes;
|
6 | 17 | use matrix_sdk_test::async_test;
|
7 | 18 | use matrix_sdk_ui::encryption_sync_service::{
|
8 | 19 | EncryptionSyncPermit, EncryptionSyncService, WithLocking,
|
9 | 20 | };
|
| 21 | +use ruma::{device_id, user_id}; |
| 22 | +use serde::Deserialize; |
10 | 23 | use serde_json::json;
|
11 | 24 | use tokio::sync::Mutex as AsyncMutex;
|
12 |
| -use wiremock::{Mock, MockGuard, MockServer, Request, ResponseTemplate}; |
| 25 | +use wiremock::{ |
| 26 | + matchers::{method, path}, |
| 27 | + Mock, MockGuard, MockServer, Request, ResponseTemplate, |
| 28 | +}; |
13 | 29 |
|
14 | 30 | use crate::{
|
| 31 | + mock_sync, |
15 | 32 | sliding_sync::{check_requests, PartialSlidingSyncRequest, SlidingSyncMatcher},
|
16 | 33 | sliding_sync_then_assert_request_and_fake_response,
|
17 | 34 | };
|
@@ -320,3 +337,231 @@ async fn test_encryption_sync_always_reloads_todevice_token() -> anyhow::Result<
|
320 | 337 |
|
321 | 338 | Ok(())
|
322 | 339 | }
|
| 340 | + |
| 341 | +#[async_test] |
| 342 | +async fn test_notification_client_does_not_upload_duplicate_one_time_keys() -> anyhow::Result<()> { |
| 343 | + use tempfile::tempdir; |
| 344 | + |
| 345 | + let dir = tempdir().unwrap(); |
| 346 | + let user_id = user_id!("@example:morpheus.localhost"); |
| 347 | + |
| 348 | + let (builder, server) = test_client_builder_with_server().await; |
| 349 | + let client = builder |
| 350 | + .request_config(RequestConfig::new().disable_retry()) |
| 351 | + .sqlite_store(dir.path(), None) |
| 352 | + .build() |
| 353 | + .await |
| 354 | + .unwrap(); |
| 355 | + |
| 356 | + let session = MatrixSession { |
| 357 | + meta: SessionMeta { user_id: user_id.into(), device_id: device_id!("DEVICEID").to_owned() }, |
| 358 | + tokens: MatrixSessionTokens { access_token: "1234".to_owned(), refresh_token: None }, |
| 359 | + }; |
| 360 | + |
| 361 | + client.restore_session(session.to_owned()).await.unwrap(); |
| 362 | + |
| 363 | + tracing::info!("Creating the notification client"); |
| 364 | + let notification_client = client |
| 365 | + .notification_client() |
| 366 | + .await |
| 367 | + .expect("We should be able to build a notification client"); |
| 368 | + |
| 369 | + let sync_permit = Arc::new(AsyncMutex::new(EncryptionSyncPermit::new_for_testing())); |
| 370 | + let sync_permit_guard = sync_permit.lock_owned().await; |
| 371 | + let encryption_sync = |
| 372 | + EncryptionSyncService::new("tests".to_owned(), client.clone(), None, WithLocking::Yes) |
| 373 | + .await?; |
| 374 | + |
| 375 | + let stream = encryption_sync.sync(sync_permit_guard); |
| 376 | + pin_mut!(stream); |
| 377 | + |
| 378 | + Mock::given(method("POST")) |
| 379 | + .and(path("/_matrix/client/r0/keys/query")) |
| 380 | + .respond_with(ResponseTemplate::new(200).set_body_json(json!({}))) |
| 381 | + .mount(&server) |
| 382 | + .await; |
| 383 | + |
| 384 | + tracing::info!("First sync, uploading 50 one-time keys"); |
| 385 | + |
| 386 | + sliding_sync_then_assert_request_and_fake_response! { |
| 387 | + [server, stream] |
| 388 | + assert request = { |
| 389 | + "conn_id": "encryption", |
| 390 | + "extensions": { |
| 391 | + "e2ee": { |
| 392 | + "enabled": true |
| 393 | + }, |
| 394 | + "to_device": { |
| 395 | + "enabled": true |
| 396 | + } |
| 397 | + } |
| 398 | + }, |
| 399 | + respond with = { |
| 400 | + "pos": "0", |
| 401 | + "extensions": { |
| 402 | + "to_device": { |
| 403 | + "next_batch": "nb0" |
| 404 | + }, |
| 405 | + } |
| 406 | + }, |
| 407 | + }; |
| 408 | + |
| 409 | + #[derive(Debug, Deserialize)] |
| 410 | + struct UploadRequest { |
| 411 | + one_time_keys: BTreeMap<String, serde_json::Value>, |
| 412 | + } |
| 413 | + |
| 414 | + let found_duplicate = Arc::new(AtomicBool::new(false)); |
| 415 | + let uploaded_key_ids = Arc::new(Mutex::new(HashSet::new())); |
| 416 | + |
| 417 | + Mock::given(method("POST")) |
| 418 | + .and(path("/_matrix/client/r0/keys/upload")) |
| 419 | + .respond_with({ |
| 420 | + let found_duplicate = found_duplicate.clone(); |
| 421 | + let uploaded_key_ids = uploaded_key_ids.clone(); |
| 422 | + |
| 423 | + move |request: &Request| { |
| 424 | + let request: UploadRequest = request |
| 425 | + .body_json() |
| 426 | + .expect("The /keys/upload request should contain one-time keys"); |
| 427 | + |
| 428 | + let mut uploaded_key_ids = uploaded_key_ids.lock().unwrap(); |
| 429 | + |
| 430 | + let new_key_ids: HashSet<String> = request.one_time_keys.into_keys().collect(); |
| 431 | + |
| 432 | + tracing::warn!(?new_key_ids, "Got a new /keys/upload request"); |
| 433 | + |
| 434 | + let duplicates: HashSet<_> = uploaded_key_ids.intersection(&new_key_ids).collect(); |
| 435 | + |
| 436 | + if let Some(duplicate) = duplicates.into_iter().next() { |
| 437 | + tracing::error!("Duplicate one-time keys were uploaded."); |
| 438 | + |
| 439 | + found_duplicate.store(true, Ordering::SeqCst); |
| 440 | + |
| 441 | + ResponseTemplate::new(400).set_body_json(json!({ |
| 442 | + "errcode": "M_WAT", |
| 443 | + "error:": format!("One time key {duplicate} already exists!") |
| 444 | + })) |
| 445 | + } else { |
| 446 | + tracing::trace!("No duplicate one-time keys found."); |
| 447 | + uploaded_key_ids.extend(new_key_ids); |
| 448 | + |
| 449 | + ResponseTemplate::new(200).set_body_json(json!({ |
| 450 | + "one_time_key_counts": { |
| 451 | + "signed_curve25519": 50 |
| 452 | + } |
| 453 | + })) |
| 454 | + } |
| 455 | + } |
| 456 | + }) |
| 457 | + .expect(4) |
| 458 | + .mount(&server) |
| 459 | + .await; |
| 460 | + |
| 461 | + tracing::info!("Main sync not gets told that a one-time key has been used up."); |
| 462 | + |
| 463 | + sliding_sync_then_assert_request_and_fake_response! { |
| 464 | + [server, stream] |
| 465 | + assert request = { |
| 466 | + "conn_id": "encryption", |
| 467 | + "extensions": { |
| 468 | + "to_device": { |
| 469 | + "since": "nb0", |
| 470 | + }, |
| 471 | + } |
| 472 | + }, |
| 473 | + respond with = { |
| 474 | + "pos": "2", |
| 475 | + "extensions": { |
| 476 | + "to_device": { |
| 477 | + "next_batch": "nb2" |
| 478 | + }, |
| 479 | + "e2ee": { |
| 480 | + "device_one_time_keys_count": { |
| 481 | + "signed_curve25519": 49 |
| 482 | + } |
| 483 | + } |
| 484 | + } |
| 485 | + }, |
| 486 | + }; |
| 487 | + |
| 488 | + assert!( |
| 489 | + !found_duplicate.load(Ordering::SeqCst), |
| 490 | + "The main sync should not have caused a duplicate one-time key" |
| 491 | + ); |
| 492 | + |
| 493 | + mock_sync( |
| 494 | + &server, |
| 495 | + json!({ |
| 496 | + "next_batch": "foo", |
| 497 | + "device_one_time_keys_count": { |
| 498 | + "signed_curve25519": 49 |
| 499 | + } |
| 500 | + }), |
| 501 | + None, |
| 502 | + ) |
| 503 | + .await; |
| 504 | + |
| 505 | + tracing::info!("The notification client now syncs and tries to upload some one-time keys"); |
| 506 | + |
| 507 | + notification_client |
| 508 | + .sync_once(Default::default()) |
| 509 | + .await |
| 510 | + .expect("The notification client should be able to sync successfully"); |
| 511 | + |
| 512 | + tracing::info!("Back to the main sync"); |
| 513 | + |
| 514 | + sliding_sync_then_assert_request_and_fake_response! { |
| 515 | + [server, stream] |
| 516 | + assert request = { |
| 517 | + "conn_id": "encryption", |
| 518 | + "extensions": { |
| 519 | + "to_device": { |
| 520 | + "since": "foo", |
| 521 | + }, |
| 522 | + } |
| 523 | + }, |
| 524 | + respond with = { |
| 525 | + "pos": "2", |
| 526 | + "extensions": { |
| 527 | + "to_device": { |
| 528 | + "next_batch": "nb4" |
| 529 | + }, |
| 530 | + "e2ee": { |
| 531 | + "device_one_time_keys_count": { |
| 532 | + "signed_curve25519": 49 |
| 533 | + } |
| 534 | + } |
| 535 | + } |
| 536 | + }, |
| 537 | + }; |
| 538 | + |
| 539 | + sliding_sync_then_assert_request_and_fake_response! { |
| 540 | + [server, stream] |
| 541 | + assert request = { |
| 542 | + "conn_id": "encryption", |
| 543 | + "extensions": { |
| 544 | + "to_device": { |
| 545 | + "since": "nb4", |
| 546 | + }, |
| 547 | + } |
| 548 | + }, |
| 549 | + respond with = { |
| 550 | + "pos": "2", |
| 551 | + "extensions": { |
| 552 | + "to_device": { |
| 553 | + "next_batch": "nb5" |
| 554 | + }, |
| 555 | + } |
| 556 | + }, |
| 557 | + }; |
| 558 | + |
| 559 | + assert!( |
| 560 | + !found_duplicate.load(Ordering::SeqCst), |
| 561 | + "Duplicate one-time keys should not have been created" |
| 562 | + ); |
| 563 | + |
| 564 | + server.verify().await; |
| 565 | + |
| 566 | + Ok(()) |
| 567 | +} |
0 commit comments