diff --git a/crates/fleetspeak/Cargo.toml b/crates/fleetspeak/Cargo.toml index f5adaa3..b062ab2 100644 --- a/crates/fleetspeak/Cargo.toml +++ b/crates/fleetspeak/Cargo.toml @@ -24,4 +24,4 @@ protobuf = { workspace = true } libc = { version = "0.2.147" } [target.'cfg(target_family = "windows")'.dependencies] -windows-sys = { version = "0.48.0", features = ["Win32_Foundation", "Win32_Storage_FileSystem"] } +windows-sys = { version = "0.48.0", features = ["Win32_Foundation", "Win32_Storage_FileSystem", "Win32_System_IO"] } diff --git a/crates/fleetspeak/src/io.rs b/crates/fleetspeak/src/io.rs index 1ac5fcb..54bb306 100644 --- a/crates/fleetspeak/src/io.rs +++ b/crates/fleetspeak/src/io.rs @@ -9,6 +9,56 @@ use byteorder::{LittleEndian, ReadBytesExt as _, WriteBytesExt as _}; use crate::Message; +#[cfg(target_family = "unix")] +mod unix; + +#[cfg(target_family = "windows")] +mod windows; + +mod sys { + #[cfg(target_family = "unix")] + pub use crate::io::unix::*; + + #[cfg(target_family = "windows")] + pub use crate::io::windows::*; +} + +pub use self::sys::{ + CommsInRaw, + CommsOutRaw, +}; + +/// An error returned in case instantiating communicaton channels fails. +#[derive(Clone, Debug)] +pub struct CommsEnvError { + repr: CommsEnvErrorRepr, +} + +#[derive(Clone, Debug)] +enum CommsEnvErrorRepr { + /// Communication channel is not specified in the environment. + NotSpecified, + /// Communication channel specified in the environment is not valid. + NotParsable(std::ffi::OsString), +} + +impl std::fmt::Display for CommsEnvError { + + fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.repr { + CommsEnvErrorRepr::NotSpecified => { + write!(fmt, "communication channel not specified") + } + CommsEnvErrorRepr::NotParsable(value) => { + write!(fmt, "invalid communication channel value: {value:?}") + } + } + } +} + +impl std::error::Error for CommsEnvError { +} + /// Executes the handshake procedure. /// /// The handshake procedure consists of writing and reading magic numbers from diff --git a/crates/fleetspeak/src/io/unix.rs b/crates/fleetspeak/src/io/unix.rs new file mode 100644 index 0000000..a1598e6 --- /dev/null +++ b/crates/fleetspeak/src/io/unix.rs @@ -0,0 +1,129 @@ +// Copyright 2024 Google LLC +// +// Use of this source code is governed by an MIT-style license that can be found +// in the LICENSE file or at https://opensource.org/licenses/MIT. + +use super::{CommsEnvError, CommsEnvErrorRepr}; + +/// Alternative for [`std::io::Stdin`] for communicating with Fleetspeak. +/// +/// Reading from this communication channel is not synchronized nor buffered. +pub struct CommsInRaw { + /// File descriptor of the input channel passeed by the Fleetspeak process. + fd: libc::c_int, +} + +/// Alternative for [`std::io::Stdout`] for communicating with Fleetspeak. +/// +/// Writing to this communication channel is not synchronized nor buffered. +pub struct CommsOutRaw { + /// File descriptor of the output channel passeed by the Fleetspeak process. + fd: libc::c_int, +} + +impl CommsInRaw { + + /// Returns a [`CommsIn`] instance given by the parent Fleetspeak process. + pub fn from_env() -> Result { + Ok(CommsInRaw { + fd: env_var_fd("FLEETSPEAK_COMMS_CHANNEL_INFD")?, + }) + } +} + +impl CommsOutRaw { + + /// Returns a [`CommsOut`] instance given by the parent Fleetspeak process. + pub fn from_env() -> Result { + Ok(CommsOutRaw { + fd: env_var_fd("FLEETSPEAK_COMMS_CHANNEL_OUTFD")?, + }) + } +} + +impl std::io::Read for CommsInRaw { + + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + // SAFETY: We do not have any assumptions on `self.fd`. We usually want + // it to be a valid file descriptor but since it is passed to us from + // the parent process, we cannot guarantee that it actually is. + // + // However, there is no unsafety here: in case we are not allowed to do + // a read operation on this supposed descriptor, it will simply fail + // (e.g. with `EBADF` if this is not actually a descriptor). + // + // The rest is just a function call as described in the docs [1, 2]: we + // pass a valid buffer and the number of bytes that we want to read + // (which is equal to the length of the buffer). We verify the result + // afterwards. + // + // [1]: https://man7.org/linux/man-pages/man2/read.2.html + // [2]: https://pubs.opengroup.org/onlinepubs/009604599/functions/read.html + let count = unsafe { + libc::read(self.fd, buf.as_mut_ptr().cast(), buf.len()) + }; + + if count < 0 { + return Err(std::io::Error::last_os_error()); + } + + Ok(count as usize) + } +} + +impl std::io::Write for CommsOutRaw { + + fn write(&mut self, buf: &[u8]) -> std::io::Result { + // SAFETY: We do not have any assumptions on `self.fd`. We usually want + // it to be a valid file descriptor but since it is passed to us from + // the parent process, we cannot guarantee that it actually is. + // + // However, there is no unsafety here: in case we are not allowed to do + // a write operation on this supposed descriptor, it will simply fail + // (e.g. with `EBADF` if this is not actually a descriptor). + // + // The rest is just a function call as described in the docs [1, 2]: we + // pass a valid buffer and the number of bytes that we want to write + // (which is equal to the length of the buffer). We verify the result + // afterwards. + // + // [1]: https://man7.org/linux/man-pages/man2/write.2.html + // [2]: https://pubs.opengroup.org/onlinepubs/9699919799/functions/write.html + let count = unsafe { + libc::write(self.fd, buf.as_ptr().cast(), buf.len()) + }; + + if count < 0 { + return Err(std::io::Error::last_os_error()); + } + + Ok(count as usize) + } + + fn flush(&mut self) -> std::io::Result<()> { + // We use `libc::write` for writing data which is not buffered, there + // is nothing to flush. + Ok(()) + } +} + +/// Retrieves a file descriptor specified in the given environment variable. +fn env_var_fd(key: K) -> Result +where + K: AsRef, +{ + match std::env::var(key) { + Ok(fd) => match fd.parse::() { + Ok(fd) => Ok(fd), + Err(_) => Err(CommsEnvError { + repr: CommsEnvErrorRepr::NotParsable(fd.into()), + }), + } + Err(std::env::VarError::NotPresent) => Err(CommsEnvError { + repr: CommsEnvErrorRepr::NotSpecified, + }), + Err(std::env::VarError::NotUnicode(value)) => Err(CommsEnvError { + repr: CommsEnvErrorRepr::NotParsable(value), + }), + } +} diff --git a/crates/fleetspeak/src/io/windows.rs b/crates/fleetspeak/src/io/windows.rs new file mode 100644 index 0000000..4a3a257 --- /dev/null +++ b/crates/fleetspeak/src/io/windows.rs @@ -0,0 +1,189 @@ +// Copyright 2024 Google LLC +// +// Use of this source code is governed by an MIT-style license that can be found +// in the LICENSE file or at https://opensource.org/licenses/MIT. + +use super::{CommsEnvError, CommsEnvErrorRepr}; + +/// Alternative for [`std::io::Stdin`] for communicating with Fleetspeak. +/// +/// Reading from this communication channel is not synchronized nor buffered. +pub struct CommsInRaw { + /// File handle of the input channel passed by the Fleetspeak process. + handle: windows_sys::Win32::Foundation::HANDLE, +} + +/// Alternative for [`std::io::Stdout`] for communicating with Fleetspeak. +/// +/// Writing to this communication channel is not synchronized nor buffered. +pub struct CommsOutRaw { + /// File handle of the output channel passed by the Fleetspeak process. + handle: windows_sys::Win32::Foundation::HANDLE, +} + +impl CommsInRaw { + + /// Returns a [`CommsIn`] instance given by the parent Fleetspeak process. + pub fn from_env() -> Result { + Ok(CommsInRaw { + handle: env_var_handle("FLEETSPEAK_COMMS_CHANNEL_INFD")?, + }) + } +} + +impl CommsOutRaw { + + /// Returns a [`CommsOut`] instance given by the parent Fleetspeak process. + pub fn from_env() -> Result { + Ok(CommsOutRaw { + handle: env_var_handle("FLEETSPEAK_COMMS_CHANNEL_OUTFD")?, + }) + } +} + +impl std::io::Read for CommsInRaw { + + fn read(&mut self, buf: &mut [u8]) -> std::io::Result { + let buf_len = u32::try_from(buf.len()) + .map_err(|_| std::io::ErrorKind::InvalidInput)?; + + let mut count = std::mem::MaybeUninit::uninit(); + + // SAFETY: We do not have any assumptons on `self.handle`. We usually + // want it to be a valid file handle but since it is passed to us from + // the parent process, we cannot guarantee that it actually is. + // + // And this is why things are a bit fuzzy when it comes to safety: MSDN + // documentation for this function [1] does not explicitly mention what + // happens if we pass it an invalid handle. However, we know that there + // exists the `ERROR_INVALID_HANDLE` [2] error code and other functions + // are explicitly documented (e.g. `FlushFileBuffers` [3]) to return it + // in case the handle is invalid. Moreover, from empirical study we know + // that it is the case for `ReadFile` as well. + // + // The rest is just a function call as described in the docs: we pass a + // valid buffer and the number of bytes we want to read (which we first + // verify to fit the `u32` type required by the API). After the call we + // check whether it succeeded. + // + // [1]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-readfile + // [2]: https://learn.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- + // [3]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-flushfilebuffers + let status = unsafe { + windows_sys::Win32::Storage::FileSystem::ReadFile( + self.handle, + // TODO(@panhania): Upgrade `windows-sys` crate and remove cast. + buf.as_mut_ptr().cast::(), + buf_len, + count.as_mut_ptr(), + std::ptr::null_mut(), + ) + }; + + if status == windows_sys::Win32::Foundation::FALSE { + return Err(std::io::Error::last_os_error()); + } + + // SAFETY: We verified that the call to `ReadFile` succeeded and thus + // `count` is guaranteed to be initialized to the number of bytes that + // were read. + let count = unsafe { count.assume_init() }; + + Ok(count as usize) + } +} + +impl std::io::Write for CommsOutRaw { + + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let buf_len = u32::try_from(buf.len()) + .map_err(|_| std::io::ErrorKind::InvalidInput)?; + + let mut count = std::mem::MaybeUninit::uninit(); + + // SAFETY: We do not have any assumptons on `self.handle`. We usually + // want it to be a valid file handle but since it is passed to us from + // the parent process, we cannot guarantee that it actually is. + // + // And this is why things are a bit fuzzy when it comes to safety: MSDN + // documentation for this function [1] does not explicitly mention what + // happens if we pass it an invalid handle. However, we know that there + // exists the `ERROR_INVALID_HANDLE` [2] error code and other functions + // are explicitly documented (e.g. `FlushFileBuffers` [3]) to return it + // in case the handle is invalid. Moreover, from empirical study we know + // that it is the case for `WriteFile` as well. + // + // The rest is just a function call as described in the docs: we pass a + // valid buffer and the number of bytes we want to write (which we first + // verify to fit the `u32` type required by the API). After the call we + // check whether it succeeded. + // + // [1]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-writefile + // [2]: https://learn.microsoft.com/en-us/windows/win32/debug/system-error-codes--0-499- + // [3]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-flushfilebuffers + let status = unsafe { + windows_sys::Win32::Storage::FileSystem::WriteFile( + self.handle, + buf.as_ptr(), + buf_len, + count.as_mut_ptr(), + std::ptr::null_mut(), + ) + }; + + if status == windows_sys::Win32::Foundation::FALSE { + return Err(std::io::Error::last_os_error()); + } + + // SAFETY: We verified that the call to `WriteFile` succeeded and thus + // `count` is guaranteed to be initialized to the number of bytes that + // were written. + let count = unsafe { count.assume_init() }; + + Ok(count as usize) + } + + fn flush(&mut self) -> std::io::Result<()> { + // SAFETY: We do not have any assumptons on `self.handle`. We usually + // want it to be a valid file handle but since it is passed to use from + // the parent process, we cannot guarantee that it actually is. + // + // However, there is no unsafety here: in case the handle is not valid, + // this function will cause `ERROR_INVALID_HANDLE` to be raised [1]. We + // verify the status after the call. + // + // [1]: https://learn.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-flushfilebuffers + let status = unsafe { + windows_sys::Win32::Storage::FileSystem::FlushFileBuffers( + self.handle, + ) + }; + + if status == windows_sys::Win32::Foundation::FALSE { + return Err(std::io::Error::last_os_error()); + }; + + Ok(()) + } +} + +/// Retrieves a file handle specified in the given environment variable. +fn env_var_handle(key: K) -> Result +where + K: AsRef, +{ + match std::env::var(key) { + Ok(string) => match string.parse() { + Ok(handle) => Ok(handle), + Err(_) => Err(CommsEnvError { + repr: CommsEnvErrorRepr::NotParsable(string.into()), + }), + } + Err(std::env::VarError::NotPresent) => Err(CommsEnvError { + repr: CommsEnvErrorRepr::NotSpecified, + }), + Err(std::env::VarError::NotUnicode(string)) => Err(CommsEnvError { + repr: CommsEnvErrorRepr::NotParsable(string), + }), + } +} diff --git a/crates/fleetspeak/src/lib.rs b/crates/fleetspeak/src/lib.rs index 2794e1f..f34cbab 100644 --- a/crates/fleetspeak/src/lib.rs +++ b/crates/fleetspeak/src/lib.rs @@ -213,16 +213,27 @@ pub fn receive_with_heartbeat(rate: Duration) -> Message { /// sending heartbeat signals) when another thread might be busy with reading /// messages. struct Connection { - input: Mutex<&'static mut std::fs::File>, - output: Mutex<&'static mut std::fs::File>, + input: Mutex>, + output: Mutex>, } lazy_static! { static ref CONNECTION: Connection = { - let input = file_from_env_var("FLEETSPEAK_COMMS_CHANNEL_INFD"); - let output = file_from_env_var("FLEETSPEAK_COMMS_CHANNEL_OUTFD"); + let mut input = match crate::io::CommsInRaw::from_env() { + Ok(input) => std::io::BufReader::new(input), + Err(error) => { + panic!("invalid input communication channel: {error}"); + } + }; - crate::io::handshake(input, output) + let mut output = match crate::io::CommsOutRaw::from_env() { + Ok(output) => std::io::BufWriter::new(output), + Err(error) => { + panic!("invalid output commmunication channel: {error}"); + } + }; + + crate::io::handshake(&mut input, &mut output) .expect("handshake failure"); log::info!("handshake successful"); @@ -243,9 +254,9 @@ lazy_static! { /// /// Any I/O error returned by the executed function indicates a fatal connection /// failure and ends with a panic. -fn execute(mutex: &Mutex<&mut std::fs::File>, f: F) -> T +fn execute(mutex: &Mutex, f: F) -> T where - F: FnOnce(&mut std::fs::File) -> std::io::Result, + F: FnOnce(&mut C) -> std::io::Result, { let mut file = mutex.lock().expect("poisoned connection mutex"); match f(&mut file) { @@ -253,98 +264,3 @@ where Err(error) => panic!("connection failure: {}", error), } } - -/// Creates a [`File`] object specified in the given environment variable. -/// -/// Note that this function will panic if the environment variable `var` is not -/// a valid file descriptor (in which case the library cannot be initialized and -/// the service is unlikely to work anyway). -/// -/// This function returns a static mutable reference to ensure that the file is -/// never dropped. -/// -/// [`File`]: std::fs::File -fn file_from_env_var(var: &str) -> &'static mut std::fs::File { - let fd = std::env::var(var) - .expect(&format!("invalid variable `{}`", var)) - .parse() - .expect(&format!("failed to parse file descriptor")); - - // We run inside a critical section to guarantee that between verifying the - // descriptor and using `std::File::from_raw_*` functions (see the comments - // below) nothing closes it inbetween. - let mutex = Mutex::new(()); - let guard = mutex.lock() - .unwrap(); - - // SAFETY: `std::fs::File::from_raw_fd` requires the file descriptor to be - // valid and open. We verify this through `fcntl` and panic if the check - // fails. Because we run within a critical section we are sure that the file - // was not closed at the moment we wrap the raw descriptor. - // - // Note that the whole issue is more subtle than this. While we uphold the - // safety requirements of `from_raw_fd`, we cannot guarantee that we are - // exclusive owner of the descriptor or that the descriptor remains open - // throughout the entirety of the process lifetime which might lead to other - // kinds of undefined behaviour. As an additional safety measure we return - // a static mutable reference to ensure that the file destructor is never - // called. See the discussion in [1]. - // - // [1]: https://github.com/rust-lang/unsafe-code-guidelines/issues/434 - #[cfg(target_family = "unix")] - let file = unsafe { - if libc::fcntl(fd, libc::F_GETFD) == -1 { - let error = std::io::Error::last_os_error(); - panic!("invalid file descriptor '{fd}': {error}"); - } - - std::os::unix::io::FromRawFd::from_raw_fd(fd) - }; - - // SAFETY: `std::fs::File::from_raw_handle` requires the file handle to be - // valid, open and closable via `CloseHandle`. We verify this through a call - // to `GetFileType`: if the call fails or returns an unexpected file type, - // we panic. We expect the type to be a named pipe: this is what Fleetspeak - // should pass as and it is closable via `CloseHandle` [1] as required. We - // run within a critical section we are sure that the file was not closed at - // the moment we wrap the raw handle. - // - // See also remarks in the comment for the Unix branch of this method. - // - // [1]: https://learn.microsoft.com/en-us/windows/win32/api/handleapi/nf-handleapi-closehandle#remarks - #[cfg(target_family = "windows")] - let file = unsafe { - use windows_sys::Win32::{ - Foundation::*, - Storage::FileSystem::*, - }; - - // We use `identity` to specify the type for the `parse` call above. - let handle = std::convert::identity::(fd); - - // We fail both in case there is something wrong with the handle (in - // which case the call to `GetFileType` should return unknown file type - // and set the last error value) or the file type is not as expected. - let file_type = GetFileType(handle); - if file_type != FILE_TYPE_PIPE { - let code = GetLastError(); - if code != NO_ERROR { - let error = std::io::Error::from_raw_os_error(code as i32); - panic!("invalid file descriptor '{handle}': {error}"); - } else { - panic!("wrong type of file descriptor '{handle}': {file_type}"); - } - } - - // `HANDLE` type from `windows-sys` and from the standard library are - // different (one is a pointer and one is `isize`), so we have to cast - // between them. - let handle = handle as std::os::windows::raw::HANDLE; - - std::os::windows::io::FromRawHandle::from_raw_handle(handle) - }; - - drop(guard); - - Box::leak(Box::new(file)) -}