diff --git a/teos/proto/teos/v2/appointment.proto b/teos/proto/teos/v2/appointment.proto index 67c66797..0d09086d 100644 --- a/teos/proto/teos/v2/appointment.proto +++ b/teos/proto/teos/v2/appointment.proto @@ -4,9 +4,10 @@ package teos.v2; import "common/teos/v2/appointment.proto"; message GetAppointmentsRequest { - // Request the information of appointments with specific locator. + // Request the information of appointments with specific locator and user_id (optional) . bytes locator = 1; + bytes user_id = 2; } message GetAppointmentsResponse { diff --git a/teos/src/api/internal.rs b/teos/src/api/internal.rs index 38396bc0..2df4d052 100644 --- a/teos/src/api/internal.rs +++ b/teos/src/api/internal.rs @@ -2,7 +2,6 @@ use std::sync::{Arc, Condvar, Mutex}; use tonic::{Code, Request, Response, Status}; use triggered::Trigger; -use crate::extended_appointment::UUID; use crate::protos as msgs; use crate::protos::private_tower_services_server::PrivateTowerServices; use crate::protos::public_tower_services_server::PublicTowerServices; @@ -280,27 +279,40 @@ impl PrivateTowerServices for Arc { .map_or("an unknown address".to_owned(), |a| a.to_string()) ); - let mut matching_appointments = vec![]; - let locator = Locator::from_slice(&request.into_inner().locator).map_err(|_| { + let req_data = request.into_inner(); + let locator = Locator::from_slice(&req_data.locator).map_err(|_| { Status::new( Code::InvalidArgument, "The provided locator does not match the expected format (16-byte hexadecimal string)", ) })?; - for (_, appointment) in self + let mut appointments: Vec<(UserId, Appointment)> = self .watcher .get_watcher_appointments_with_locator(locator) + .into_values() + .map(|appointment| (appointment.user_id, appointment.inner)) + .collect(); + + let user_id_slice = req_data.user_id; + if !(&user_id_slice.is_empty()) { + let user_id = UserId::from_slice(&user_id_slice).map_err(|_| { + Status::new( + Code::InvalidArgument, + "The Provided user_id does not match expected format (33-byte hex string)", + ) + })?; + appointments.retain(|(appointment_user_id, _)| *appointment_user_id == user_id); + } + + let mut matching_appointments: Vec = appointments .into_iter() - { - matching_appointments.push(common_msgs::AppointmentData { + .map(|(_, appointment)| common_msgs::AppointmentData { appointment_data: Some( - common_msgs::appointment_data::AppointmentData::Appointment( - appointment.inner.into(), - ), + common_msgs::appointment_data::AppointmentData::Appointment(appointment.into()), ), }) - } + .collect(); for (_, tracker) in self .watcher @@ -511,7 +523,10 @@ mod tests_private_api { let locator = Locator::new(get_random_tx().txid()).to_vec(); let response = internal_api - .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator })) + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator, + user_id: Vec::new(), + })) .await .unwrap() .into_inner(); @@ -519,6 +534,86 @@ mod tests_private_api { assert!(matches!(response, msgs::GetAppointmentsResponse { .. })); } + #[tokio::test] + async fn test_get_appointments_with_and_without_user_id() { + // setup + let (internal_api, _s) = create_api().await; + let random_txn = get_random_tx(); + let (user_sk1, user_pk1) = get_random_keypair(); + let user_id1 = UserId(user_pk1); + let (user_sk2, user_pk2) = get_random_keypair(); + let user_id2 = UserId(user_pk2); + internal_api.watcher.register(user_id1).unwrap(); + internal_api.watcher.register(user_id2).unwrap(); + let appointment1 = + generate_dummy_appointment_with_user(user_id1, Some(&random_txn.clone().txid())) + .1 + .inner; + let signature1 = cryptography::sign(&appointment1.to_vec(), &user_sk1).unwrap(); + let appointment2 = + generate_dummy_appointment_with_user(user_id2, Some(&random_txn.clone().txid())) + .1 + .inner; + let signature2 = cryptography::sign(&appointment2.to_vec(), &user_sk2).unwrap(); + internal_api + .watcher + .add_appointment(appointment1.clone(), signature1) + .unwrap(); + internal_api + .watcher + .add_appointment(appointment2.clone(), signature2) + .unwrap(); + + let locator = &appointment1.locator; + + // returns all appointments if user_id is absent + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.clone().to_vec(), + user_id: Vec::new(), + })) + .await + .unwrap() + .into_inner(); + let dummy_appointments = response.appointments; + assert_eq!(&dummy_appointments.len(), &2); + let responses: Vec> = dummy_appointments + .into_iter() + .filter_map(|data| { + if let Some(common_msgs::appointment_data::AppointmentData::Appointment( + appointment, + )) = data.appointment_data + { + return Some(appointment.locator); + } + return None; + }) + .collect(); + assert_eq!(responses[0], locator.clone().to_vec()); + assert_eq!(responses[1], locator.clone().to_vec()); + + // returns specific appointments if user_id is absent + let response = internal_api + .get_appointments(Request::new(msgs::GetAppointmentsRequest { + locator: locator.clone().to_vec(), + user_id: user_id1.clone().to_vec(), + })) + .await + .unwrap() + .into_inner(); + let dummy_appointments = response.appointments; + assert_eq!(&dummy_appointments.len(), &1); + let dummy_appointmnets_data = &dummy_appointments[0].appointment_data; + assert!( + matches!(dummy_appointmnets_data.clone(), Some(common_msgs::appointment_data::AppointmentData::Appointment( + common_msgs::Appointment { + locator: ref app_loc, + .. + } + )) if app_loc.clone() == locator.to_vec() ) + ) + } + #[tokio::test] async fn test_get_appointments_watcher() { let (internal_api, _s) = create_api().await; @@ -548,6 +643,7 @@ mod tests_private_api { let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), + user_id: Vec::new(), })) .await .unwrap() @@ -599,6 +695,7 @@ mod tests_private_api { let response = internal_api .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), + user_id: Vec::new(), })) .await .unwrap() @@ -747,7 +844,10 @@ mod tests_private_api { assert_eq!(response.available_slots, SLOTS - 1); assert_eq!(response.subscription_expiry, START_HEIGHT as u32 + DURATION); - assert_eq!(response.appointments, Vec::from([appointment.inner.locator.to_vec()])); + assert_eq!( + response.appointments, + Vec::from([appointment.inner.locator.to_vec()]) + ); } #[tokio::test] diff --git a/teos/src/cli.rs b/teos/src/cli.rs index 3ef1d9ff..5ef9928f 100644 --- a/teos/src/cli.rs +++ b/teos/src/cli.rs @@ -75,11 +75,21 @@ async fn main() { println!("{}", pretty_json(&appointments.into_inner()).unwrap()); } Command::GetAppointments(appointments_data) => { + let mut user_id = Vec::new(); + if let Some(i) = &appointments_data.user_id { + match UserId::from_str(i.as_str()) { + Ok(parsed_user_id) => { + user_id = parsed_user_id.to_vec(); + } + Err(err) => handle_error(err), + } + } match Locator::from_hex(&appointments_data.locator) { Ok(locator) => { match client .get_appointments(Request::new(msgs::GetAppointmentsRequest { locator: locator.to_vec(), + user_id, })) .await { diff --git a/teos/src/cli_config.rs b/teos/src/cli_config.rs index ba085b8d..8c23c467 100644 --- a/teos/src/cli_config.rs +++ b/teos/src/cli_config.rs @@ -31,6 +31,8 @@ pub struct GetUserData { pub struct GetAppointmentsData { /// The locator of the appointments (16-byte hexadecimal string). pub locator: String, + /// The user identifier (33-byte compressed public key). + pub user_id: Option, } /// Holds all the command line options and commands.