@@ -634,3 +634,265 @@ mod tests {
634
634
assert ! ( store. is_message_known( & hash) . await . unwrap( ) ) ;
635
635
}
636
636
}
637
+
638
+ #[ cfg( test) ]
639
+ mod integration_tests {
640
+ use std:: {
641
+ collections:: HashMap ,
642
+ sync:: { Arc , Mutex , OnceLock } ,
643
+ } ;
644
+
645
+ use async_trait:: async_trait;
646
+ use ruma:: {
647
+ events:: secret:: request:: SecretName , DeviceId , OwnedDeviceId , RoomId , TransactionId , UserId ,
648
+ } ;
649
+
650
+ use super :: MemoryStore ;
651
+ use crate :: {
652
+ cryptostore_integration_tests, cryptostore_integration_tests_time,
653
+ olm:: {
654
+ InboundGroupSession , OlmMessageHash , OutboundGroupSession , PrivateCrossSigningIdentity ,
655
+ StaticAccountData ,
656
+ } ,
657
+ store:: { BackupKeys , Changes , CryptoStore , PendingChanges , RoomKeyCounts , RoomSettings } ,
658
+ types:: events:: room_key_withheld:: RoomKeyWithheldEvent ,
659
+ Account , GossipRequest , GossippedSecret , ReadOnlyDevice , ReadOnlyUserIdentities ,
660
+ SecretInfo , Session , TrackedUser ,
661
+ } ;
662
+
663
+ /// Holds on to a MemoryStore during a test, and moves it back into STORES
664
+ /// when this is dropped
665
+ #[ derive( Clone , Debug ) ]
666
+ struct PersistentMemoryStore ( Arc < MemoryStore > ) ;
667
+
668
+ impl PersistentMemoryStore {
669
+ fn new ( ) -> Self {
670
+ Self ( Arc :: new ( MemoryStore :: new ( ) ) )
671
+ }
672
+
673
+ fn get_static_account ( & self ) -> Option < StaticAccountData > {
674
+ self . 0 . get_static_account ( )
675
+ }
676
+ }
677
+
678
+ impl MemoryStore {
679
+ fn get_static_account ( & self ) -> Option < StaticAccountData > {
680
+ self . account . read ( ) . unwrap ( ) . as_ref ( ) . map ( |acc| acc. static_data ( ) . clone ( ) )
681
+ }
682
+ }
683
+
684
+ /// Return a clone of the store for the test with the supplied name. Note:
685
+ /// dropping this store won't destroy its data, since
686
+ /// [PersistentMemoryStore] is a reference-counted smart pointer
687
+ /// to an underlying [MemoryStore].
688
+ async fn get_store ( name : & str , _passphrase : Option < & str > ) -> PersistentMemoryStore {
689
+ // Holds on to one [PersistentMemoryStore] per test, so even if the test drops
690
+ // the store, we keep its data alive. This simulates the behaviour of
691
+ // the other stores, which keep their data in a real DB, allowing us to
692
+ // test MemoryStore using the same code.
693
+ static STORES : OnceLock < Mutex < HashMap < String , PersistentMemoryStore > > > = OnceLock :: new ( ) ;
694
+ let stores = STORES . get_or_init ( || Mutex :: new ( HashMap :: new ( ) ) ) ;
695
+
696
+ stores
697
+ . lock ( )
698
+ . unwrap ( )
699
+ . entry ( name. to_owned ( ) )
700
+ . or_insert_with ( PersistentMemoryStore :: new)
701
+ . clone ( )
702
+ }
703
+
704
+ /// Forwards all methods to the underlying [MemoryStore].
705
+ #[ async_trait]
706
+ impl CryptoStore for PersistentMemoryStore {
707
+ type Error = <MemoryStore as CryptoStore >:: Error ;
708
+
709
+ async fn load_account ( & self ) -> Result < Option < Account > , Self :: Error > {
710
+ self . 0 . load_account ( ) . await
711
+ }
712
+
713
+ async fn load_identity ( & self ) -> Result < Option < PrivateCrossSigningIdentity > , Self :: Error > {
714
+ self . 0 . load_identity ( ) . await
715
+ }
716
+
717
+ async fn save_changes ( & self , changes : Changes ) -> Result < ( ) , Self :: Error > {
718
+ self . 0 . save_changes ( changes) . await
719
+ }
720
+
721
+ async fn save_pending_changes ( & self , changes : PendingChanges ) -> Result < ( ) , Self :: Error > {
722
+ self . 0 . save_pending_changes ( changes) . await
723
+ }
724
+
725
+ async fn get_sessions (
726
+ & self ,
727
+ sender_key : & str ,
728
+ ) -> Result < Option < Arc < tokio:: sync:: Mutex < Vec < Session > > > > , Self :: Error > {
729
+ self . 0 . get_sessions ( sender_key) . await
730
+ }
731
+
732
+ async fn get_inbound_group_session (
733
+ & self ,
734
+ room_id : & RoomId ,
735
+ session_id : & str ,
736
+ ) -> Result < Option < InboundGroupSession > , Self :: Error > {
737
+ self . 0 . get_inbound_group_session ( room_id, session_id) . await
738
+ }
739
+
740
+ async fn get_withheld_info (
741
+ & self ,
742
+ room_id : & RoomId ,
743
+ session_id : & str ,
744
+ ) -> Result < Option < RoomKeyWithheldEvent > , Self :: Error > {
745
+ self . 0 . get_withheld_info ( room_id, session_id) . await
746
+ }
747
+
748
+ async fn get_inbound_group_sessions (
749
+ & self ,
750
+ ) -> Result < Vec < InboundGroupSession > , Self :: Error > {
751
+ self . 0 . get_inbound_group_sessions ( ) . await
752
+ }
753
+
754
+ async fn inbound_group_session_counts ( & self ) -> Result < RoomKeyCounts , Self :: Error > {
755
+ self . 0 . inbound_group_session_counts ( ) . await
756
+ }
757
+
758
+ async fn inbound_group_sessions_for_backup (
759
+ & self ,
760
+ limit : usize ,
761
+ ) -> Result < Vec < InboundGroupSession > , Self :: Error > {
762
+ self . 0 . inbound_group_sessions_for_backup ( limit) . await
763
+ }
764
+
765
+ async fn mark_inbound_group_sessions_as_backed_up (
766
+ & self ,
767
+ room_and_session_ids : & [ ( & RoomId , & str ) ] ,
768
+ ) -> Result < ( ) , Self :: Error > {
769
+ self . 0 . mark_inbound_group_sessions_as_backed_up ( room_and_session_ids) . await
770
+ }
771
+
772
+ async fn reset_backup_state ( & self ) -> Result < ( ) , Self :: Error > {
773
+ self . 0 . reset_backup_state ( ) . await
774
+ }
775
+
776
+ async fn load_backup_keys ( & self ) -> Result < BackupKeys , Self :: Error > {
777
+ self . 0 . load_backup_keys ( ) . await
778
+ }
779
+
780
+ async fn get_outbound_group_session (
781
+ & self ,
782
+ room_id : & RoomId ,
783
+ ) -> Result < Option < OutboundGroupSession > , Self :: Error > {
784
+ self . 0 . get_outbound_group_session ( room_id) . await
785
+ }
786
+
787
+ async fn load_tracked_users ( & self ) -> Result < Vec < TrackedUser > , Self :: Error > {
788
+ self . 0 . load_tracked_users ( ) . await
789
+ }
790
+
791
+ async fn save_tracked_users ( & self , users : & [ ( & UserId , bool ) ] ) -> Result < ( ) , Self :: Error > {
792
+ self . 0 . save_tracked_users ( users) . await
793
+ }
794
+
795
+ async fn get_device (
796
+ & self ,
797
+ user_id : & UserId ,
798
+ device_id : & DeviceId ,
799
+ ) -> Result < Option < ReadOnlyDevice > , Self :: Error > {
800
+ self . 0 . get_device ( user_id, device_id) . await
801
+ }
802
+
803
+ async fn get_user_devices (
804
+ & self ,
805
+ user_id : & UserId ,
806
+ ) -> Result < HashMap < OwnedDeviceId , ReadOnlyDevice > , Self :: Error > {
807
+ self . 0 . get_user_devices ( user_id) . await
808
+ }
809
+
810
+ async fn get_user_identity (
811
+ & self ,
812
+ user_id : & UserId ,
813
+ ) -> Result < Option < ReadOnlyUserIdentities > , Self :: Error > {
814
+ self . 0 . get_user_identity ( user_id) . await
815
+ }
816
+
817
+ async fn is_message_known (
818
+ & self ,
819
+ message_hash : & OlmMessageHash ,
820
+ ) -> Result < bool , Self :: Error > {
821
+ self . 0 . is_message_known ( message_hash) . await
822
+ }
823
+
824
+ async fn get_outgoing_secret_requests (
825
+ & self ,
826
+ request_id : & TransactionId ,
827
+ ) -> Result < Option < GossipRequest > , Self :: Error > {
828
+ self . 0 . get_outgoing_secret_requests ( request_id) . await
829
+ }
830
+
831
+ async fn get_secret_request_by_info (
832
+ & self ,
833
+ secret_info : & SecretInfo ,
834
+ ) -> Result < Option < GossipRequest > , Self :: Error > {
835
+ self . 0 . get_secret_request_by_info ( secret_info) . await
836
+ }
837
+
838
+ async fn get_unsent_secret_requests ( & self ) -> Result < Vec < GossipRequest > , Self :: Error > {
839
+ self . 0 . get_unsent_secret_requests ( ) . await
840
+ }
841
+
842
+ async fn delete_outgoing_secret_requests (
843
+ & self ,
844
+ request_id : & TransactionId ,
845
+ ) -> Result < ( ) , Self :: Error > {
846
+ self . 0 . delete_outgoing_secret_requests ( request_id) . await
847
+ }
848
+
849
+ async fn get_secrets_from_inbox (
850
+ & self ,
851
+ secret_name : & SecretName ,
852
+ ) -> Result < Vec < GossippedSecret > , Self :: Error > {
853
+ self . 0 . get_secrets_from_inbox ( secret_name) . await
854
+ }
855
+
856
+ async fn delete_secrets_from_inbox (
857
+ & self ,
858
+ secret_name : & SecretName ,
859
+ ) -> Result < ( ) , Self :: Error > {
860
+ self . 0 . delete_secrets_from_inbox ( secret_name) . await
861
+ }
862
+
863
+ async fn get_room_settings (
864
+ & self ,
865
+ room_id : & RoomId ,
866
+ ) -> Result < Option < RoomSettings > , Self :: Error > {
867
+ self . 0 . get_room_settings ( room_id) . await
868
+ }
869
+
870
+ async fn get_custom_value ( & self , key : & str ) -> Result < Option < Vec < u8 > > , Self :: Error > {
871
+ self . 0 . get_custom_value ( key) . await
872
+ }
873
+
874
+ async fn set_custom_value ( & self , key : & str , value : Vec < u8 > ) -> Result < ( ) , Self :: Error > {
875
+ self . 0 . set_custom_value ( key, value) . await
876
+ }
877
+
878
+ async fn remove_custom_value ( & self , key : & str ) -> Result < ( ) , Self :: Error > {
879
+ self . 0 . remove_custom_value ( key) . await
880
+ }
881
+
882
+ async fn try_take_leased_lock (
883
+ & self ,
884
+ lease_duration_ms : u32 ,
885
+ key : & str ,
886
+ holder : & str ,
887
+ ) -> Result < bool , Self :: Error > {
888
+ self . 0 . try_take_leased_lock ( lease_duration_ms, key, holder) . await
889
+ }
890
+
891
+ async fn next_batch_token ( & self ) -> Result < Option < String > , Self :: Error > {
892
+ self . 0 . next_batch_token ( ) . await
893
+ }
894
+ }
895
+
896
+ cryptostore_integration_tests ! ( ) ;
897
+ cryptostore_integration_tests_time ! ( ) ;
898
+ }
0 commit comments