diff --git a/teos/proto/teos/v2/appointment.proto b/teos/proto/teos/v2/appointment.proto index ecb873ee..07aab6f1 100644 --- a/teos/proto/teos/v2/appointment.proto +++ b/teos/proto/teos/v2/appointment.proto @@ -4,7 +4,8 @@ package teos.v2; import "common/teos/v2/appointment.proto"; message GetAppointmentsRequest { - // Request the information of appointments with specific locator and user_id (optional) . + // Request the information of appointments with specific locator. + // If a user id is provided (optional), request only appointments belonging to that user. bytes locator = 1; optional bytes user_id = 2; diff --git a/teos/src/api/internal.rs b/teos/src/api/internal.rs index fdfd13dd..51543fbb 100644 --- a/teos/src/api/internal.rs +++ b/teos/src/api/internal.rs @@ -298,26 +298,23 @@ impl PrivateTowerServices for Arc { ) })?; - let appointments: Vec = self + let mut matching_appointments: Vec = self .watcher .get_watcher_appointments_with_locator(locator, user_id) .into_values() - .map(|appointment| appointment.inner) - .collect(); - - let mut matching_appointments: Vec = appointments - .into_iter() .map(|appointment| common_msgs::AppointmentData { appointment_data: Some( - common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()), + common_msgs::appointment_data::AppointmentData::Appointment( + appointment.inner.into(), + ), ), }) .collect(); - for (_, tracker) in self + for tracker in self .watcher .get_responder_trackers_with_locator(locator, user_id) - .into_iter() + .into_values() { matching_appointments.push(common_msgs::AppointmentData { appointment_data: Some(common_msgs::appointment_data::AppointmentData::Tracker( @@ -445,6 +442,8 @@ mod tests_private_api { use bitcoin::hashes::Hash; use bitcoin::Txid; + use rand::{self, thread_rng, Rng}; + use crate::responder::{ConfirmationStatus, TransactionTracker}; use crate::test_utils::{ create_api, generate_dummy_appointment, generate_dummy_appointment_with_user, @@ -531,6 +530,19 @@ mod tests_private_api { .into_inner(); assert!(matches!(response, msgs::GetAppointmentsResponse { .. })); + + let user_id = get_random_user_id().to_vec(); + let locator = Locator::new(get_random_tx().txid()).to_vec(); + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator, + user_id: Some(user_id), + })) + .await + .unwrap() + .into_inner(); + + assert!(matches!(response, msgs::GetAppointmentsResponse { .. })); } #[tokio::test] @@ -544,10 +556,24 @@ mod tests_private_api { // The number of different appointments to create for this dispute tx. let appointments_to_create = 4 * i + 7; + // Create a specific user + let random_appointment_num = thread_rng().gen_range(1..appointments_to_create); + let (random_user_sk, random_user_pk) = get_random_keypair(); + let random_user_id = UserId(random_user_pk); + internal_api.watcher.register(random_user_id).unwrap(); + + let distinct_appointment_numbers = appointments_to_create - random_appointment_num + 1; + // Add that many appointments to the watcher. - for _ in 0..appointments_to_create { - let (user_sk, user_pk) = get_random_keypair(); - internal_api.watcher.register(UserId(user_pk)).unwrap(); + for i in 0..appointments_to_create { + let user_sk = if i < random_appointment_num { + random_user_sk + } else { + let (user_sk, user_pk) = get_random_keypair(); + let user_id = UserId(user_pk); + internal_api.watcher.register(user_id).unwrap(); + user_sk + }; let appointment = generate_dummy_appointment(Some(&dispute_txid)).inner; let signature = cryptography::sign(&appointment.to_vec(), &user_sk).unwrap(); internal_api @@ -556,9 +582,10 @@ mod tests_private_api { .unwrap(); } + // Add that many appointments to the watcher. let locator = Locator::new(dispute_txid); - // Query for the current locator and assert it retrieves correct appointments. + // Query for the current locator without the optional user_id let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), @@ -568,8 +595,10 @@ mod tests_private_api { .unwrap() .into_inner(); - // The response should contain `appointments_to_create` appointments, all having the locator of the current iteration. - assert_eq!(response.appointments.len(), appointments_to_create); + // Verify that the response contain only distinct appointments. + assert_eq!(response.appointments.len(), distinct_appointment_numbers); + + // Verify that all appointments have the locator of the current iteration. for app_data in response.appointments { assert!(matches!( app_data.appointment_data, @@ -581,6 +610,30 @@ mod tests_private_api { )) if Locator::from_slice(app_loc).unwrap() == locator )); } + + // Query for the current locator with the optional user_id present + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.to_vec(), + user_id: Some(random_user_id.to_vec()), + })) + .await + .unwrap() + .into_inner(); + + // Verify that only a single appointment is returned + assert_eq!(response.appointments.len(), 1 ); + + // Verify that the appointment have the current locator + assert!(matches!( + response.appointments[0].appointment_data, + Some(common_msgs::appointment_data::AppointmentData::Appointment( + common_msgs::Appointment { + locator: ref app_loc, + .. + } + )) if Locator::from_slice(app_loc).unwrap() == locator + )); } } @@ -596,11 +649,19 @@ mod tests_private_api { // The number of different trackers to create for this dispute tx. let trackers_to_create = 4 * i + 7; + let random_tracker_num = thread_rng().gen_range(0..trackers_to_create); + let random_user_id = get_random_user_id(); + // Add that many trackers to the responder. - for _ in 0..trackers_to_create { + for i in 0..trackers_to_create { + let user_id = if i == random_tracker_num { + random_user_id + } else { + get_random_user_id() + }; let tracker = TransactionTracker::new( breach.clone(), - get_random_user_id(), + user_id, ConfirmationStatus::ConfirmedIn(100), ); internal_api @@ -610,7 +671,7 @@ mod tests_private_api { let locator = Locator::new(dispute_tx.txid()); - // Query for the current locator and assert it retrieves correct trackers. + // Query for the current locator without the optional user_id. let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), @@ -620,7 +681,7 @@ mod tests_private_api { .unwrap() .into_inner(); - // The response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration. + // Verify that the response should contain `trackers_to_create` trackers, all with dispute txid that matches with the locator of the current iteration. assert_eq!(response.appointments.len(), trackers_to_create); for app_data in response.appointments { assert!(matches!( @@ -633,6 +694,28 @@ mod tests_private_api { )) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator )); } + + // Query for the current locator with the optional user_id present. + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.to_vec(), + user_id: Some(random_user_id.to_vec()), + })) + .await + .unwrap() + .into_inner(); + + // Verify that only a single appointment is returned and the correct locator is found + assert_eq!(response.appointments.len(), 1); + assert!(matches!( + response.appointments[0].appointment_data, + Some(common_msgs::appointment_data::AppointmentData::Tracker( + common_msgs::Tracker { + ref dispute_txid, + .. + } + )) if Locator::new(Txid::from_slice(dispute_txid).unwrap()) == locator + )); } } diff --git a/teos/src/dbm.rs b/teos/src/dbm.rs index 3d38e348..cd900b96 100644 --- a/teos/src/dbm.rs +++ b/teos/src/dbm.rs @@ -338,13 +338,15 @@ impl DBM { "SELECT a.UUID, a.locator, a.encrypted_blob, a.to_self_delay, a.user_signature, a.start_block, a.user_id FROM appointments as a LEFT JOIN trackers as t ON a.UUID=t.UUID WHERE t.UUID IS NULL".to_string(); - // If a locator and an optional user_id were passed, filter based on it. - if let Some((_, user_id)) = locator_and_userid { + // If a locator was passed, filter based on it. + if locator_and_userid.is_some() { sql.push_str(" AND a.locator=(?1)"); - if user_id.is_some() { - sql.push_str(" AND a.user_id=(?2)"); - } - }; + } + + // If a user_id is passed, filter even more. + if locator_and_userid.is_some_and(|inner| inner.1.is_some()) { + sql.push_str(" AND a.user_id=(?2)"); + } let mut stmt = self.connection.prepare(&sql).unwrap(); @@ -611,12 +613,14 @@ impl DBM { FROM trackers as t INNER JOIN appointments as a ON t.UUID=a.UUID" .to_string(); - // If a locator and an optional user_id were passed, filter based on it. - if let Some((_, user_id)) = locator_and_userid { + // If a locator was passed, filter based on it. + if locator_and_userid.is_some() { sql.push_str(" AND a.locator=(?1)"); - if user_id.is_some() { - sql.push_str(" AND a.user_id=(?2)"); - } + } + + // If a user_id is passed, filter even more. + if locator_and_userid.is_some_and(|inner| inner.1.is_some()) { + sql.push_str(" AND a.user_id=(?2)"); } let mut stmt = self.connection.prepare(&sql).unwrap(); @@ -1201,7 +1205,7 @@ mod tests { let dispute_txid = dispute_tx.txid(); let locator = Locator::new(dispute_txid); - // create user id + // Create user id let user_id = get_random_user_id(); let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY); dbm.store_user(user_id, &user).unwrap(); @@ -1212,14 +1216,14 @@ mod tests { dbm.store_appointment(uuid, &appointment).unwrap(); appointments.insert(uuid, appointment.clone()); - // create random appointments + // Create random appointments for _ in 1..11 { let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None); dbm.store_appointment(uuid, &appointment).unwrap(); appointments.insert(uuid, appointment); } - // Returns empty if no appointment matches both userid and locator + // Verify that no appointment is returned if there is not an exact match of user_id + locator assert_eq!( dbm.load_appointments(Some((locator, Some(get_random_user_id()))),), HashMap::new() @@ -1229,17 +1233,17 @@ mod tests { HashMap::new() ); - // Returns particular appointments if they match both userid and locator + // Verify that the expected appointment is returned, if the correct user_id and locator is given assert_eq!( dbm.load_appointments(Some((locator, Some(user_id))),), HashMap::from([(uuid, appointment)]) ); - // Create a tracker from existing appointment + // Create a tracker from the existing appointment let tracker = get_random_tracker(user_id, ConfirmationStatus::InMempoolSince(100)); dbm.store_tracker(uuid, &tracker).unwrap(); - // ensure that no tracker is returned + // Verify that an appointment is not returned, if it is triggered (there's a tracker for it) assert_eq!( dbm.load_appointments(Some((locator, Some(user_id))),), HashMap::new() @@ -1624,7 +1628,7 @@ mod tests { let locator = Locator::new(dispute_txid); let status = ConfirmationStatus::InMempoolSince(42); - // create user id + // Create user id let user_id = get_random_user_id(); let user = UserInfo::new(AVAILABLE_SLOTS, SUBSCRIPTION_START, SUBSCRIPTION_EXPIRY); dbm.store_user(user_id, &user).unwrap(); @@ -1637,7 +1641,7 @@ mod tests { dbm.store_tracker(uuid, &tracker).unwrap(); trackers.insert(uuid, tracker.clone()); - // create random trackers + // Create random trackers for _ in 1..11 { let (uuid, appointment) = generate_dummy_appointment_with_user(user_id, None); let tracker = get_random_tracker(user_id, status); @@ -1645,7 +1649,7 @@ mod tests { dbm.store_tracker(uuid, &tracker).unwrap(); } - // Returns empty if no tracker matches both userid and locator + // Verify that no tracker is returned if there is not an exact match of user_id + locator assert_eq!( dbm.load_trackers(Some((locator, Some(get_random_user_id()))),), HashMap::new() @@ -1655,7 +1659,7 @@ mod tests { HashMap::new() ); - // Returns particular trackers if they match both userid and locator + // Verify that the expected tracker is returned if both the correct user_id and locator are provided assert_eq!( dbm.load_trackers(Some((locator, Some(user_id))),), HashMap::from([(uuid, tracker)]) diff --git a/teos/src/watcher.rs b/teos/src/watcher.rs index 1ee78af1..8a39441e 100644 --- a/teos/src/watcher.rs +++ b/teos/src/watcher.rs @@ -423,7 +423,8 @@ impl Watcher { self.dbm.lock().unwrap().load_appointments(None) } - /// Gets all the appointments matching a specific locator and an optional user id from the [Watcher] (from the database). + /// Gets all the appointments matching a specific locator + /// If a user id is provided (optional), only the appointments matching that user are returned pub(crate) fn get_watcher_appointments_with_locator( &self, locator: Locator,