diff --git a/src/protocol/resp/src/request/mod.rs b/src/protocol/resp/src/request/mod.rs index 534ccebc..ffe3d5c7 100644 --- a/src/protocol/resp/src/request/mod.rs +++ b/src/protocol/resp/src/request/mod.rs @@ -12,10 +12,12 @@ use std::sync::Arc; mod badd; mod get; +mod ping; mod set; pub use badd::BAddRequest; pub use get::GetRequest; +pub use ping::PingRequest; pub use set::SetRequest; #[derive(Default)] @@ -98,6 +100,9 @@ impl Parse for RequestParser { Some(b"set") | Some(b"SET") => { SetRequest::try_from(message).map(Request::from) } + Some(b"ping") | Some(b"PING") => { + PingRequest::try_from(message).map(Request::from) + } _ => Err(Error::new(ErrorKind::Other, "unknown command")), }, _ => { @@ -121,6 +126,7 @@ impl Compose for Request { Self::BAdd(r) => r.compose(buf), Self::Get(r) => r.compose(buf), Self::Set(r) => r.compose(buf), + Self::Ping(r) => r.compose(buf), } } } @@ -130,6 +136,7 @@ pub enum Request { BAdd(BAddRequest), Get(GetRequest), Set(SetRequest), + Ping(PingRequest), } impl From for Request { @@ -150,11 +157,18 @@ impl From for Request { } } +impl From for Request { + fn from(other: PingRequest) -> Self { + Self::Ping(other) + } +} + #[derive(Debug, PartialEq, Eq)] pub enum Command { BAdd, Get, Set, + Ping, } impl TryFrom<&[u8]> for Command { @@ -165,6 +179,7 @@ impl TryFrom<&[u8]> for Command { b"badd" | b"BADD" => Ok(Command::BAdd), b"get" | b"GET" => Ok(Command::Get), b"set" | b"SET" => Ok(Command::Set), + b"ping" | b"PING" => Ok(Command::Ping), _ => Err(()), } } diff --git a/src/protocol/resp/src/request/ping.rs b/src/protocol/resp/src/request/ping.rs new file mode 100644 index 00000000..5be8eb26 --- /dev/null +++ b/src/protocol/resp/src/request/ping.rs @@ -0,0 +1,59 @@ +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use super::*; +use std::io::{Error, ErrorKind}; + +#[derive(Debug, PartialEq, Eq)] +#[allow(clippy::redundant_allocation)] +pub struct PingRequest {} + +impl TryFrom for PingRequest { + type Error = Error; + + fn try_from(other: Message) -> Result { + if let Message::Array(array) = other { + if array.inner.is_none() { + return Err(Error::new(ErrorKind::Other, "malformed command")); + } + Ok(Self {}) + } else { + Err(Error::new(ErrorKind::Other, "malformed command")) + } + } +} + +impl PingRequest { + pub fn new() -> Self { + Self {} + } +} + +impl From<&PingRequest> for Message { + fn from(_: &PingRequest) -> Message { + Message::Array(Array { + inner: Some(vec![Message::BulkString(BulkString::new(b"Ping"))]), + }) + } +} + +impl Compose for PingRequest { + fn compose(&self, buf: &mut dyn BufMut) -> usize { + let message = Message::from(self); + message.compose(buf) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parser() { + let parser = RequestParser::new(); + assert_eq!( + parser.parse(b"PING\r\n").unwrap().into_inner(), + Request::Ping(PingRequest::new()) + ); + } +} diff --git a/src/proxy/momento/src/frontend.rs b/src/proxy/momento/src/frontend.rs index 55a7729c..7701fd02 100644 --- a/src/proxy/momento/src/frontend.rs +++ b/src/proxy/momento/src/frontend.rs @@ -102,6 +102,11 @@ pub(crate) async fn handle_resp_client( break; } } + resp::Request::Ping(_) => { + if resp::ping(&mut socket).await.is_err() { + break; + } + } _ => { println!("bad request"); let _ = socket.write_all(b"CLIENT_ERROR\r\n").await; diff --git a/src/proxy/momento/src/protocol/resp/mod.rs b/src/proxy/momento/src/protocol/resp/mod.rs index fb65889e..ed1b7e04 100644 --- a/src/proxy/momento/src/protocol/resp/mod.rs +++ b/src/proxy/momento/src/protocol/resp/mod.rs @@ -5,7 +5,9 @@ pub use protocol_resp::{Request, RequestParser}; mod get; +mod ping; mod set; pub use get::*; +pub use ping::*; pub use set::*; diff --git a/src/proxy/momento/src/protocol/resp/ping.rs b/src/proxy/momento/src/protocol/resp/ping.rs new file mode 100644 index 00000000..3cac6541 --- /dev/null +++ b/src/proxy/momento/src/protocol/resp/ping.rs @@ -0,0 +1,22 @@ +// Licensed under the Apache License, Version 2.0 +// http://www.apache.org/licenses/LICENSE-2.0 + +use crate::Error; +use net::TCP_SEND_BYTE; +use session::{SESSION_SEND, SESSION_SEND_BYTE, SESSION_SEND_EX}; +use tokio::io::AsyncWriteExt; + +const PONG_RSP: &[u8; 7] = b"+PONG\r\n"; + +pub async fn ping(socket: &mut tokio::net::TcpStream) -> Result<(), Error> { + let mut response_buf = Vec::new(); + response_buf.extend_from_slice(PONG_RSP); + SESSION_SEND.increment(); + SESSION_SEND_BYTE.add(response_buf.len() as _); + TCP_SEND_BYTE.add(response_buf.len() as _); + if let Err(e) = socket.write_all(&response_buf).await { + SESSION_SEND_EX.increment(); + return Err(e); + } + Ok(()) +}