Skip to content

Commit a921a5d

Browse files
Safer poll timeout
1 parent e756c96 commit a921a5d

File tree

3 files changed

+263
-9
lines changed

3 files changed

+263
-9
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ This project adheres to [Semantic Versioning](https://semver.org/).
7777
([#1870](https://github.com/nix-rust/nix/pull/1870))
7878
- The `length` argument of `sys::mman::mmap` is now of type `NonZeroUsize`.
7979
([#1873](https://github.com/nix-rust/nix/pull/1873))
80+
- The `timeout` argument of `poll::poll` is now of type `poll::PollTimeout`.
81+
([#1876](https://github.com/nix-rust/nix/pull/1876))
8082

8183
### Fixed
8284

src/poll.rs

+258-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
//! Wait for events to trigger on specific file descriptors
2+
use std::convert::TryFrom;
3+
use std::fmt;
24
use std::os::unix::io::{AsRawFd, RawFd};
5+
use std::time::Duration;
36

47
use crate::errno::Errno;
58
use crate::Result;
@@ -132,6 +135,255 @@ libc_bitflags! {
132135
}
133136
}
134137

138+
/// Timeout argument for [`poll`].
139+
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
140+
pub struct PollTimeout(i32);
141+
142+
/// Error type for [`PollTimeout::try_from::<i128>::()`].
143+
#[derive(Debug, Clone, Copy)]
144+
pub enum TryFromI128Error {
145+
/// Value is less than -1.
146+
Underflow(crate::Errno),
147+
/// Value is greater than [`i32::MAX`].
148+
Overflow(<i32 as TryFrom<i128>>::Error),
149+
}
150+
impl fmt::Display for TryFromI128Error {
151+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
152+
match self {
153+
Self::Underflow(err) => write!(f, "Underflow: {err}"),
154+
Self::Overflow(err) => write!(f, "Overflow: {err}"),
155+
}
156+
}
157+
}
158+
impl std::error::Error for TryFromI128Error {}
159+
160+
/// Error type for [`PollTimeout::try_from::<i68>()`].
161+
#[derive(Debug, Clone, Copy)]
162+
pub enum TryFromI64Error {
163+
/// Value is less than -1.
164+
Underflow(crate::Errno),
165+
/// Value is greater than [`i32::MAX`].
166+
Overflow(<i32 as TryFrom<i64>>::Error),
167+
}
168+
impl fmt::Display for TryFromI64Error {
169+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
170+
match self {
171+
Self::Underflow(err) => write!(f, "Underflow: {err}"),
172+
Self::Overflow(err) => write!(f, "Overflow: {err}"),
173+
}
174+
}
175+
}
176+
impl std::error::Error for TryFromI64Error {}
177+
178+
// These cases implement slightly different conversions that make using generics impossible without
179+
// specialization.
180+
impl PollTimeout {
181+
/// Blocks indefinitely.
182+
pub const NONE: Self = Self(-1);
183+
/// Returns immediately.
184+
pub const ZERO: Self = Self(0);
185+
/// Blocks for at most [`std::i32::MAX`] milliseconds.
186+
pub const MAX: Self = Self(i32::MAX);
187+
/// Returns if `self` equals [`PollTimeout::NONE`].
188+
pub fn is_none(&self) -> bool {
189+
*self == Self::NONE
190+
}
191+
/// Returns if `self` does not equal [`PollTimeout::NONE`].
192+
pub fn is_some(&self) -> bool {
193+
!self.is_none()
194+
}
195+
/// Returns the timeout in milliseconds if there is some, otherwise returns `None`.
196+
pub fn timeout(&self) -> Option<i32> {
197+
self.is_some().then_some(self.0)
198+
}
199+
}
200+
impl TryFrom<Duration> for PollTimeout {
201+
type Error = <i32 as TryFrom<u128>>::Error;
202+
fn try_from(x: Duration) -> std::result::Result<Self, Self::Error> {
203+
Ok(Self(i32::try_from(x.as_millis())?))
204+
}
205+
}
206+
impl TryFrom<u128> for PollTimeout {
207+
type Error = <i32 as TryFrom<u128>>::Error;
208+
fn try_from(x: u128) -> std::result::Result<Self, Self::Error> {
209+
Ok(Self(i32::try_from(x)?))
210+
}
211+
}
212+
impl TryFrom<u64> for PollTimeout {
213+
type Error = <i32 as TryFrom<u64>>::Error;
214+
fn try_from(x: u64) -> std::result::Result<Self, Self::Error> {
215+
Ok(Self(i32::try_from(x)?))
216+
}
217+
}
218+
impl TryFrom<u32> for PollTimeout {
219+
type Error = <i32 as TryFrom<u32>>::Error;
220+
fn try_from(x: u32) -> std::result::Result<Self, Self::Error> {
221+
Ok(Self(i32::try_from(x)?))
222+
}
223+
}
224+
impl From<u16> for PollTimeout {
225+
fn from(x: u16) -> Self {
226+
Self(i32::from(x))
227+
}
228+
}
229+
impl From<u8> for PollTimeout {
230+
fn from(x: u8) -> Self {
231+
Self(i32::from(x))
232+
}
233+
}
234+
impl TryFrom<i128> for PollTimeout {
235+
type Error = TryFromI128Error;
236+
fn try_from(x: i128) -> std::result::Result<Self, Self::Error> {
237+
match x {
238+
-1 => Ok(Self::NONE),
239+
millis @ 0.. => Ok(Self(
240+
i32::try_from(millis).map_err(TryFromI128Error::Overflow)?,
241+
)),
242+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
243+
_ => Err(TryFromI128Error::Underflow(Errno::EINVAL)),
244+
}
245+
}
246+
}
247+
impl TryFrom<i64> for PollTimeout {
248+
type Error = TryFromI64Error;
249+
fn try_from(x: i64) -> std::result::Result<Self, Self::Error> {
250+
match x {
251+
-1 => Ok(Self::NONE),
252+
millis @ 0.. => Ok(Self(
253+
i32::try_from(millis).map_err(TryFromI64Error::Overflow)?,
254+
)),
255+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
256+
_ => Err(TryFromI64Error::Underflow(Errno::EINVAL)),
257+
}
258+
}
259+
}
260+
impl TryFrom<i32> for PollTimeout {
261+
type Error = Errno;
262+
fn try_from(x: i32) -> Result<Self> {
263+
match x {
264+
-1 => Ok(Self::NONE),
265+
millis @ 0.. => Ok(Self(millis)),
266+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
267+
_ => Err(Errno::EINVAL),
268+
}
269+
}
270+
}
271+
impl TryFrom<i16> for PollTimeout {
272+
type Error = Errno;
273+
fn try_from(x: i16) -> Result<Self> {
274+
match x {
275+
-1 => Ok(Self::NONE),
276+
millis @ 0.. => Ok(Self(millis.into())),
277+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
278+
_ => Err(Errno::EINVAL),
279+
}
280+
}
281+
}
282+
impl TryFrom<i8> for PollTimeout {
283+
type Error = Errno;
284+
fn try_from(x: i8) -> Result<Self> {
285+
match x {
286+
-1 => Ok(Self::NONE),
287+
millis @ 0.. => Ok(Self(millis.into())),
288+
// EINVAL (ppoll()) The timeout value expressed in *ip is invalid (negative).
289+
_ => Err(Errno::EINVAL),
290+
}
291+
}
292+
}
293+
impl TryFrom<PollTimeout> for Duration {
294+
type Error = ();
295+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
296+
match x.timeout() {
297+
// SAFETY: `x.0` is always positive.
298+
Some(millis) => Ok(Duration::from_millis(unsafe {
299+
u64::try_from(millis).unwrap_unchecked()
300+
})),
301+
None => Err(()),
302+
}
303+
}
304+
}
305+
impl TryFrom<PollTimeout> for u128 {
306+
type Error = ();
307+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
308+
match x.timeout() {
309+
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
310+
Some(millis) => {
311+
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
312+
}
313+
None => Err(()),
314+
}
315+
}
316+
}
317+
impl TryFrom<PollTimeout> for u64 {
318+
type Error = ();
319+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
320+
match x.timeout() {
321+
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
322+
Some(millis) => {
323+
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
324+
}
325+
None => Err(()),
326+
}
327+
}
328+
}
329+
impl TryFrom<PollTimeout> for u32 {
330+
type Error = ();
331+
fn try_from(x: PollTimeout) -> std::result::Result<Self, ()> {
332+
match x.timeout() {
333+
// SAFETY: When `x.timeout()` returns `Some(a)`, `a` is always positive.
334+
Some(millis) => {
335+
Ok(unsafe { Self::try_from(millis).unwrap_unchecked() })
336+
}
337+
None => Err(()),
338+
}
339+
}
340+
}
341+
impl TryFrom<PollTimeout> for u16 {
342+
type Error = Option<<Self as TryFrom<i32>>::Error>;
343+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
344+
match x.timeout() {
345+
Some(millis) => Ok(Self::try_from(millis).map_err(Some)?),
346+
None => Err(None),
347+
}
348+
}
349+
}
350+
impl TryFrom<PollTimeout> for u8 {
351+
type Error = Option<<Self as TryFrom<i32>>::Error>;
352+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
353+
match x.timeout() {
354+
Some(millis) => Ok(Self::try_from(millis).map_err(Some)?),
355+
None => Err(None),
356+
}
357+
}
358+
}
359+
impl From<PollTimeout> for i128 {
360+
fn from(x: PollTimeout) -> Self {
361+
x.timeout().unwrap_or(-1).into()
362+
}
363+
}
364+
impl From<PollTimeout> for i64 {
365+
fn from(x: PollTimeout) -> Self {
366+
x.timeout().unwrap_or(-1).into()
367+
}
368+
}
369+
impl From<PollTimeout> for i32 {
370+
fn from(x: PollTimeout) -> Self {
371+
x.timeout().unwrap_or(-1)
372+
}
373+
}
374+
impl TryFrom<PollTimeout> for i16 {
375+
type Error = <Self as TryFrom<i32>>::Error;
376+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
377+
Self::try_from(x.timeout().unwrap_or(-1))
378+
}
379+
}
380+
impl TryFrom<PollTimeout> for i8 {
381+
type Error = <Self as TryFrom<i32>>::Error;
382+
fn try_from(x: PollTimeout) -> std::result::Result<Self, Self::Error> {
383+
Self::try_from(x.timeout().unwrap_or(-1))
384+
}
385+
}
386+
135387
/// `poll` waits for one of a set of file descriptors to become ready to perform I/O.
136388
/// ([`poll(2)`](https://pubs.opengroup.org/onlinepubs/9699919799/functions/poll.html))
137389
///
@@ -148,16 +400,16 @@ libc_bitflags! {
148400
///
149401
/// Note that the timeout interval will be rounded up to the system clock
150402
/// granularity, and kernel scheduling delays mean that the blocking
151-
/// interval may overrun by a small amount. Specifying a negative value
152-
/// in timeout means an infinite timeout. Specifying a timeout of zero
153-
/// causes `poll()` to return immediately, even if no file descriptors are
154-
/// ready.
155-
pub fn poll(fds: &mut [PollFd], timeout: libc::c_int) -> Result<libc::c_int> {
403+
/// interval may overrun by a small amount. Specifying a [`PollTimeout::NONE`]
404+
/// in timeout means an infinite timeout. Specifying a timeout of
405+
/// [`PollTimeout::ZERO`] causes `poll()` to return immediately, even if no file
406+
/// descriptors are ready.
407+
pub fn poll(fds: &mut [PollFd], timeout: PollTimeout) -> Result<libc::c_int> {
156408
let res = unsafe {
157409
libc::poll(
158410
fds.as_mut_ptr() as *mut libc::pollfd,
159411
fds.len() as libc::nfds_t,
160-
timeout,
412+
timeout.into(),
161413
)
162414
};
163415

test/test_poll.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use nix::{
22
errno::Errno,
3-
poll::{poll, PollFd, PollFlags},
3+
poll::{poll, PollFd, PollFlags, PollTimeout},
44
unistd::{pipe, write},
55
};
66

@@ -22,14 +22,14 @@ fn test_poll() {
2222
let mut fds = [PollFd::new(r, PollFlags::POLLIN)];
2323

2424
// Poll an idle pipe. Should timeout
25-
let nfds = loop_while_eintr!(poll(&mut fds, 100));
25+
let nfds = loop_while_eintr!(poll(&mut fds, PollTimeout::from(100u8)));
2626
assert_eq!(nfds, 0);
2727
assert!(!fds[0].revents().unwrap().contains(PollFlags::POLLIN));
2828

2929
write(w, b".").unwrap();
3030

3131
// Poll a readable pipe. Should return an event.
32-
let nfds = poll(&mut fds, 100).unwrap();
32+
let nfds = poll(&mut fds, PollTimeout::from(100u8)).unwrap();
3333
assert_eq!(nfds, 1);
3434
assert!(fds[0].revents().unwrap().contains(PollFlags::POLLIN));
3535
}

0 commit comments

Comments
 (0)