diff --git a/Cargo.toml b/Cargo.toml index dd0989d..39dd62c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,3 +34,4 @@ pretty_assertions = "1.1.0" test-case = "1.2.3" testcontainers = { git = "https://github.com/kezhuw/testcontainers-rs.git", branch = "zookeeper-client" } futures = "0.3.21" +speculoos = "0.9.0" diff --git a/src/client/mod.rs b/src/client/mod.rs index 7d23815..6bfec48 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,12 +1,13 @@ mod watcher; +use std::future::Future; use std::time::Duration; use const_format::formatcp; -use tokio::sync::{mpsc, oneshot, watch}; +use tokio::sync::{mpsc, watch}; pub use self::watcher::{OneshotWatcher, PersistentWatcher, StateWatcher}; -use super::session::{self, AuthResponser, Depot, Session, SessionOperation, WatchReceiver}; +use super::session::{Depot, Session, SessionOperation, WatchReceiver}; use crate::acl::{Acl, AuthUser}; use crate::error::{ConnectError, Error}; use crate::proto::{ @@ -29,9 +30,11 @@ use crate::proto::{ }; pub use crate::proto::{EnsembleUpdate, Stat}; use crate::record::{self, Record, StaticRecord}; -pub use crate::session::{EventType, SessionId, SessionState, WatchedEvent}; +pub use crate::session::{EventType, SessionId, SessionState, StateReceiver, WatchedEvent}; use crate::util::{self, Ref as _}; +type Result = std::result::Result; + /// CreateMode specifies ZooKeeper znode type. It covers all znode types with help from /// [CreateOptions::with_ttl]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -119,7 +122,7 @@ impl<'a> CreateOptions<'a> { self } - fn validate(&'a self) -> Result<(), Error> { + fn validate(&'a self) -> Result<()> { if let Some(ref ttl) = self.ttl { if self.mode != CreateMode::Persistent && self.mode != CreateMode::PersistentSequential { return Err(Error::BadArguments(&"ttl can only be specified with persistent node")); @@ -157,13 +160,15 @@ impl std::fmt::Display for CreateSequence { /// /// # Notable behaviors /// * All cloned clients share same authentication identities. +/// * All methods construct resulting future by sending request synchronously and polling output +/// asynchronously. This guarantees that requests are sending to server in the order of method +/// call but not future evaluation. #[derive(Clone, Debug)] pub struct Client { root: String, session: (SessionId, Vec), session_timeout: Duration, - requester: mpsc::Sender, - auth_requester: mpsc::Sender<(AuthPacket, AuthResponser)>, + requester: mpsc::UnboundedSender, state_watcher: StateWatcher, } @@ -171,7 +176,7 @@ impl Client { const CONFIG_NODE: &'static str = "/zookeeper/config"; /// Connects to ZooKeeper cluster with specified session timeout. - pub async fn connect(cluster: &str, timeout: Duration) -> Result { + pub async fn connect(cluster: &str, timeout: Duration) -> std::result::Result { return ClientBuilder::new(timeout).connect(cluster).await; } @@ -179,19 +184,18 @@ impl Client { root: String, session: (SessionId, Vec), timeout: Duration, - requester: mpsc::Sender, - auth_requester: mpsc::Sender<(AuthPacket, AuthResponser)>, + requester: mpsc::UnboundedSender, state_receiver: watch::Receiver, ) -> Client { let state_watcher = StateWatcher::new(state_receiver); - Client { root, session, session_timeout: timeout, requester, auth_requester, state_watcher } + Client { root, session, session_timeout: timeout, requester, state_watcher } } - fn validate_path<'a>(&self, path: &'a str) -> Result<(&'a str, bool), Error> { + fn validate_path<'a>(&self, path: &'a str) -> Result<(&'a str, bool)> { return util::validate_path(self.root.as_str(), path, false); } - fn validate_sequential_path<'a>(&self, path: &'a str) -> Result<(&'a str, bool), Error> { + fn validate_sequential_path<'a>(&self, path: &'a str) -> Result<(&'a str, bool)> { util::validate_path(&self.root, path, true) } @@ -227,7 +231,7 @@ impl Client { /// /// # Notable behaviors /// * Existing watchers are not affected. - pub fn chroot(mut self, root: &str) -> Result { + pub fn chroot(mut self, root: &str) -> std::result::Result { let is_zookeeper_root = match util::validate_path("", root, false) { Err(_) => return Err(self), Ok((_, is_zookeeper_root)) => is_zookeeper_root, @@ -239,16 +243,38 @@ impl Client { Ok(self) } - async fn request(&self, code: OpCode, body: &impl Record) -> Result<(Vec, WatchReceiver), Error> { - let (operation, receiver) = session::build_state_operation(code, body); - if self.requester.send(operation).await.is_err() { + fn send_request(&self, code: OpCode, body: &impl Record) -> StateReceiver { + let (operation, receiver) = SessionOperation::new(code, body).with_responser(); + if let Err(mpsc::error::SendError(operation)) = self.requester.send(operation) { let state = self.state(); - return Err(state.to_error()); + operation.responser.send(Err(state.to_error())); } - return receiver.await.unwrap(); + receiver } - fn parse_sequence(client_path: &str, path: &str) -> Result { + async fn wait(result: Result) -> Result + where + F: Future>, { + match result { + Err(err) => Err(err), + Ok(future) => future.await, + } + } + + async fn map_wait(result: Result, f: Fn) -> Result + where + Fu: Future>, + Fn: FnOnce(T) -> U, { + match result { + Err(err) => Err(err), + Ok(future) => match future.await { + Err(err) => Err(err), + Ok(t) => Ok(f(t)), + }, + } + } + + fn parse_sequence(client_path: &str, path: &str) -> Result { if let Some(sequence_path) = client_path.strip_prefix(path) { match sequence_path.parse::() { Err(_) => Err(Error::UnexpectedError(format!("sequential node get no i32 path {}", client_path))), @@ -269,12 +295,21 @@ impl Client { /// * [Error::NoNode] if parent node does not exist. /// * [Error::NoChildrenForEphemerals] if parent node is ephemeral. /// * [Error::InvalidAcl] if acl is invalid or empty. - pub async fn create( - &self, - path: &str, + pub fn create<'a: 'f, 'b: 'f, 'f>( + &'a self, + path: &'b str, + data: &[u8], + options: &CreateOptions<'_>, + ) -> impl Future> + Send + 'f { + Self::wait(self.create_internally(path, data, options)) + } + + fn create_internally<'a: 'f, 'b: 'f, 'f>( + &'a self, + path: &'b str, data: &[u8], options: &CreateOptions<'_>, - ) -> Result<(Stat, CreateSequence), Error> { + ) -> Result> + Send + 'f> { options.validate()?; let create_mode = options.mode; let sequential = create_mode.is_sequential(); @@ -289,13 +324,16 @@ impl Client { }; let flags = create_mode.as_flags(ttl != 0); let request = CreateRequest { path: RootedPath::new(&self.root, leaf), data, acls: options.acls, flags, ttl }; - let (body, _) = self.request(op_code, &request).await?; - let mut buf = body.as_slice(); - let server_path = record::unmarshal_entity::<&str>(&"server path", &mut buf)?; - let client_path = util::strip_root_path(server_path, &self.root)?; - let sequence = if sequential { Self::parse_sequence(client_path, path)? } else { CreateSequence(-1) }; - let stat = record::unmarshal::(&mut buf)?; - Ok((stat, sequence)) + let receiver = self.send_request(op_code, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let server_path = record::unmarshal_entity::<&str>(&"server path", &mut buf)?; + let client_path = util::strip_root_path(server_path, &self.root)?; + let sequence = if sequential { Self::parse_sequence(client_path, path)? } else { CreateSequence(-1) }; + let stat = record::unmarshal::(&mut buf)?; + Ok((stat, sequence)) + }) } /// Deletes node with specified path. @@ -304,41 +342,51 @@ impl Client { /// * [Error::NoNode] if such node does not exist. /// * [Error::BadVersion] if such node exists but has different version. /// * [Error::NotEmpty] if such node exists but has children. - pub async fn delete(&self, path: &str, expected_version: Option) -> Result<(), Error> { + pub fn delete(&self, path: &str, expected_version: Option) -> impl Future> + Send { + Self::wait(self.delete_internally(path, expected_version)) + } + + fn delete_internally(&self, path: &str, expected_version: Option) -> Result>> { let (leaf, _) = self.validate_path(path)?; if leaf.is_empty() { return Err(Error::BadArguments(&"can not delete root node")); } let request = DeleteRequest { path: RootedPath::new(&self.root, leaf), version: expected_version.unwrap_or(-1) }; - self.request(OpCode::Delete, &request).await?; - Ok(()) + let receiver = self.send_request(OpCode::Delete, &request); + Ok(async move { + receiver.await?; + Ok(()) + }) } - async fn get_data_internally( + fn get_data_internally( &self, root: &str, - leaf: &str, + path: &str, watch: bool, - ) -> Result<(Vec, Stat, WatchReceiver), Error> { + ) -> Result, Stat, WatchReceiver)>> + Send> { + let (leaf, _) = self.validate_path(path)?; let request = GetRequest { path: RootedPath::new(root, leaf), watch }; - let (mut body, watcher) = self.request(OpCode::GetData, &request).await?; - let data_len = body.len() - Stat::record_len(); - let mut stat_buf = &body[data_len..]; - let stat = record::unmarshal(&mut stat_buf)?; - body.truncate(data_len); - drop(body.drain(..4)); - Ok((body, stat, watcher)) + let receiver = self.send_request(OpCode::GetData, &request); + Ok(async move { + let (mut body, watcher) = receiver.await?; + let data_len = body.len() - Stat::record_len(); + let mut stat_buf = &body[data_len..]; + let stat = record::unmarshal(&mut stat_buf)?; + body.truncate(data_len); + drop(body.drain(..4)); + Ok((body, stat, watcher)) + }) } /// Gets stat and data for node with given path. /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn get_data(&self, path: &str) -> Result<(Vec, Stat), Error> { - let (leaf, _) = self.validate_path(path)?; - let (data, stat, _) = self.get_data_internally(&self.root, leaf, false).await?; - Ok((data, stat)) + pub fn get_data(&self, path: &str) -> impl Future, Stat)>> + Send { + let result = self.get_data_internally(&self.root, path, false); + Self::map_wait(result, |(data, stat, _)| (data, stat)) } /// Gets stat and data for node with given path, and watches node deletion and data change. @@ -350,25 +398,33 @@ impl Client { /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn get_and_watch_data(&self, path: &str) -> Result<(Vec, Stat, OneshotWatcher), Error> { - let (leaf, _) = self.validate_path(path)?; - let (data, stat, watch_receiver) = self.get_data_internally(&self.root, leaf, true).await?; - Ok((data, stat, watch_receiver.into_oneshot(&self.root))) + pub fn get_and_watch_data( + &self, + path: &str, + ) -> impl Future, Stat, OneshotWatcher)>> + Send + '_ { + let result = self.get_data_internally(&self.root, path, true); + Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(&self.root))) } - async fn check_stat_internally(&self, path: &str, watch: bool) -> Result<(Option, WatchReceiver), Error> { + fn check_stat_internally( + &self, + path: &str, + watch: bool, + ) -> Result, WatchReceiver)>>> { let (leaf, _) = self.validate_path(path)?; let request = ExistsRequest { path: RootedPath::new(&self.root, leaf), watch }; - let (body, watcher) = self.request(OpCode::Exists, &request).await?; - let mut buf = body.as_slice(); - let stat = record::try_deserialize(&mut buf)?; - Ok((stat, watcher)) + let receiver = self.send_request(OpCode::Exists, &request); + Ok(async move { + let (body, watcher) = receiver.await?; + let mut buf = body.as_slice(); + let stat = record::try_deserialize(&mut buf)?; + Ok((stat, watcher)) + }) } /// Checks stat for node with given path. - pub async fn check_stat(&self, path: &str) -> Result, Error> { - let (stat, _) = self.check_stat_internally(path, false).await?; - Ok(stat) + pub fn check_stat(&self, path: &str) -> impl Future>> + Send { + Self::map_wait(self.check_stat_internally(path, false), |(stat, _)| stat) } /// Checks stat for node with given path, and watches node creation, deletion and data change. @@ -377,9 +433,12 @@ impl Client { /// * Data change. /// * Node creation and deletion. /// * Session expiration. - pub async fn check_and_watch_stat(&self, path: &str) -> Result<(Option, OneshotWatcher), Error> { - let (stat, watch_receiver) = self.check_stat_internally(path, true).await?; - Ok((stat, watch_receiver.into_oneshot(&self.root))) + pub fn check_and_watch_stat( + &self, + path: &str, + ) -> impl Future, OneshotWatcher)>> + Send + '_ { + let result = self.check_stat_internally(path, true); + Self::map_wait(result, |(stat, watcher)| (stat, watcher.into_oneshot(&self.root))) } /// Sets data for node with given path and returns updated stat. @@ -387,33 +446,56 @@ impl Client { /// # Notable errors /// * [Error::NoNode] if such node does not exist. /// * [Error::BadVersion] if such node exists but has different version. - pub async fn set_data(&self, path: &str, data: &[u8], expected_version: Option) -> Result { + pub fn set_data( + &self, + path: &str, + data: &[u8], + expected_version: Option, + ) -> impl Future> + Send { + Self::wait(self.set_data_internally(path, data, expected_version)) + } + + pub fn set_data_internally( + &self, + path: &str, + data: &[u8], + expected_version: Option, + ) -> Result>> { let (leaf, _) = self.validate_path(path)?; let request = SetDataRequest { path: RootedPath::new(&self.root, leaf), data, version: expected_version.unwrap_or(-1) }; - let (body, _) = self.request(OpCode::SetData, &request).await?; - let mut buf = body.as_slice(); - let stat: Stat = record::unmarshal(&mut buf)?; - Ok(stat) + let receiver = self.send_request(OpCode::SetData, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let stat: Stat = record::unmarshal(&mut buf)?; + Ok(stat) + }) } - async fn list_children_internally(&self, path: &str, watch: bool) -> Result<(Vec, WatchReceiver), Error> { + fn list_children_internally( + &self, + path: &str, + watch: bool, + ) -> Result, WatchReceiver)>>> { let (leaf, _) = self.validate_path(path)?; let request = GetChildrenRequest { path: RootedPath::new(&self.root, leaf), watch }; - let (body, watcher) = self.request(OpCode::GetChildren, &request).await?; - let mut buf = body.as_slice(); - let children = record::unmarshal_entity::>(&"children paths", &mut buf)?; - let children = children.into_iter().map(|child| child.to_owned()).collect(); - Ok((children, watcher)) + let receiver = self.send_request(OpCode::GetChildren, &request); + Ok(async move { + let (body, watcher) = receiver.await?; + let mut buf = body.as_slice(); + let children = record::unmarshal_entity::>(&"children paths", &mut buf)?; + let children = children.into_iter().map(|child| child.to_owned()).collect(); + Ok((children, watcher)) + }) } /// Lists children for node with given path. /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn list_children(&self, path: &str) -> Result, Error> { - let (children, _) = self.list_children_internally(path, false).await?; - Ok(children) + pub fn list_children(&self, path: &str) -> impl Future>> + Send + '_ { + Self::map_wait(self.list_children_internally(path, false), |(children, _)| children) } /// Lists children for node with given path, and watches node deletion, children creation and @@ -426,32 +508,38 @@ impl Client { /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn list_and_watch_children(&self, path: &str) -> Result<(Vec, OneshotWatcher), Error> { - let (children, watcher) = self.list_children_internally(path, true).await?; - Ok((children, watcher.into_oneshot(&self.root))) + pub fn list_and_watch_children( + &self, + path: &str, + ) -> impl Future, OneshotWatcher)>> + Send + '_ { + let result = self.list_children_internally(path, true); + Self::map_wait(result, |(children, watcher)| (children, watcher.into_oneshot(&self.root))) } - async fn get_children_internally( + fn get_children_internally( &self, path: &str, watch: bool, - ) -> Result<(Vec, Stat, WatchReceiver), Error> { + ) -> Result, Stat, WatchReceiver)>>> { let (leaf, _) = self.validate_path(path)?; let request = GetChildrenRequest { path: RootedPath::new(&self.root, leaf), watch }; - let (body, watcher) = self.request(OpCode::GetChildren2, &request).await?; - let mut buf = body.as_slice(); - let response = record::unmarshal::(&mut buf)?; - let children = response.children.into_iter().map(|s| s.to_owned()).collect(); - Ok((children, response.stat, watcher)) + let receiver = self.send_request(OpCode::GetChildren2, &request); + Ok(async move { + let (body, watcher) = receiver.await?; + let mut buf = body.as_slice(); + let response = record::unmarshal::(&mut buf)?; + let children = response.children.into_iter().map(|s| s.to_owned()).collect(); + Ok((children, response.stat, watcher)) + }) } /// Gets stat and children for node with given path. /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn get_children(&self, path: &str) -> Result<(Vec, Stat), Error> { - let (children, stat, _) = self.get_children_internally(path, false).await?; - Ok((children, stat)) + pub fn get_children(&self, path: &str) -> impl Future, Stat)>> + Send { + let result = self.get_children_internally(path, false); + Self::map_wait(result, |(children, stat, _)| (children, stat)) } /// Gets stat and children for node with given path, and watches node deletion, children @@ -464,22 +552,32 @@ impl Client { /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn get_and_watch_children(&self, path: &str) -> Result<(Vec, Stat, OneshotWatcher), Error> { - let (children, stat, watcher) = self.get_children_internally(path, true).await?; - Ok((children, stat, watcher.into_oneshot(&self.root))) + pub fn get_and_watch_children( + &self, + path: &str, + ) -> impl Future, Stat, OneshotWatcher)>> + Send + '_ { + let result = self.get_children_internally(path, true); + Self::map_wait(result, |(children, stat, watcher)| (children, stat, watcher.into_oneshot(&self.root))) } /// Counts descendants number for node with given path. /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn count_descendants_number(&self, path: &str) -> Result { + pub fn count_descendants_number(&self, path: &str) -> impl Future> + Send { + Self::wait(self.count_descendants_number_internally(path)) + } + + fn count_descendants_number_internally(&self, path: &str) -> Result>> { let (leaf, _) = self.validate_path(path)?; let request = RootedPath::new(&self.root, leaf); - let (body, _) = self.request(OpCode::GetAllChildrenNumber, &request).await?; - let mut buf = body.as_slice(); - let n = record::unmarshal_entity::(&"all children number", &mut buf)?; - Ok(n as usize) + let receiver = self.send_request(OpCode::GetAllChildrenNumber, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let n = record::unmarshal_entity::(&"all children number", &mut buf)?; + Ok(n as usize) + }) } /// Lists all ephemerals nodes that created by current session and starts with given path. @@ -488,29 +586,43 @@ impl Client { /// * No [Error::NoNode] if node with give path does not exist. /// * Result will include given path if that node is ephemeral. /// * Returned paths are located at chroot but not ZooKeeper root. - pub async fn list_ephemerals(&self, path: &str) -> Result, Error> { + pub fn list_ephemerals(&self, path: &str) -> impl Future>> + Send + '_ { + Self::wait(self.list_ephemerals_internally(path)) + } + + fn list_ephemerals_internally(&self, path: &str) -> Result>> + Send + '_> { let (leaf, _) = self.validate_path(path)?; let request = RootedPath::new(&self.root, leaf); - let (body, _) = self.request(OpCode::GetEphemerals, &request).await?; - let mut buf = body.as_slice(); - let mut ephemerals = record::unmarshal_entity::>(&"ephemerals", &mut buf)?; - for ephemeral_path in ephemerals.iter_mut() { - util::drain_root_path(ephemeral_path, &self.root)?; - } - Ok(ephemerals) + let receiver = self.send_request(OpCode::GetEphemerals, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let mut ephemerals = record::unmarshal_entity::>(&"ephemerals", &mut buf)?; + for ephemeral_path in ephemerals.iter_mut() { + util::drain_root_path(ephemeral_path, &self.root)?; + } + Ok(ephemerals) + }) } /// Gets acl and stat for node with given path. /// /// # Notable errors /// * [Error::NoNode] if such node does not exist. - pub async fn get_acl(&self, path: &str) -> Result<(Vec, Stat), Error> { + pub fn get_acl(&self, path: &str) -> impl Future, Stat)>> + Send + '_ { + Self::wait(self.get_acl_internally(path)) + } + + fn get_acl_internally(&self, path: &str) -> Result, Stat)>>> { let (leaf, _) = self.validate_path(path)?; let request = RootedPath::new(&self.root, leaf); - let (body, _) = self.request(OpCode::GetACL, &request).await?; - let mut buf = body.as_slice(); - let response: GetAclResponse = record::unmarshal(&mut buf)?; - Ok((response.acl, response.stat)) + let receiver = self.send_request(OpCode::GetACL, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let response: GetAclResponse = record::unmarshal(&mut buf)?; + Ok((response.acl, response.stat)) + }) } /// Sets acl for node with given path and returns updated stat. @@ -518,14 +630,31 @@ impl Client { /// # Notable errors /// * [Error::NoNode] if such node does not exist. /// * [Error::BadVersion] if such node exists but has different acl version. - pub async fn set_acl(&self, path: &str, acl: &[Acl], expected_acl_version: Option) -> Result { + pub fn set_acl( + &self, + path: &str, + acl: &[Acl], + expected_acl_version: Option, + ) -> impl Future> + Send + '_ { + Self::wait(self.set_acl_internally(path, acl, expected_acl_version)) + } + + fn set_acl_internally( + &self, + path: &str, + acl: &[Acl], + expected_acl_version: Option, + ) -> Result>> { let (leaf, _) = self.validate_path(path)?; let request = SetAclRequest { path: RootedPath::new(&self.root, leaf), acl, version: expected_acl_version.unwrap_or(-1) }; - let (body, _) = self.request(OpCode::SetACL, &request).await?; - let mut buf = body.as_slice(); - let stat: Stat = record::unmarshal(&mut buf)?; - Ok(stat) + let receiver = self.send_request(OpCode::SetACL, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let stat: Stat = record::unmarshal(&mut buf)?; + Ok(stat) + }) } /// Watches possible nonexistent path using specified mode. @@ -540,12 +669,23 @@ impl Client { /// persistent watch on same path. /// /// [ZOOKEEPER-4466]: https://issues.apache.org/jira/browse/ZOOKEEPER-4466 - pub async fn watch(&self, path: &str, mode: AddWatchMode) -> Result { + pub fn watch(&self, path: &str, mode: AddWatchMode) -> impl Future> + Send + '_ { + Self::wait(self.watch_internally(path, mode)) + } + + fn watch_internally( + &self, + path: &str, + mode: AddWatchMode, + ) -> Result> + Send + '_> { let (leaf, _) = self.validate_path(path)?; let proto_mode = proto::AddWatchMode::from(mode); let request = PersistentWatchRequest { path: RootedPath::new(&self.root, leaf), mode: proto_mode.into() }; - let (_, watcher) = self.request(OpCode::AddWatch, &request).await?; - Ok(watcher.into_persistent(&self.root)) + let receiver = self.send_request(OpCode::AddWatch, &request); + Ok(async move { + let (_, watcher) = receiver.await?; + Ok(watcher.into_persistent(&self.root)) + }) } /// Syncs with ZooKeeper **leader**. @@ -558,16 +698,24 @@ impl Client { /// /// [ZOOKEEPER-1675]: https://issues.apache.org/jira/browse/ZOOKEEPER-1675 /// [ZOOKEEPER-2136]: https://issues.apache.org/jira/browse/ZOOKEEPER-2136 - pub async fn sync(&self, path: &str) -> Result<(), Error> { + pub fn sync(&self, path: &str) -> impl Future> + Send + '_ { + Self::wait(self.sync_internally(path)) + } + + fn sync_internally(&self, path: &str) -> Result>> { let (leaf, _) = self.validate_path(path)?; let request = SyncRequest { path: RootedPath::new(&self.root, leaf) }; - let (body, _) = self.request(OpCode::Sync, &request).await?; - let mut buf = body.as_slice(); - record::unmarshal_entity::<&str>(&"server path", &mut buf)?; - Ok(()) + let receiver = self.send_request(OpCode::Sync, &request); + Ok(async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + record::unmarshal_entity::<&str>(&"server path", &mut buf)?; + Ok(()) + }) } - /// Authenticates session using given scheme and auth identication. + /// Authenticates session using given scheme and auth identication. This affects only + /// subsequent operations. /// /// # Errors /// * [Error::AuthFailed] if authentication failed. @@ -576,15 +724,16 @@ impl Client { /// # Notable behaviors /// * Same auth will be resubmitted for authentication after session reestablished. /// * This method is resistent to temporary session unavailability, that means - /// [SessionState::Disconnected] will not end authentication. - pub async fn auth(&self, scheme: String, auth: Vec) -> Result<(), Error> { - let (sender, receiver) = oneshot::channel(); - let auth_packet = AuthPacket { scheme, auth }; - if self.auth_requester.send((auth_packet, sender)).await.is_err() { - let state = self.state(); - return Err(state.to_error()); + /// [SessionState::Disconnected] will not end authentication. + /// * It is ok to ignore resulting future of this method as request is sending synchronously + /// and auth failure will fail ZooKeeper session with [SessionState::AuthFailed]. + pub fn auth(&self, scheme: String, auth: Vec) -> impl Future> + Send + '_ { + let request = AuthPacket { scheme, auth }; + let receiver = self.send_request(OpCode::Auth, &request); + async move { + receiver.await?; + Ok(()) } - return receiver.await.unwrap(); } /// Gets all authentication informations attached to current session. @@ -596,23 +745,26 @@ impl Client { /// * [ZOOKEEPER-3969][] Add whoami API and Cli command. /// /// [ZOOKEEPER-3969]: https://issues.apache.org/jira/browse/ZOOKEEPER-3969 - pub async fn list_auth_users(&self) -> Result, Error> { - let (body, _) = self.request(OpCode::WhoAmI, &()).await?; - let mut buf = body.as_slice(); - let authed_users = record::unmarshal_entity::>(&"authed users", &mut buf)?; - Ok(authed_users) + pub fn list_auth_users(&self) -> impl Future>> + Send { + let receiver = self.send_request(OpCode::WhoAmI, &()); + async move { + let (body, _) = receiver.await?; + let mut buf = body.as_slice(); + let authed_users = record::unmarshal_entity::>(&"authed users", &mut buf)?; + Ok(authed_users) + } } /// Gets data for ZooKeeper config node, that is node with path "/zookeeper/config". - pub async fn get_config(&self) -> Result<(Vec, Stat), Error> { - let (data, stat, _) = self.get_data_internally(Self::CONFIG_NODE, Default::default(), false).await?; - Ok((data, stat)) + pub fn get_config(&self) -> impl Future, Stat)>> + Send { + let result = self.get_data_internally("", Self::CONFIG_NODE, false); + Self::map_wait(result, |(data, stat, _)| (data, stat)) } /// Gets stat and data for ZooKeeper config node, that is node with path "/zookeeper/config". - pub async fn get_and_watch_config(&self) -> Result<(Vec, Stat, OneshotWatcher), Error> { - let (data, stat, watcher) = self.get_data_internally(Self::CONFIG_NODE, Default::default(), true).await?; - Ok((data, stat, watcher.into_oneshot(""))) + pub fn get_and_watch_config(&self) -> impl Future, Stat, OneshotWatcher)>> + Send { + let result = self.get_data_internally("", Self::CONFIG_NODE, true); + Self::map_wait(result, |(data, stat, watcher)| (data, stat, watcher.into_oneshot(""))) } /// Updates ZooKeeper ensemble. @@ -622,20 +774,23 @@ impl Client { /// /// # References /// See [ZooKeeper Dynamic Reconfiguration](https://zookeeper.apache.org/doc/current/zookeeperReconfig.html). - pub async fn update_ensemble<'a, I: Iterator + Clone>( + pub fn update_ensemble<'a, I: Iterator + Clone>( &self, update: EnsembleUpdate<'a, I>, expected_version: Option, - ) -> Result<(Vec, Stat), Error> { + ) -> impl Future, Stat)>> + Send { let request = ReconfigRequest { update, version: expected_version.unwrap_or(-1) }; - let (mut body, _) = self.request(OpCode::Reconfig, &request).await?; - let mut buf = body.as_slice(); - let data: &str = record::unmarshal_entity(&"reconfig data", &mut buf)?; - let stat = record::unmarshal_entity(&"reconfig stat", &mut buf)?; - let data_len = data.len(); - body.truncate(data_len + 4); - drop(body.drain(..4)); - Ok((body, stat)) + let receiver = self.send_request(OpCode::Reconfig, &request); + async move { + let (mut body, _) = receiver.await?; + let mut buf = body.as_slice(); + let data: &str = record::unmarshal_entity(&"reconfig data", &mut buf)?; + let stat = record::unmarshal_entity(&"reconfig stat", &mut buf)?; + let data_len = data.len(); + body.truncate(data_len + 4); + drop(body.drain(..4)); + Ok((body, stat)) + } } } @@ -666,32 +821,24 @@ impl ClientBuilder { } /// Connects to ZooKeeper cluster. - /// - /// # Notable behaviors - /// * On success, authes were consumed. - pub async fn connect(&mut self, cluster: &str) -> Result { + pub async fn connect(&mut self, cluster: &str) -> std::result::Result { let (hosts, root) = util::parse_connect_string(cluster)?; let mut buf = Vec::with_capacity(4096); let mut connecting_depot = Depot::for_connecting(); - let authes = std::mem::take(&mut self.authes); - let (mut session, state_receiver) = Session::new(self.timeout, authes, self.readonly); + let (mut session, state_receiver) = Session::new(self.timeout, &self.authes, self.readonly); let mut hosts_iter = hosts.iter().copied(); let sock = match session.start(&mut hosts_iter, &mut buf, &mut connecting_depot).await { Ok(sock) => sock, - Err(err) => { - self.authes = std::mem::take(&mut session.authes); - return Err(ConnectError::from(err)); - }, + Err(err) => return Err(ConnectError::from(err)), }; - let (sender, receiver) = mpsc::channel(512); - let (auth_sender, auth_receiver) = mpsc::channel(10); + let (sender, receiver) = mpsc::unbounded_channel(); let servers = hosts.into_iter().map(|addr| addr.to_value()).collect(); let session_info = (session.session_id, session.session_password.clone()); let session_timeout = session.session_timeout; tokio::spawn(async move { - session.serve(servers, sock, buf, connecting_depot, receiver, auth_receiver).await; + session.serve(servers, sock, buf, connecting_depot, receiver).await; }); - let client = Client::new(root.to_string(), session_info, session_timeout, sender, auth_sender, state_receiver); + let client = Client::new(root.to_string(), session_info, session_timeout, sender, state_receiver); Ok(client) } } diff --git a/src/proto/consts.rs b/src/proto/consts.rs index 3fa2009..ebae47f 100644 --- a/src/proto/consts.rs +++ b/src/proto/consts.rs @@ -4,7 +4,16 @@ use num_enum::{IntoPrimitive, TryFromPrimitive}; #[derive(Copy, Clone, Debug, PartialEq, Eq, IntoPrimitive)] pub enum PredefinedXid { Notification = -1, + /// ZooKeeper server [hard-code -2 as ping response xid][ping-xid], so we have to use this and make sure + /// at most one ping in wire. + /// + /// ping-xid: https://github.com/apache/zookeeper/blob/de7c5869d372e46af43979134d0e30b49d2319b1/zookeeper-server/src/main/java/org/apache/zookeeper/server/FinalRequestProcessor.java#L215 Ping = -2, + + /// Fortunately, ZooKeeper server [use xid from header](auth-xid) to reply auth request, so we can have + /// multiple auth requets in network. + /// + /// auth-xid: https://github.com/apache/zookeeper/blob/de7c5869d372e46af43979134d0e30b49d2319b1/zookeeper-server/src/main/java/org/apache/zookeeper/server/ZooKeeperServer.java#L1621 Auth = -4, SetWatches = -8, } diff --git a/src/proto/request_header.rs b/src/proto/request_header.rs index eb54bb0..d4382a4 100644 --- a/src/proto/request_header.rs +++ b/src/proto/request_header.rs @@ -13,7 +13,6 @@ impl RequestHeader { pub fn with_code(code: OpCode) -> RequestHeader { let xid = match code { OpCode::Ping => PredefinedXid::Ping.into(), - OpCode::Auth => PredefinedXid::Auth.into(), OpCode::SetWatches | OpCode::SetWatches2 => PredefinedXid::SetWatches.into(), _ => 0, }; diff --git a/src/session/depot.rs b/src/session/depot.rs index cc56c56..84a7771 100644 --- a/src/session/depot.rs +++ b/src/session/depot.rs @@ -4,25 +4,23 @@ use std::io::{self, IoSlice}; use hashbrown::HashMap; use strum::IntoEnumIterator; use tokio::net::TcpStream; -use tokio::sync::oneshot; -use super::request::{self, MarshalledRequest, Operation, SessionOperation, StateResponser}; +use super::request::{MarshalledRequest, Operation, SessionOperation, StateResponser}; use super::types::WatchMode; use super::xid::Xid; use super::SessionId; use crate::error::Error; -use crate::proto::{AuthPacket, OpCode, RemoveWatchesRequest}; - -pub type AuthResponser = oneshot::Sender>; +use crate::proto::{OpCode, PredefinedXid, RemoveWatchesRequest}; #[derive(Default)] pub struct Depot { xid: Xid, + pending_authes: Vec, + writing_slices: Vec>, writing_operations: VecDeque, - written_operations: VecDeque, - pending_auth: Option<(AuthPacket, AuthResponser)>, + written_operations: HashMap, watching_paths: HashMap<(&'static str, WatchMode), usize>, unwatching_paths: HashMap<(&'static str, WatchMode), SessionOperation>, @@ -33,10 +31,10 @@ impl Depot { let writing_capacity = 128usize; Depot { xid: Default::default(), + pending_authes: Vec::with_capacity(5), writing_slices: Vec::with_capacity(writing_capacity), writing_operations: VecDeque::with_capacity(writing_capacity), - written_operations: VecDeque::with_capacity(128), - pending_auth: None, + written_operations: HashMap::with_capacity(128), watching_paths: HashMap::with_capacity(32), unwatching_paths: HashMap::with_capacity(32), } @@ -45,28 +43,40 @@ impl Depot { pub fn for_connecting() -> Depot { Depot { xid: Default::default(), + pending_authes: Default::default(), writing_slices: Vec::with_capacity(10), writing_operations: VecDeque::with_capacity(10), - written_operations: VecDeque::with_capacity(10), - pending_auth: None, + written_operations: HashMap::with_capacity(10), watching_paths: HashMap::new(), unwatching_paths: HashMap::new(), } } + /// Clear all buffered operations from previous run. pub fn clear(&mut self) { + self.pending_authes.clear(); self.writing_slices.clear(); self.watching_paths.clear(); + self.unwatching_paths.clear(); self.writing_operations.clear(); self.written_operations.clear(); } - pub fn error(&mut self, err: Error) { - self.written_operations.drain(..).for_each(|operation| { + /// Error out ongoing operations except authes. + pub fn error(&mut self, err: &Error) { + self.written_operations.drain().for_each(|(_, operation)| { + if operation.request.get_code() == OpCode::Auth { + self.pending_authes.push(operation); + return; + } operation.responser.send(Err(err.clone())); }); self.writing_operations.drain(..).for_each(|operation| { if let Operation::Session(operation) = operation { + if operation.request.get_code() == OpCode::Auth { + self.pending_authes.push(operation); + return; + } operation.responser.send(Err(err.clone())); } }); @@ -77,42 +87,28 @@ impl Depot { self.watching_paths.clear(); } - pub fn is_empty(&self) -> bool { - self.writing_operations.is_empty() && self.written_operations.is_empty() - } - - pub fn pop_pending_auth(&mut self) -> Option<(AuthPacket, AuthResponser)> { - self.pending_auth.take() + /// Terminate all ongoing operations including authes. + pub fn terminate(&mut self, err: Error) { + self.error(&err); + for SessionOperation { responser, .. } in self.pending_authes.drain(..) { + responser.send(Err(err.clone())); + } } - pub fn has_pending_auth(&self) -> bool { - self.pending_auth.is_some() + /// Check whether there is any ongoing operations. + pub fn is_empty(&self) -> bool { + self.writing_operations.is_empty() && self.written_operations.is_empty() } - pub fn pop_reqeust(&mut self, xid: i32) -> Result { - match self.written_operations.pop_front() { + pub fn pop_request(&mut self, xid: i32) -> Result { + match self.written_operations.remove(&xid) { None => Err(Error::UnexpectedError(format!("recv response with xid {} but no pending request", xid))), - Some(operation) => { - let request_xid = operation.request.get_xid(); - if xid == request_xid { - return Ok(operation); - } - self.written_operations.push_front(operation); - Err(Error::UnexpectedError(format!("expect response xid {} but got {}", xid, request_xid))) - }, + Some(operation) => Ok(operation), } } pub fn pop_ping(&mut self) -> Result<(), Error> { - if let Some(operation) = self.written_operations.pop_front() { - let op_code = operation.request.get_code(); - if op_code != OpCode::Ping { - self.written_operations.push_front(operation); - return Err(Error::UnexpectedError(format!("expect pending ping request, got {}", op_code))); - } - return Ok(()); - } - Err(Error::UnexpectedError("expect pending ping request, got none".to_string())) + self.pop_request(PredefinedXid::Ping.into()).map(|_| ()) } pub fn push_operation(&mut self, operation: Operation) { @@ -126,9 +122,11 @@ impl Depot { } pub fn start(&mut self) { - if let Some((auth, responser)) = self.pending_auth.take() { - self.push_auth(auth, responser); + let mut pending_authes = std::mem::take(&mut self.pending_authes); + for operation in pending_authes.drain(..) { + self.push_session(operation); } + self.pending_authes = pending_authes; } fn cancel_unwatch(&mut self, path: &'static str, mode: WatchMode) { @@ -195,12 +193,6 @@ impl Depot { .any(|mode| self.watching_paths.contains_key(&(path, mode))) } - pub fn push_auth(&mut self, auth: AuthPacket, responser: AuthResponser) { - let operation = request::build_auth_operation(OpCode::Auth, &auth); - self.pending_auth = Some((auth, responser)); - self.push_operation(Operation::Auth(operation)); - } - pub fn write_operations(&mut self, sock: &TcpStream, session_id: SessionId) -> Result<(), Error> { let result = sock.try_write_vectored(self.writing_slices.as_slice()); let mut written_bytes = match result { @@ -226,13 +218,18 @@ impl Depot { .unwrap_or(self.writing_slices.len()); if written_slices != 0 { self.writing_slices.drain(..written_slices); - let written = self.writing_operations.drain(..written_slices).filter_map(|operation| { - if let Operation::Session(operation) = operation { - return Some(operation); - } - None - }); - self.written_operations.extend(written); + self.writing_operations + .drain(..written_slices) + .filter_map(|operation| { + if let Operation::Session(operation) = operation { + return Some(operation); + } + None + }) + .for_each(|operation| { + let xid = operation.request.get_xid(); + self.written_operations.insert(xid, operation); + }); } if written_bytes != 0 { let (_, rest) = self.writing_slices[0].split_at(written_bytes); diff --git a/src/session/mod.rs b/src/session/mod.rs index 1e78c83..71c6242 100644 --- a/src/session/mod.rs +++ b/src/session/mod.rs @@ -14,25 +14,39 @@ use tokio::select; use tokio::sync::mpsc; use tokio::time::{self, Instant}; -pub use self::depot::{AuthResponser, Depot}; +pub use self::depot::Depot; use self::event::WatcherEvent; -pub use self::request::{build_state_operation, MarshalledRequest, Operation, SessionOperation, StateResponser}; +pub use self::request::{ + ConnectOperation, + MarshalledRequest, + Operation, + SessionOperation, + StateReceiver, + StateResponser, +}; pub use self::types::{EventType, SessionId, SessionState, WatchedEvent}; pub use self::watch::{OneshotReceiver, PersistentReceiver, WatchReceiver}; use self::watch::{WatchManager, WatcherId}; use crate::error::Error; -use crate::proto::{ - AuthPacket, - ConnectRequest, - ConnectResponse, - ErrorCode, - OpCode, - PredefinedXid, - ReplyHeader, - RequestHeader, -}; +use crate::proto::{AuthPacket, ConnectRequest, ConnectResponse, ErrorCode, OpCode, PredefinedXid, ReplyHeader}; use crate::record; +trait RequestOperation { + fn into_responser(self) -> StateResponser; +} + +impl RequestOperation for SessionOperation { + fn into_responser(self) -> StateResponser { + self.responser + } +} + +impl RequestOperation for (WatcherId, StateResponser) { + fn into_responser(self) -> StateResponser { + self.1 + } +} + pub struct Session { timeout: Duration, readonly: bool, @@ -51,7 +65,7 @@ pub struct Session { pub session_password: Vec, session_readonly: bool, - pub authes: Vec, + pub authes: Vec, state_sender: tokio::sync::watch::Sender, watch_manager: WatchManager, @@ -61,7 +75,7 @@ pub struct Session { impl Session { pub fn new( timeout: Duration, - authes: Vec, + authes: &[AuthPacket], readonly: bool, ) -> (Session, tokio::sync::watch::Receiver) { let (state_sender, state_receiver) = tokio::sync::watch::channel(SessionState::Disconnected); @@ -85,7 +99,7 @@ impl Session { session_password: Vec::with_capacity(16), session_readonly: false, - authes, + authes: authes.iter().map(|auth| MarshalledRequest::new_request(OpCode::Auth, auth)).collect(), state_sender, watch_manager, unwatch_receiver: Some(unwatch_receiver), @@ -94,10 +108,11 @@ impl Session { (session, state_receiver) } - async fn quit(&mut self, mut requester: mpsc::Receiver, err: &Error) { + async fn close_requester(mut requester: mpsc::UnboundedReceiver, err: &Error) { requester.close(); while let Some(operation) = requester.recv().await { - operation.responser.send(Err(err.clone())); + let responser = operation.into_responser(); + responser.send(Err(err.clone())); } } @@ -107,12 +122,11 @@ impl Session { sock: TcpStream, mut buf: Vec, mut connecting_trans: Depot, - mut requester: mpsc::Receiver, - mut auth_requester: mpsc::Receiver<(AuthPacket, AuthResponser)>, + mut requester: mpsc::UnboundedReceiver, ) { let mut depot = Depot::for_serving(); let mut unwatch_requester = self.unwatch_receiver.take().unwrap(); - self.serve_once(sock, &mut buf, &mut depot, &mut requester, &mut auth_requester, &mut unwatch_requester).await; + self.serve_once(sock, &mut buf, &mut depot, &mut requester, &mut unwatch_requester).await; while !self.session_state.is_terminated() { let mut hosts = servers.iter().map(|(host, port)| (host.as_str(), *port)); let sock = match self.start(&mut hosts, &mut buf, &mut connecting_trans).await { @@ -123,14 +137,12 @@ impl Session { }, Ok(sock) => sock, }; - self.serve_once(sock, &mut buf, &mut depot, &mut requester, &mut auth_requester, &mut unwatch_requester) - .await; + self.serve_once(sock, &mut buf, &mut depot, &mut requester, &mut unwatch_requester).await; } let err = self.state_error(); - self.quit(requester, &err).await; - if let Some((_, responser)) = depot.pop_pending_auth() { - responser.send(Err(err)).ignore(); - } + Self::close_requester(requester, &err).await; + Self::close_requester(unwatch_requester, &err).await; + depot.terminate(err); } fn state_error(&self) -> Error { @@ -170,14 +182,13 @@ impl Session { sock: TcpStream, buf: &mut Vec, depot: &mut Depot, - requester: &mut mpsc::Receiver, - auth_requester: &mut mpsc::Receiver<(AuthPacket, AuthResponser)>, + requester: &mut mpsc::UnboundedReceiver, unwatch_requester: &mut mpsc::UnboundedReceiver<(WatcherId, StateResponser)>, ) { - if let Err(err) = self.serve_session(&sock, buf, depot, requester, auth_requester, unwatch_requester).await { + if let Err(err) = self.serve_session(&sock, buf, depot, requester, unwatch_requester).await { self.resolve_serve_error(&err); log::debug!("ZooKeeper session {} state {} error {}", self.session_id, self.session_state, err); - depot.error(err); + depot.error(&err); } else { self.change_state(SessionState::Disconnected); self.change_state(SessionState::Closed); @@ -231,7 +242,13 @@ impl Session { }; let SessionOperation { responser, request, .. } = operation; let (op_code, watcher) = self.handle_session_watcher(&request, error_code, depot); - if error_code == ErrorCode::Ok || (error_code == ErrorCode::NoNode && op_code == OpCode::Exists) { + if error_code == ErrorCode::Ok || (op_code == OpCode::Exists && error_code == ErrorCode::NoNode) { + if op_code == OpCode::Auth { + if responser.send_empty() { + self.authes.push(request); + } + return; + } let mut buf = request.0; buf.clear(); buf.extend_from_slice(body); @@ -248,13 +265,7 @@ impl Session { } else if header.err == ErrorCode::AuthFailed.into() { return Err(Error::AuthFailed); } - if header.xid == PredefinedXid::Auth.into() { - if let Some((auth, responser)) = depot.pop_pending_auth() { - self.authes.push(auth); - responser.send(Ok(())).ignore(); - } - return Ok(()); - } else if header.xid == PredefinedXid::Notification.into() { + if header.xid == PredefinedXid::Notification.into() { self.handle_notification(body, depot)?; return Ok(()); } else if header.xid == PredefinedXid::Ping.into() { @@ -265,7 +276,7 @@ impl Session { } return Ok(()); } - let operation = depot.pop_reqeust(header.xid)?; + let operation = depot.pop_request(header.xid)?; self.handle_session_reply(operation, header.err, body, depot); Ok(()) } @@ -364,14 +375,12 @@ impl Session { sock: &TcpStream, buf: &mut Vec, depot: &mut Depot, - requester: &mut mpsc::Receiver, - auth_requester: &mut mpsc::Receiver<(AuthPacket, AuthResponser)>, + requester: &mut mpsc::UnboundedReceiver, unwatch_requester: &mut mpsc::UnboundedReceiver<(WatcherId, StateResponser)>, ) -> Result<(), Error> { let mut tick = time::interval(self.tick_timeout); tick.set_missed_tick_behavior(time::MissedTickBehavior::Skip); let mut channel_closed = false; - let mut auth_closed = false; depot.start(); while !(channel_closed && depot.is_empty()) { select! { @@ -394,10 +403,6 @@ impl Session { depot.write_operations(sock, self.session_id)?; self.last_send = Instant::now(); }, - r = auth_requester.recv(), if !auth_closed && !depot.has_pending_auth() => match r { - Some((auth, responser)) => depot.push_auth(auth, responser), - None => auth_closed = true, - }, r = unwatch_requester.recv() => if let Some((watcher_id, responser)) = r { self.watch_manager.remove_watcher(watcher_id, responser, depot); }, @@ -439,8 +444,7 @@ impl Session { } fn send_ping(&mut self, depot: &mut Depot, now: Instant) { - let header = RequestHeader::with_code(OpCode::Ping); - let operation = request::build_session_operation(&header); + let operation = SessionOperation::new_without_body(OpCode::Ping); depot.push_operation(Operation::Session(operation)); self.last_send = now; self.last_ping = Some(self.last_send); @@ -455,14 +459,14 @@ impl Session { password: self.session_password.as_slice(), readonly: self.readonly, }; - let operation = request::build_connect_operation(&request); + let operation = ConnectOperation::new(&request); depot.push_operation(Operation::Connect(operation)); } fn send_authes(&self, depot: &mut Depot) { self.authes.iter().for_each(|auth| { - let operation = request::build_auth_operation(OpCode::Auth, auth); - depot.push_operation(Operation::Auth(operation)); + let operation = SessionOperation::from(auth.clone()); + depot.push_session(operation); }); } diff --git a/src/session/request.rs b/src/session/request.rs index 076f326..174e6ec 100644 --- a/src/session/request.rs +++ b/src/session/request.rs @@ -1,3 +1,7 @@ +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; + use bytes::{Buf, BufMut}; use ignore_result::Ignore; use tokio::sync::oneshot; @@ -79,7 +83,6 @@ impl MarshalledRequest { pub enum Operation { Connect(ConnectOperation), - Auth(AuthOperation), Session(SessionOperation), } @@ -88,7 +91,6 @@ impl Operation { match self { Operation::Connect(operation) => operation.request.as_slice(), Operation::Session(operation) => operation.request.as_slice(), - Operation::Auth(operation) => operation.request.as_slice(), } } } @@ -97,11 +99,14 @@ pub struct ConnectOperation { pub request: Vec, } -pub struct AuthOperation { - pub request: MarshalledRequest, +impl ConnectOperation { + pub fn new(request: &ConnectRequest) -> Self { + let buf = proto::build_record_request(request); + Self { request: buf } + } } -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct MarshalledRequest(pub Vec); #[derive(Debug)] @@ -110,7 +115,62 @@ pub struct SessionOperation { pub responser: StateResponser, } -pub type StateReceiver = oneshot::Receiver, WatchReceiver), Error>>; +impl SessionOperation { + pub fn new(code: OpCode, body: &dyn Record) -> Self { + let request = MarshalledRequest::new_request(code, body); + Self { request, responser: Default::default() } + } + + pub fn new_without_body(code: OpCode) -> Self { + let header = RequestHeader::with_code(code); + let request = MarshalledRequest::new_record(&header); + Self { request, responser: StateResponser::default() } + } + + pub fn with_responser(self) -> (Self, StateReceiver) { + let (sender, receiver) = oneshot::channel(); + let request = self.request; + let code = request.get_code(); + let operation = Self { request, responser: StateResponser::new(sender) }; + (operation, StateReceiver { code, receiver }) + } +} + +impl From for SessionOperation { + fn from(request: MarshalledRequest) -> Self { + SessionOperation { request, responser: StateResponser::none() } + } +} + +pub struct StateReceiver { + code: OpCode, + receiver: oneshot::Receiver, WatchReceiver), Error>>, +} + +impl StateReceiver { + pub fn new(code: OpCode, receiver: oneshot::Receiver, WatchReceiver), Error>>) -> Self { + Self { code, receiver } + } +} + +impl Future for StateReceiver { + type Output = Result<(Vec, WatchReceiver), Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let code = self.code; + let receiver = unsafe { self.map_unchecked_mut(|r| &mut r.receiver) }; + match receiver.poll(cx) { + Poll::Pending => Poll::Pending, + Poll::Ready(result) => match result { + Err(_) => { + Poll::Ready(Err(Error::UnexpectedError(format!("BUG: {} expect response, but got none", code)))) + }, + Ok(r) => Poll::Ready(r), + }, + } + } +} + type StateSender = oneshot::Sender, WatchReceiver), Error>>; #[derive(Default, Debug)] @@ -125,35 +185,19 @@ impl StateResponser { StateResponser(None) } - pub fn send(mut self, result: Result<(Vec, WatchReceiver), Error>) { + pub fn is_none(&self) -> bool { + self.0.is_none() + } + + pub fn send(mut self, result: Result<(Vec, WatchReceiver), Error>) -> bool { if let Some(sender) = self.0.take() { sender.send(result).ignore(); + return true; } + false } - pub fn send_empty(self) { - self.send(Ok((Vec::new(), WatchReceiver::None))); + pub fn send_empty(self) -> bool { + self.send(Ok((Vec::new(), WatchReceiver::None))) } } - -pub fn build_connect_operation(request: &ConnectRequest) -> ConnectOperation { - let buf = proto::build_record_request(request); - ConnectOperation { request: buf } -} - -pub fn build_auth_operation(code: OpCode, body: &dyn Record) -> AuthOperation { - let request = MarshalledRequest::new_request(code, body); - AuthOperation { request } -} - -pub fn build_state_operation(code: OpCode, body: &dyn Record) -> (SessionOperation, StateReceiver) { - let request = MarshalledRequest::new_request(code, body); - let (sender, receiver) = oneshot::channel(); - let operation = SessionOperation { request, responser: StateResponser::new(sender) }; - (operation, receiver) -} - -pub fn build_session_operation(request: &dyn Record) -> SessionOperation { - let request = MarshalledRequest::new_record(request); - SessionOperation { request, responser: StateResponser::default() } -} diff --git a/src/session/watch.rs b/src/session/watch.rs index d01b256..79c62b6 100644 --- a/src/session/watch.rs +++ b/src/session/watch.rs @@ -6,7 +6,7 @@ use tokio::sync::{mpsc, oneshot}; use super::depot::Depot; use super::event::WatcherEvent; -use super::request::{self, Operation, StateResponser}; +use super::request::{Operation, SessionOperation, StateReceiver, StateResponser}; use super::types::{EventType, SessionState, WatchMode, WatchedEvent}; use crate::error::Error; use crate::proto::{ErrorCode, OpCode, SetWatchesRequest}; @@ -71,7 +71,8 @@ impl OneshotReceiver { let unwatch = unsafe { self.into_unwatch() }; let (sender, receiver) = oneshot::channel(); unwatch.send((id, StateResponser::new(sender))).ignore(); - receiver.await.unwrap()?; + let receiver = StateReceiver::new(OpCode::RemoveWatches, receiver); + receiver.await?; Ok(()) } } @@ -114,7 +115,8 @@ impl PersistentReceiver { let unwatch = unsafe { self.into_unwatch() }; let (sender, receiver) = oneshot::channel(); unwatch.send((id, StateResponser::new(sender))).ignore(); - receiver.await.unwrap()?; + let receiver = StateReceiver::new(OpCode::RemoveWatches, receiver); + receiver.await?; Ok(()) } } @@ -396,7 +398,7 @@ impl WatchManager { fn send_and_clear_watches(&self, last_zxid: i64, paths: &mut [Vec<&str>; 5], i: usize, depot: &mut Depot) { let (n, op_code) = if i <= 2 { (3, OpCode::SetWatches) } else { (5, OpCode::SetWatches2) }; let request = SetWatchesRequest { relative_zxid: last_zxid, paths: &paths[..n] }; - let (operation, _) = request::build_state_operation(op_code, &request); + let operation = SessionOperation::new(op_code, &request); depot.push_operation(Operation::Session(operation)); paths[..=i].iter_mut().for_each(|v| v.clear()); } diff --git a/tests/zookeeper.rs b/tests/zookeeper.rs index 9518040..12a2d65 100644 --- a/tests/zookeeper.rs +++ b/tests/zookeeper.rs @@ -4,6 +4,7 @@ use futures::future; use pretty_assertions::assert_eq; use rand::distributions::Standard; use rand::{self, Rng}; +use speculoos::prelude::*; use testcontainers::clients::Cli as DockerCli; use testcontainers::core::{Healthcheck, WaitFor}; use testcontainers::images::generic::GenericImage; @@ -24,7 +25,7 @@ fn zookeeper_image() -> GenericImage { .with_healthcheck(healthcheck) .with_env_var( "SERVER_JVMFLAGS", - "-Dzookeeper.DigestAuthenticationProvider.superDigest=super:D/InIHSb7yEEbrWz8b9l71RjZJU=", + "-Dzookeeper.DigestAuthenticationProvider.superDigest=super:D/InIHSb7yEEbrWz8b9l71RjZJU= -Dzookeeper.enableEagerACLCheck=true", ) .with_wait_for(WaitFor::Healthcheck) } @@ -104,6 +105,40 @@ async fn test_no_node() { ); } +#[tokio::test] +async fn test_request_order() { + let docker = DockerCli::default(); + let zookeeper = docker.run(zookeeper_image()); + let zk_port = zookeeper.get_host_port(2181); + + let cluster = format!("127.0.0.1:{}", zk_port); + let client = zk::Client::connect(&cluster, Duration::from_secs(20)).await.unwrap(); + + let create_options = zk::CreateOptions::new(zk::CreateMode::Persistent, zk::Acl::anyone_all()); + + let path = "/abc"; + let child_path = "/abc/efg"; + + let create = client.create(path, Default::default(), &create_options); + let get_data = client.get_and_watch_children(path); + let get_child_data = client.get_data(child_path); + let (child_stat, _) = client.create(child_path, Default::default(), &create_options).await.unwrap(); + let (stat, _) = create.await.unwrap(); + + assert_that!(child_stat.czxid).is_greater_than(stat.czxid); + + let (children, stat1, watcher) = get_data.await.unwrap(); + + assert_that!(children).is_empty(); + assert_that!(stat1).is_equal_to(stat); + + let child_event = watcher.changed().await; + assert_that!(child_event.event_type).is_equal_to(zk::EventType::NodeChildrenChanged); + assert_that!(child_event.path).is_equal_to(path.to_owned()); + + assert_that!(get_child_data.await).is_equal_to(Err(zk::Error::NoNode)); +} + #[tokio::test] async fn test_data_node() { let docker = DockerCli::default();