From d0c48f4cd652459562e6b0323f94b0ffb01256cd Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Mon, 9 Sep 2024 14:09:57 +0800 Subject: [PATCH 1/4] refactor: I/O safety for control msg ScmRights --- src/sys/socket/mod.rs | 32 +++++++++++++++++--------------- test/sys/test_socket.rs | 32 +++++++++++++++----------------- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/src/sys/socket/mod.rs b/src/sys/socket/mod.rs index 99b47a5905..65842f2265 100644 --- a/src/sys/socket/mod.rs +++ b/src/sys/socket/mod.rs @@ -19,7 +19,7 @@ use libc::{ use std::io::{IoSlice, IoSliceMut}; #[cfg(feature = "net")] use std::net; -use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd}; +use std::os::unix::io::{AsFd, AsRawFd, FromRawFd, OwnedFd, RawFd, BorrowedFd}; use std::{mem, ptr}; #[deny(missing_docs)] @@ -35,6 +35,9 @@ pub mod sockopt; pub use self::addr::{SockaddrLike, SockaddrStorage}; +#[cfg(feature = "net")] +use crate::sys::socket::addr::{ipv4addr_to_libc, ipv6addr_to_libc}; + #[cfg(solarish)] pub use self::addr::{AddressFamily, UnixAddr}; #[cfg(not(solarish))] @@ -62,9 +65,6 @@ pub use libc::{sa_family_t, sockaddr, sockaddr_storage, sockaddr_un}; #[cfg(feature = "net")] pub use libc::{sockaddr_in, sockaddr_in6}; -#[cfg(feature = "net")] -use crate::sys::socket::addr::{ipv4addr_to_libc, ipv6addr_to_libc}; - /// These constants are used to specify the communication semantics /// when creating a socket with [`socket()`](fn.socket.html) #[derive(Clone, Copy, PartialEq, Eq, Debug)] @@ -556,16 +556,16 @@ feature! { /// ``` /// # #[macro_use] extern crate nix; /// # use nix::sys::time::TimeVal; -/// # use std::os::unix::io::RawFd; +/// # use std::os::fd::OwnedFd; /// # fn main() { /// // Create a buffer for a `ControlMessageOwned::ScmTimestamp` message /// let _ = cmsg_space!(TimeVal); /// // Create a buffer big enough for a `ControlMessageOwned::ScmRights` message /// // with two file descriptors -/// let _ = cmsg_space!([RawFd; 2]); +/// let _ = cmsg_space!([OwnedFd; 2]); /// // Create a buffer big enough for a `ControlMessageOwned::ScmRights` message /// // and a `ControlMessageOwned::ScmTimestamp` message -/// let _ = cmsg_space!(RawFd, TimeVal); +/// let _ = cmsg_space!(OwnedFd, TimeVal); /// # } /// ``` #[macro_export] @@ -655,11 +655,11 @@ impl<'a> Iterator for CmsgIterator<'a> { // alignment issues. // // See https://github.com/nix-rust/nix/issues/999 -#[derive(Clone, Debug, Eq, PartialEq)] +#[derive(Debug)] #[non_exhaustive] pub enum ControlMessageOwned { /// Received version of [`ControlMessage::ScmRights`] - ScmRights(Vec), + ScmRights(Vec), /// Received version of [`ControlMessage::ScmCredentials`] #[cfg(linux_android)] ScmCredentials(UnixCredentials), @@ -908,11 +908,11 @@ impl ControlMessageOwned { - p as usize; match (header.cmsg_level, header.cmsg_type) { (libc::SOL_SOCKET, libc::SCM_RIGHTS) => { - let n = len / mem::size_of::(); + let n = len / mem::size_of::(); let mut fds = Vec::with_capacity(n); for i in 0..n { unsafe { - let fdp = (p as *const RawFd).add(i); + let fdp = (p as *const OwnedFd).add(i); fds.push(ptr::read_unaligned(fdp)); } } @@ -1094,7 +1094,7 @@ impl ControlMessageOwned { /// pattern-match it. /// /// [Further reading](https://man7.org/linux/man-pages/man3/cmsg.3.html) -#[derive(Clone, Copy, Debug, Eq, PartialEq)] +#[derive(Clone, Copy, Debug)] #[non_exhaustive] pub enum ControlMessage<'a> { /// A message of type `SCM_RIGHTS`, containing an array of file @@ -1108,7 +1108,7 @@ pub enum ControlMessage<'a> { /// swallow all but the first `ScmRights` message or fail with `EINVAL`. /// Instead, you can put all fds to be passed into a single `ScmRights` /// message. - ScmRights(&'a [RawFd]), + ScmRights(&'a [BorrowedFd<'a>]), /// A message of type `SCM_CREDENTIALS`, containing the pid, uid and gid of /// a process connected to the socket. /// @@ -1586,13 +1586,14 @@ impl<'a> ControlMessage<'a> { /// # use nix::unistd::pipe; /// # use std::io::IoSlice; /// # use std::os::unix::io::AsRawFd; +/// # use std::os::unix::io::AsFd; /// let (fd1, fd2) = socketpair(AddressFamily::Unix, SockType::Stream, None, /// SockFlag::empty()) /// .unwrap(); /// let (r, w) = pipe().unwrap(); /// /// let iov = [IoSlice::new(b"hello")]; -/// let fds = [r.as_raw_fd()]; +/// let fds = [r.as_fd()]; /// let cmsg = ControlMessage::ScmRights(&fds); /// sendmsg::<()>(fd1.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), None).unwrap(); /// ``` @@ -1602,6 +1603,7 @@ impl<'a> ControlMessage<'a> { /// # use nix::unistd::pipe; /// # use std::io::IoSlice; /// # use std::str::FromStr; +/// # use std::os::unix::io::AsFd; /// # use std::os::unix::io::AsRawFd; /// let localhost = SockaddrIn::from_str("1.2.3.4:8080").unwrap(); /// let fd = socket(AddressFamily::Inet, SockType::Datagram, SockFlag::empty(), @@ -1609,7 +1611,7 @@ impl<'a> ControlMessage<'a> { /// let (r, w) = pipe().unwrap(); /// /// let iov = [IoSlice::new(b"hello")]; -/// let fds = [r.as_raw_fd()]; +/// let fds = [r.as_fd()]; /// let cmsg = ControlMessage::ScmRights(&fds); /// sendmsg(fd.as_raw_fd(), &iov, &[cmsg], MsgFlags::empty(), Some(&localhost)).unwrap(); /// ``` diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index afe5629142..bd47b356ee 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -5,6 +5,7 @@ use nix::sys::socket::{getsockname, AddressFamily, UnixAddr}; use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; use std::net::{SocketAddrV4, SocketAddrV6}; +use std::os::fd::AsFd; use std::os::unix::io::{AsRawFd, RawFd}; use std::path::Path; use std::slice; @@ -846,8 +847,9 @@ pub fn test_scm_rights() { recvmsg, sendmsg, socketpair, AddressFamily, ControlMessage, ControlMessageOwned, MsgFlags, SockFlag, SockType, }; - use nix::unistd::{close, pipe, read, write}; + use nix::unistd::{pipe, read, write}; use std::io::{IoSlice, IoSliceMut}; + use std::os::fd::OwnedFd; let (fd1, fd2) = socketpair( AddressFamily::Unix, @@ -857,11 +859,11 @@ pub fn test_scm_rights() { ) .unwrap(); let (r, w) = pipe().unwrap(); - let mut received_r: Option = None; + let mut received_r: Option = None; { let iov = [IoSlice::new(b"hello")]; - let fds = [r.as_raw_fd()]; + let fds = [r.as_fd()]; let cmsg = ControlMessage::ScmRights(&fds); assert_eq!( sendmsg::<()>( @@ -890,10 +892,10 @@ pub fn test_scm_rights() { .unwrap(); for cmsg in msg.cmsgs().unwrap() { - if let ControlMessageOwned::ScmRights(fd) = cmsg { - assert_eq!(received_r, None); + if let ControlMessageOwned::ScmRights(mut fd) = cmsg { + assert!(received_r.is_none()); assert_eq!(fd.len(), 1); - received_r = Some(fd[0]); + received_r = Some(fd.pop().unwrap()); } else { panic!("unexpected cmsg"); } @@ -908,15 +910,8 @@ pub fn test_scm_rights() { // Ensure that the received file descriptor works write(&w, b"world").unwrap(); let mut buf = [0u8; 5]; - // SAFETY: - // should be safe since we don't use it after close - let borrowed_received_r = - unsafe { std::os::fd::BorrowedFd::borrow_raw(received_r) }; - read(borrowed_received_r, &mut buf).unwrap(); + read(&received_r, &mut buf).unwrap(); assert_eq!(&buf[..], b"world"); - // SAFETY: - // there shouldn't be double close - unsafe { close(received_r).unwrap() }; } // Disable the test on emulated platforms due to not enabled support of AF_ALG in QEMU from rust cross @@ -1333,7 +1328,8 @@ fn test_scm_rights_single_cmsg_multiple_fds() { recvmsg, sendmsg, ControlMessage, ControlMessageOwned, MsgFlags, }; use std::io::{IoSlice, IoSliceMut}; - use std::os::unix::io::{AsRawFd, RawFd}; + use std::os::fd::BorrowedFd; + use std::os::unix::io::AsRawFd; use std::os::unix::net::UnixDatagram; use std::thread; @@ -1342,7 +1338,7 @@ fn test_scm_rights_single_cmsg_multiple_fds() { let mut buf = [0u8; 8]; let mut iovec = [IoSliceMut::new(&mut buf)]; - let mut space = cmsg_space!([RawFd; 2]); + let mut space = cmsg_space!([BorrowedFd; 2]); let msg = recvmsg::<()>( receive.as_raw_fd(), &mut iovec, @@ -1374,7 +1370,9 @@ fn test_scm_rights_single_cmsg_multiple_fds() { let slice = [1u8, 2, 3, 4, 5, 6, 7, 8]; let iov = [IoSlice::new(&slice)]; - let fds = [libc::STDIN_FILENO, libc::STDOUT_FILENO]; // pass stdin and stdout + let stdin_owned = std::io::stdin(); + let stdout_owned = std::io::stdout(); + let fds = [stdin_owned.as_fd(), stdout_owned.as_fd()]; // pass stdin and stdout let cmsg = [ControlMessage::ScmRights(&fds)]; sendmsg::<()>(send.as_raw_fd(), &iov, &cmsg, MsgFlags::empty(), None) .unwrap(); From fb8701c045f6d02fdd7d58659ee6309559c58f76 Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Mon, 9 Sep 2024 14:22:41 +0800 Subject: [PATCH 2/4] test: try to fix test on Linux --- test/sys/test_socket.rs | 19 ++++++------------- 1 file changed, 6 insertions(+), 13 deletions(-) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index bd47b356ee..a7daf633a8 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -1556,6 +1556,7 @@ fn test_impl_scm_credentials_and_rights( }; use nix::unistd::{close, getgid, getpid, getuid, pipe, write}; use std::io::{IoSlice, IoSliceMut}; + use std::os::fd::BorrowedFd; let (send, recv) = socketpair( AddressFamily::Unix, @@ -1567,7 +1568,7 @@ fn test_impl_scm_credentials_and_rights( setsockopt(&recv, PassCred, &true).unwrap(); let (r, w) = pipe().unwrap(); - let mut received_r: Option = None; + let mut received_r: Option = None; { let iov = [IoSlice::new(b"hello")]; @@ -1577,7 +1578,7 @@ fn test_impl_scm_credentials_and_rights( gid: getgid().as_raw(), } .into(); - let fds = [r.as_raw_fd()]; + let fds = [r.as_fd()]; let cmsgs = [ ControlMessage::ScmCredentials(&cred), ControlMessage::ScmRights(&fds), @@ -1611,10 +1612,10 @@ fn test_impl_scm_credentials_and_rights( for cmsg in msg.cmsgs()? { match cmsg { - ControlMessageOwned::ScmRights(fds) => { + ControlMessageOwned::ScmRights(mut fds) => { assert_eq!(received_r, None, "already received fd"); assert_eq!(fds.len(), 1); - received_r = Some(fds[0]); + received_r = Some(fds.pop().unwrap()); } ControlMessageOwned::ScmCredentials(cred) => { assert!(received_cred.is_none()); @@ -1637,16 +1638,8 @@ fn test_impl_scm_credentials_and_rights( // Ensure that the received file descriptor works write(&w, b"world").unwrap(); let mut buf = [0u8; 5]; - // SAFETY: - // It should be safe if we don't use this BorrowedFd after close. - let received_r_borrowed = - unsafe { std::os::fd::BorrowedFd::borrow_raw(received_r) }; - read(received_r_borrowed, &mut buf).unwrap(); + read(&received_r, &mut buf).unwrap(); assert_eq!(&buf[..], b"world"); - // SAFETY: - // double-close won't happen - unsafe { close(received_r).unwrap() }; - Ok(()) } From fbbccff7d0705acf65bb906e6366739a1c1b5da7 Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Mon, 9 Sep 2024 14:25:47 +0800 Subject: [PATCH 3/4] test: try to fix test on Linux --- test/sys/test_socket.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index a7daf633a8..4b6bfccb67 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -1568,7 +1568,7 @@ fn test_impl_scm_credentials_and_rights( setsockopt(&recv, PassCred, &true).unwrap(); let (r, w) = pipe().unwrap(); - let mut received_r: Option = None; + let mut received_r = None; { let iov = [IoSlice::new(b"hello")]; From 57ebe0a8bcb85088ab461af250d36eec7b260a9d Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Mon, 9 Sep 2024 14:28:12 +0800 Subject: [PATCH 4/4] test: try to fix test on Linux --- test/sys/test_socket.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/sys/test_socket.rs b/test/sys/test_socket.rs index 4b6bfccb67..5362045a8d 100644 --- a/test/sys/test_socket.rs +++ b/test/sys/test_socket.rs @@ -1554,9 +1554,8 @@ fn test_impl_scm_credentials_and_rights( recvmsg, sendmsg, setsockopt, socketpair, ControlMessage, ControlMessageOwned, MsgFlags, SockFlag, SockType, }; - use nix::unistd::{close, getgid, getpid, getuid, pipe, write}; + use nix::unistd::{getgid, getpid, getuid, pipe, write}; use std::io::{IoSlice, IoSliceMut}; - use std::os::fd::BorrowedFd; let (send, recv) = socketpair( AddressFamily::Unix, @@ -1613,7 +1612,7 @@ fn test_impl_scm_credentials_and_rights( for cmsg in msg.cmsgs()? { match cmsg { ControlMessageOwned::ScmRights(mut fds) => { - assert_eq!(received_r, None, "already received fd"); + assert!(received_r.is_none(), "already received fd"); assert_eq!(fds.len(), 1); received_r = Some(fds.pop().unwrap()); }