diff --git a/crates/sargon/src/system/sargon_os/profile_state_holder.rs b/crates/sargon/src/system/sargon_os/profile_state_holder.rs index a93a8acf1..1787a74cc 100644 --- a/crates/sargon/src/system/sargon_os/profile_state_holder.rs +++ b/crates/sargon/src/system/sargon_os/profile_state_holder.rs @@ -94,10 +94,9 @@ impl ProfileStateHolder { where F: Fn(&Profile) -> T, { - let guard = self - .profile_state - .try_read() - .expect("Implementing hosts should not read and write Profile from multiple threads."); + let guard = self.profile_state.read().expect( + "Stop execution due to the profile state lock being poisoned", + ); let state = &*guard; match state { @@ -113,10 +112,9 @@ impl ProfileStateHolder { where F: Fn(&Profile) -> Result, { - let guard = self - .profile_state - .try_read() - .expect("Implementing hosts should not read and write Profile from multiple threads."); + let guard = self.profile_state.read().expect( + "Stop execution due to the profile state lock being poisoned", + ); let state = &*guard; match state { @@ -133,10 +131,9 @@ impl ProfileStateHolder { &self, profile_state: ProfileState, ) -> Result<()> { - let mut lock = self - .profile_state - .try_write() - .map_err(|_| CommonError::UnableToAcquireWriteLockForProfile)?; + let mut lock = self.profile_state.write().expect( + "Stop execution due to the profile state lock being poisoned", + ); *lock = profile_state; Ok(()) @@ -149,25 +146,27 @@ impl ProfileStateHolder { where F: Fn(&mut Profile) -> Result, { - self.profile_state - .try_write() - .map_err(|_| CommonError::UnableToAcquireWriteLockForProfile) - .and_then(|mut guard| { - let state = &mut *guard; - - match state { - ProfileState::Loaded(ref mut profile) => mutate(profile), - _ => Err(CommonError::ProfileStateNotLoaded { - current_state: state.to_string(), - }), - } - }) + let mut guard = self.profile_state.write().expect( + "Stop execution due to the profile state lock being poisoned", + ); + + let state = &mut *guard; + + match state { + ProfileState::Loaded(ref mut profile) => mutate(profile), + _ => Err(CommonError::ProfileStateNotLoaded { + current_state: state.to_string(), + }), + } } } #[cfg(test)] mod tests { use crate::prelude::*; + use std::sync::{Arc, RwLock}; + use std::thread; + use std::time::Duration; #[test] fn test_new_none_profile_state_holder() { @@ -210,4 +209,150 @@ mod tests { state, ) } + + #[test] + fn test_concurrent_access_read_after_write() { + let state = ProfileState::Loaded(Profile::sample()); + let sut = ProfileStateHolder::new(state.clone()); + let state_holder = Arc::new(sut); + + let state_holder_clone = Arc::clone(&state_holder); + + // Spawn a thread that acquires a write lock + let handle = thread::spawn(move || { + let _write_lock = + state_holder_clone.update_profile_with(|profile| { + profile.networks.try_update_with( + &NetworkID::Mainnet, + |network| { + let _res = network.accounts.try_insert_unique( + Account::sample_mainnet_carol(), + ); + }, + ) + }); + thread::sleep(Duration::from_millis(200)); + }); + + // Give the other thread time to acquire the write lock + thread::sleep(Duration::from_millis(100)); + + let mainnet_accounts = state_holder.current_network().unwrap().accounts; + + handle.join().unwrap(); // Wait for the thread to finish + + let mut expected_accounts = Accounts::sample_mainnet(); + expected_accounts.insert(Account::sample_mainnet_carol()); + pretty_assertions::assert_eq!(mainnet_accounts, expected_accounts) + } + + #[test] + fn test_concurrent_access_writes_order_is_preserved() { + let profile = Profile::sample(); + let state = ProfileState::Loaded(profile); + let sut = ProfileStateHolder::new(state.clone()); + let state_holder = Arc::new(sut); + + let first_mainnet_account = state_holder + .access_profile_with(|profile| { + profile + .networks + .first() + .unwrap() + .accounts + .first() + .unwrap() + .clone() + }) + .unwrap(); + + let mut handles = vec![]; + + for i in 0..5 { + let state_holder_clone = Arc::clone(&state_holder); + let handle = thread::spawn(move || { + let _write_lock = + state_holder_clone.update_profile_with(|profile| { + profile.networks.try_update_with( + &NetworkID::Mainnet, + |network| { + let _res = network.accounts.try_update_with( + &first_mainnet_account.address, + |account| { + let display_name = + account.display_name.value.clone(); + account.display_name = DisplayName::new( + display_name + + i.to_string().as_str(), + ) + .unwrap() + }, + ); + }, + ) + }); + // Hold the lock for a while to simulate a long-running write operation + thread::sleep(Duration::from_millis(200)); + }); + thread::sleep(Duration::from_millis(100)); + + handles.push(handle); + } + + for handle in handles { + handle.join().unwrap(); + } + + let result_name = state_holder + .access_profile_with(|profile| { + profile + .networks + .first() + .unwrap() + .accounts + .first() + .unwrap() + .display_name + .value + .clone() + }) + .unwrap(); + + let expected_name = first_mainnet_account.display_name.value + "01234"; + + pretty_assertions::assert_eq!(expected_name, result_name) + } + + #[test] + #[should_panic] + fn test_concurrent_access_poisoned_lock_panics() { + let state = ProfileState::Loaded(Profile::sample()); + let sut = ProfileStateHolder::new(state.clone()); + let state_holder = Arc::new(sut); + + let state_holder_clone = Arc::clone(&state_holder); + + // Spawn a thread that acquires a write lock + let handle = thread::spawn(move || { + let _write_lock = + state_holder_clone.update_profile_with(|profile| { + profile.networks.try_update_with( + &NetworkID::Mainnet, + |network| { + let _res = network + .accounts + .try_insert_unique( + Account::sample_mainnet_carol(), + ) + .unwrap(); + panic!("Simulate panic in thread"); + }, + ) + }); + }); + + let _ = handle.join(); // Wait for the thread to finish + + state_holder.current_network().unwrap(); + } }