@@ -634,3 +634,263 @@ mod tests {
634
634
assert ! ( store. is_message_known( & hash) . await . unwrap( ) ) ;
635
635
}
636
636
}
637
+
638
+ #[ cfg( test) ]
639
+ mod integration_tests {
640
+ use std:: {
641
+ collections:: HashMap ,
642
+ sync:: { Arc , Mutex , OnceLock } ,
643
+ } ;
644
+
645
+ use async_trait:: async_trait;
646
+ use ruma:: {
647
+ events:: secret:: request:: SecretName , DeviceId , OwnedDeviceId , RoomId , TransactionId , UserId ,
648
+ } ;
649
+
650
+ use super :: MemoryStore ;
651
+ use crate :: {
652
+ cryptostore_integration_tests, cryptostore_integration_tests_time,
653
+ olm:: {
654
+ InboundGroupSession , OlmMessageHash , OutboundGroupSession , PrivateCrossSigningIdentity ,
655
+ StaticAccountData ,
656
+ } ,
657
+ store:: { BackupKeys , Changes , CryptoStore , PendingChanges , RoomKeyCounts , RoomSettings } ,
658
+ types:: events:: room_key_withheld:: RoomKeyWithheldEvent ,
659
+ Account , GossipRequest , GossippedSecret , ReadOnlyDevice , ReadOnlyUserIdentities ,
660
+ SecretInfo , Session , TrackedUser ,
661
+ } ;
662
+
663
+ /// Holds on to a MemoryStore during a test, and moves it back into STORES
664
+ /// when this is dropped
665
+ #[ derive( Clone , Debug ) ]
666
+ struct PersistentMemoryStore ( Arc < MemoryStore > ) ;
667
+
668
+ impl PersistentMemoryStore {
669
+ fn new ( ) -> Self {
670
+ Self ( Arc :: new ( MemoryStore :: new ( ) ) )
671
+ }
672
+
673
+ fn get_static_account ( & self ) -> Option < StaticAccountData > {
674
+ self . 0 . get_static_account ( )
675
+ }
676
+ }
677
+
678
+ impl MemoryStore {
679
+ fn get_static_account ( & self ) -> Option < StaticAccountData > {
680
+ self . account . read ( ) . unwrap ( ) . as_ref ( ) . map ( |acc| acc. static_data ( ) . clone ( ) )
681
+ }
682
+ }
683
+
684
+ /// Return a clone of the store for the test with the supplied name. Note: dropping this store
685
+ /// won't destroy its data, since [PersistentMemoryStore] is a reference-counted smart pointer
686
+ /// to an underlying [MemoryStore].
687
+ async fn get_store ( name : & str , _passphrase : Option < & str > ) -> PersistentMemoryStore {
688
+ // Holds on to one [PersistentMemoryStore] per test, so even if the test drops the store, we
689
+ // keep its data alive. This simulates the behaviour of the other stores, which keep their
690
+ // data in a real DB, allowing us to test MemoryStore using the same code.
691
+ static STORES : OnceLock < Mutex < HashMap < String , PersistentMemoryStore > > > = OnceLock :: new ( ) ;
692
+ let stores = STORES . get_or_init ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
693
+
694
+ stores
695
+ . lock ( )
696
+ . unwrap ( )
697
+ . entry ( name. to_owned ( ) )
698
+ . or_insert_with ( || PersistentMemoryStore :: new ( ) )
699
+ . clone ( )
700
+ }
701
+
702
+ /// Forwards all methods to the underlying [MemoryStore].
703
+ #[ async_trait]
704
+ impl CryptoStore for PersistentMemoryStore {
705
+ type Error = <MemoryStore as CryptoStore >:: Error ;
706
+
707
+ async fn load_account ( & self ) -> Result < Option < Account > , Self :: Error > {
708
+ self . 0 . load_account ( ) . await
709
+ }
710
+
711
+ async fn load_identity ( & self ) -> Result < Option < PrivateCrossSigningIdentity > , Self :: Error > {
712
+ self . 0 . load_identity ( ) . await
713
+ }
714
+
715
+ async fn save_changes ( & self , changes : Changes ) -> Result < ( ) , Self :: Error > {
716
+ self . 0 . save_changes ( changes) . await
717
+ }
718
+
719
+ async fn save_pending_changes ( & self , changes : PendingChanges ) -> Result < ( ) , Self :: Error > {
720
+ self . 0 . save_pending_changes ( changes) . await
721
+ }
722
+
723
+ async fn get_sessions (
724
+ & self ,
725
+ sender_key : & str ,
726
+ ) -> Result < Option < Arc < tokio:: sync:: Mutex < Vec < Session > > > > , Self :: Error > {
727
+ self . 0 . get_sessions ( sender_key) . await
728
+ }
729
+
730
+ async fn get_inbound_group_session (
731
+ & self ,
732
+ room_id : & RoomId ,
733
+ session_id : & str ,
734
+ ) -> Result < Option < InboundGroupSession > , Self :: Error > {
735
+ self . 0 . get_inbound_group_session ( room_id, session_id) . await
736
+ }
737
+
738
+ async fn get_withheld_info (
739
+ & self ,
740
+ room_id : & RoomId ,
741
+ session_id : & str ,
742
+ ) -> Result < Option < RoomKeyWithheldEvent > , Self :: Error > {
743
+ self . 0 . get_withheld_info ( room_id, session_id) . await
744
+ }
745
+
746
+ async fn get_inbound_group_sessions (
747
+ & self ,
748
+ ) -> Result < Vec < InboundGroupSession > , Self :: Error > {
749
+ self . 0 . get_inbound_group_sessions ( ) . await
750
+ }
751
+
752
+ async fn inbound_group_session_counts ( & self ) -> Result < RoomKeyCounts , Self :: Error > {
753
+ self . 0 . inbound_group_session_counts ( ) . await
754
+ }
755
+
756
+ async fn inbound_group_sessions_for_backup (
757
+ & self ,
758
+ limit : usize ,
759
+ ) -> Result < Vec < InboundGroupSession > , Self :: Error > {
760
+ self . 0 . inbound_group_sessions_for_backup ( limit) . await
761
+ }
762
+
763
+ async fn mark_inbound_group_sessions_as_backed_up (
764
+ & self ,
765
+ room_and_session_ids : & [ ( & RoomId , & str ) ] ,
766
+ ) -> Result < ( ) , Self :: Error > {
767
+ self . 0 . mark_inbound_group_sessions_as_backed_up ( room_and_session_ids) . await
768
+ }
769
+
770
+ async fn reset_backup_state ( & self ) -> Result < ( ) , Self :: Error > {
771
+ self . 0 . reset_backup_state ( ) . await
772
+ }
773
+
774
+ async fn load_backup_keys ( & self ) -> Result < BackupKeys , Self :: Error > {
775
+ self . 0 . load_backup_keys ( ) . await
776
+ }
777
+
778
+ async fn get_outbound_group_session (
779
+ & self ,
780
+ room_id : & RoomId ,
781
+ ) -> Result < Option < OutboundGroupSession > , Self :: Error > {
782
+ self . 0 . get_outbound_group_session ( room_id) . await
783
+ }
784
+
785
+ async fn load_tracked_users ( & self ) -> Result < Vec < TrackedUser > , Self :: Error > {
786
+ self . 0 . load_tracked_users ( ) . await
787
+ }
788
+
789
+ async fn save_tracked_users ( & self , users : & [ ( & UserId , bool ) ] ) -> Result < ( ) , Self :: Error > {
790
+ self . 0 . save_tracked_users ( users) . await
791
+ }
792
+
793
+ async fn get_device (
794
+ & self ,
795
+ user_id : & UserId ,
796
+ device_id : & DeviceId ,
797
+ ) -> Result < Option < ReadOnlyDevice > , Self :: Error > {
798
+ self . 0 . get_device ( user_id, device_id) . await
799
+ }
800
+
801
+ async fn get_user_devices (
802
+ & self ,
803
+ user_id : & UserId ,
804
+ ) -> Result < HashMap < OwnedDeviceId , ReadOnlyDevice > , Self :: Error > {
805
+ self . 0 . get_user_devices ( user_id) . await
806
+ }
807
+
808
+ async fn get_user_identity (
809
+ & self ,
810
+ user_id : & UserId ,
811
+ ) -> Result < Option < ReadOnlyUserIdentities > , Self :: Error > {
812
+ self . 0 . get_user_identity ( user_id) . await
813
+ }
814
+
815
+ async fn is_message_known (
816
+ & self ,
817
+ message_hash : & OlmMessageHash ,
818
+ ) -> Result < bool , Self :: Error > {
819
+ self . 0 . is_message_known ( message_hash) . await
820
+ }
821
+
822
+ async fn get_outgoing_secret_requests (
823
+ & self ,
824
+ request_id : & TransactionId ,
825
+ ) -> Result < Option < GossipRequest > , Self :: Error > {
826
+ self . 0 . get_outgoing_secret_requests ( request_id) . await
827
+ }
828
+
829
+ async fn get_secret_request_by_info (
830
+ & self ,
831
+ secret_info : & SecretInfo ,
832
+ ) -> Result < Option < GossipRequest > , Self :: Error > {
833
+ self . 0 . get_secret_request_by_info ( secret_info) . await
834
+ }
835
+
836
+ async fn get_unsent_secret_requests ( & self ) -> Result < Vec < GossipRequest > , Self :: Error > {
837
+ self . 0 . get_unsent_secret_requests ( ) . await
838
+ }
839
+
840
+ async fn delete_outgoing_secret_requests (
841
+ & self ,
842
+ request_id : & TransactionId ,
843
+ ) -> Result < ( ) , Self :: Error > {
844
+ self . 0 . delete_outgoing_secret_requests ( request_id) . await
845
+ }
846
+
847
+ async fn get_secrets_from_inbox (
848
+ & self ,
849
+ secret_name : & SecretName ,
850
+ ) -> Result < Vec < GossippedSecret > , Self :: Error > {
851
+ self . 0 . get_secrets_from_inbox ( secret_name) . await
852
+ }
853
+
854
+ async fn delete_secrets_from_inbox (
855
+ & self ,
856
+ secret_name : & SecretName ,
857
+ ) -> Result < ( ) , Self :: Error > {
858
+ self . 0 . delete_secrets_from_inbox ( secret_name) . await
859
+ }
860
+
861
+ async fn get_room_settings (
862
+ & self ,
863
+ room_id : & RoomId ,
864
+ ) -> Result < Option < RoomSettings > , Self :: Error > {
865
+ self . 0 . get_room_settings ( room_id) . await
866
+ }
867
+
868
+ async fn get_custom_value ( & self , key : & str ) -> Result < Option < Vec < u8 > > , Self :: Error > {
869
+ self . 0 . get_custom_value ( key) . await
870
+ }
871
+
872
+ async fn set_custom_value ( & self , key : & str , value : Vec < u8 > ) -> Result < ( ) , Self :: Error > {
873
+ self . 0 . set_custom_value ( key, value) . await
874
+ }
875
+
876
+ async fn remove_custom_value ( & self , key : & str ) -> Result < ( ) , Self :: Error > {
877
+ self . 0 . remove_custom_value ( key) . await
878
+ }
879
+
880
+ async fn try_take_leased_lock (
881
+ & self ,
882
+ lease_duration_ms : u32 ,
883
+ key : & str ,
884
+ holder : & str ,
885
+ ) -> Result < bool , Self :: Error > {
886
+ self . 0 . try_take_leased_lock ( lease_duration_ms, key, holder) . await
887
+ }
888
+
889
+ async fn next_batch_token ( & self ) -> Result < Option < String > , Self :: Error > {
890
+ self . 0 . next_batch_token ( ) . await
891
+ }
892
+ }
893
+
894
+ cryptostore_integration_tests ! ( ) ;
895
+ cryptostore_integration_tests_time ! ( ) ;
896
+ }
0 commit comments