@@ -55,6 +55,7 @@ pub struct MemoryStore {
55
55
sessions : SessionStore ,
56
56
inbound_group_sessions : GroupSessionStore ,
57
57
outbound_group_sessions : StdRwLock < Vec < OutboundGroupSession > > ,
58
+ private_identity : StdRwLock < Option < PrivateCrossSigningIdentity > > ,
58
59
tracked_users : StdRwLock < Vec < TrackedUser > > ,
59
60
olm_hashes : StdRwLock < HashMap < String , HashSet < String > > > ,
60
61
devices : DeviceStore ,
@@ -77,6 +78,7 @@ impl Default for MemoryStore {
77
78
sessions : SessionStore :: new ( ) ,
78
79
inbound_group_sessions : GroupSessionStore :: new ( ) ,
79
80
outbound_group_sessions : Default :: default ( ) ,
81
+ private_identity : Default :: default ( ) ,
80
82
tracked_users : Default :: default ( ) ,
81
83
olm_hashes : Default :: default ( ) ,
82
84
devices : DeviceStore :: new ( ) ,
@@ -127,6 +129,10 @@ impl MemoryStore {
127
129
fn save_outbound_group_sessions ( & self , mut sessions : Vec < OutboundGroupSession > ) {
128
130
self . outbound_group_sessions . write ( ) . unwrap ( ) . append ( & mut sessions) ;
129
131
}
132
+
133
+ fn save_private_identity ( & self , private_identity : Option < PrivateCrossSigningIdentity > ) {
134
+ * self . private_identity . write ( ) . unwrap ( ) = private_identity;
135
+ }
130
136
}
131
137
132
138
type Result < T > = std:: result:: Result < T , Infallible > ;
@@ -141,7 +147,7 @@ impl CryptoStore for MemoryStore {
141
147
}
142
148
143
149
async fn load_identity ( & self ) -> Result < Option < PrivateCrossSigningIdentity > > {
144
- Ok ( None )
150
+ Ok ( self . private_identity . read ( ) . unwrap ( ) . clone ( ) )
145
151
}
146
152
147
153
async fn next_batch_token ( & self ) -> Result < Option < String > > {
@@ -160,6 +166,7 @@ impl CryptoStore for MemoryStore {
160
166
self . save_sessions ( changes. sessions ) . await ;
161
167
self . save_inbound_group_sessions ( changes. inbound_group_sessions ) ;
162
168
self . save_outbound_group_sessions ( changes. outbound_group_sessions ) ;
169
+ self . save_private_identity ( changes. private_identity ) ;
163
170
164
171
self . save_devices ( changes. devices . new ) ;
165
172
self . save_devices ( changes. devices . changed ) ;
@@ -485,7 +492,10 @@ mod tests {
485
492
486
493
use crate :: {
487
494
identities:: device:: testing:: get_device,
488
- olm:: { tests:: get_account_and_session_test_helper, InboundGroupSession , OlmMessageHash } ,
495
+ olm:: {
496
+ tests:: get_account_and_session_test_helper, InboundGroupSession , OlmMessageHash ,
497
+ PrivateCrossSigningIdentity ,
498
+ } ,
489
499
store:: { memorystore:: MemoryStore , Changes , CryptoStore , PendingChanges } ,
490
500
} ;
491
501
@@ -571,6 +581,22 @@ mod tests {
571
581
assert_eq ! ( loaded_tracked_users. len( ) , 2 ) ;
572
582
}
573
583
584
+ #[ async_test]
585
+ async fn test_private_identity_store ( ) {
586
+ // Given a private identity
587
+ let private_identity = PrivateCrossSigningIdentity :: empty ( user_id ! ( "@u:s" ) ) ;
588
+
589
+ // When we save it to the store
590
+ let store = MemoryStore :: new ( ) ;
591
+ store. save_private_identity ( Some ( private_identity. clone ( ) ) ) ;
592
+
593
+ // Then we can get it out again
594
+ let loaded_identity =
595
+ store. load_identity ( ) . await . expect ( "failed to load private identity" ) . unwrap ( ) ;
596
+
597
+ assert_eq ! ( loaded_identity. user_id( ) , user_id!( "@u:s" ) ) ;
598
+ }
599
+
574
600
#[ async_test]
575
601
async fn test_device_store ( ) {
576
602
let device = get_device ( ) ;
0 commit comments