Skip to content

Commit 071b4a4

Browse files
committed
refactor rotation into standalone type, fix try_join sequencing test
1 parent a2f40aa commit 071b4a4

File tree

5 files changed

+128
-189
lines changed

5 files changed

+128
-189
lines changed

tokio/src/macros/join.rs

Lines changed: 46 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ doc! {macro_rules! join {
111111
#[cfg(not(doc))]
112112
doc! {macro_rules! join {
113113
(@ {
114-
// Whether to rotate which future is polled first every poll,
115-
// by incrementing a skip counter
116-
rotate_poll_order=$rotate_poll_order:literal;
114+
// Type of rotator that controls which inner future to start with
115+
// when polling our output future.
116+
rotator=$rotator:ty;
117117

118118
// One `_` for each branch in the `join!` macro. This is not used once
119119
// normalization is complete.
@@ -143,82 +143,33 @@ doc! {macro_rules! join {
143143
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
144144
let mut futures = &mut futures;
145145

146+
const COUNT: u32 = $($total)*;
147+
146148
// Each time the future created by poll_fn is polled,
147149
// if not running in biased mode,
148150
// a different future will be polled first
149151
// to ensure every future passed to join! gets a chance to make progress even if
150152
// one of the futures consumes the whole budget.
151-
//
152-
// This is number of futures that will be skipped in the first loop
153-
// iteration the next time poll.
154-
//
155-
// If running in biased mode, this variable will be optimized out since we don't pass
156-
// it to the poll_fn.
157-
let mut skip_next_time: u32 = 0;
158-
159-
match $rotate_poll_order {
160-
true => poll_fn(move |cx| {
161-
const COUNT: u32 = $($total)*;
162-
let mut is_pending = false;
163-
let mut to_run = COUNT;
164-
165-
// The number of futures that will be skipped in the first loop iteration.
166-
let mut skip = skip_next_time;
167-
// Upkeep for next poll, rotate first polled future
168-
skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
169-
170-
// This loop runs twice and the first `skip` futures
171-
// are not polled in the first iteration.
172-
loop {
173-
$(
174-
if skip == 0 {
175-
if to_run == 0 {
176-
// Every future has been polled
177-
break;
178-
}
179-
to_run -= 1;
180-
181-
// Extract the future for this branch from the tuple.
182-
let ( $($skip,)* fut, .. ) = &mut *futures;
183-
184-
// Safety: future is stored on the stack above
185-
// and never moved.
186-
let mut fut = unsafe { Pin::new_unchecked(fut) };
187-
188-
// Try polling
189-
if fut.poll(cx).is_pending() {
190-
is_pending = true;
191-
}
192-
} else {
193-
// Future skipped, one less future to skip in the next iteration
194-
skip -= 1;
195-
}
196-
)*
153+
let mut rotator = <$rotator>::default();
154+
155+
poll_fn(move |cx| {
156+
let mut is_pending = false;
157+
let mut to_run = COUNT;
158+
159+
// The number of futures that will be skipped in the first loop iteration.
160+
let mut skip = rotator.num_skip();
161+
162+
// This loop runs twice and the first `skip` futures
163+
// are not polled in the first iteration.
164+
loop {
165+
$(
166+
if skip == 0 {
167+
if to_run == 0 {
168+
// Every future has been polled
169+
break;
197170
}
171+
to_run -= 1;
198172

199-
if is_pending {
200-
Pending
201-
} else {
202-
Ready(($({
203-
// Extract the future for this branch from the tuple.
204-
let ( $($skip,)* fut, .. ) = &mut futures;
205-
206-
// Safety: future is stored on the stack above
207-
// and never moved.
208-
let mut fut = unsafe { Pin::new_unchecked(fut) };
209-
210-
fut.take_output().expect("expected completed future")
211-
},)*))
212-
}
213-
}).await,
214-
// don't rotate the poll order so no skipping
215-
false => poll_fn(move |cx| {
216-
const COUNT: u32 = $($total)*;
217-
let mut is_pending = false;
218-
let mut to_run = COUNT;
219-
220-
// no loop since we don't skip the first time through
221-
$(
222173
// Extract the future for this branch from the tuple.
223174
let ( $($skip,)* fut, .. ) = &mut *futures;
224175

@@ -230,39 +181,43 @@ doc! {macro_rules! join {
230181
if fut.poll(cx).is_pending() {
231182
is_pending = true;
232183
}
233-
)*
234-
235-
if is_pending {
236-
Pending
237184
} else {
238-
Ready(($({
239-
// Extract the future for this branch from the tuple.
240-
let ( $($skip,)* fut, .. ) = &mut futures;
185+
// Future skipped, one less future to skip in the next iteration
186+
skip -= 1;
187+
}
188+
)*
189+
}
190+
191+
if is_pending {
192+
Pending
193+
} else {
194+
Ready(($({
195+
// Extract the future for this branch from the tuple.
196+
let ( $($skip,)* fut, .. ) = &mut futures;
241197

242-
// Safety: future is stored on the stack above
243-
// and never moved.
244-
let mut fut = unsafe { Pin::new_unchecked(fut) };
198+
// Safety: future is stored on the stack above
199+
// and never moved.
200+
let mut fut = unsafe { Pin::new_unchecked(fut) };
245201

246-
fut.take_output().expect("expected completed future")
247-
},)*))
248-
}
249-
}).await
250-
}
202+
fut.take_output().expect("expected completed future")
203+
},)*))
204+
}
205+
}).await
251206
}};
252207

253208
// ===== Normalize =====
254209

255-
(@ { rotate_poll_order=$rotate_poll_order:literal; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
256-
$crate::join!(@{ rotate_poll_order=$rotate_poll_order; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
210+
(@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
211+
$crate::join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
257212
};
258213

259214
// ===== Entry point =====
260215
( biased; $($e:expr),+ $(,)?) => {
261-
$crate::join!(@{ rotate_poll_order=false; () (0) } $($e,)*)
216+
$crate::join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*)
262217
};
263218

264219
( $($e:expr),+ $(,)?) => {
265-
$crate::join!(@{ rotate_poll_order=true; () (0) } $($e,)*)
220+
$crate::join!(@{ rotator=$crate::macros::support::Rotator<COUNT>; () (0) } $($e,)*)
266221
};
267222

268223
(biased;) => { async {}.await };

tokio/src/macros/support.rs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,34 @@ cfg_macros! {
2828
pub use std::future::{Future, IntoFuture};
2929
pub use std::pin::Pin;
3030
pub use std::task::{Context, Poll};
31+
32+
#[doc(hidden)]
33+
#[derive(Default, Debug)]
34+
pub struct Rotator<const COUNT: u32> {
35+
next: u32,
36+
}
37+
38+
impl<const COUNT: u32> Rotator<COUNT> {
39+
#[doc(hidden)]
40+
#[inline]
41+
pub fn num_skip(&mut self) -> u32 {
42+
let num_skip = self.next;
43+
self.next += 1;
44+
if self.next == COUNT {
45+
self.next = 0;
46+
}
47+
num_skip
48+
}
49+
}
50+
51+
#[doc(hidden)]
52+
#[derive(Default, Debug)]
53+
pub struct BiasedRotator {}
54+
55+
impl BiasedRotator {
56+
#[doc(hidden)]
57+
#[inline]
58+
pub fn num_skip(&mut self) -> u32 {
59+
0
60+
}
61+
}

tokio/src/macros/try_join.rs

Lines changed: 38 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ doc! {macro_rules! try_join {
164164
#[cfg(not(doc))]
165165
doc! {macro_rules! try_join {
166166
(@ {
167-
// Whether to rotate which future is polled first every poll,
168-
// by incrementing a skip counter
169-
rotate_poll_order=$rotate_poll_order:literal;
167+
// Type of rotator that controls which inner future to start with
168+
// when polling our output future.
169+
rotator=$rotator:ty;
170170

171171
// One `_` for each branch in the `try_join!` macro. This is not used once
172172
// normalization is complete.
@@ -196,102 +196,52 @@ doc! {macro_rules! try_join {
196196
// <https://internals.rust-lang.org/t/surprising-soundness-trouble-around-pollfn/17484>
197197
let mut futures = &mut futures;
198198

199+
const COUNT: u32 = $($total)*;
200+
199201
// Each time the future created by poll_fn is polled,
200202
// if not running in biased mode,
201203
// a different future will be polled first
202204
// to ensure every future passed to join! gets a chance to make progress even if
203205
// one of the futures consumes the whole budget.
204-
//
205-
// This is number of futures that will be skipped in the first loop
206-
// iteration the next time.
207-
//
208-
// If running in biased mode, this variable will be optimized out since we don't pass
209-
// it to the poll_fn.
210-
let mut skip_next_time: u32 = 0;
206+
let mut rotator = <$rotator>::default();
211207

212-
match $rotate_poll_order {
213-
true => poll_fn(move |cx| {
214-
const COUNT: u32 = $($total)*;
215-
let mut is_pending = false;
216-
let mut to_run = COUNT;
208+
poll_fn(move |cx| {
209+
let mut is_pending = false;
210+
let mut to_run = COUNT;
217211

218-
// The number of futures that will be skipped in the first loop iteration.
219-
let mut skip = skip_next_time;
220-
// Upkeep for next poll, rotate first polled future
221-
skip_next_time = if skip + 1 == COUNT { 0 } else { skip + 1 };
212+
// The number of futures that will be skipped in the first loop iteration.
213+
let mut skip = rotator.num_skip();
222214

223-
// This loop runs twice and the first `skip` futures
224-
// are not polled in the first iteration.
225-
loop {
226-
$(
227-
if skip == 0 {
228-
if to_run == 0 {
229-
// Every future has been polled
230-
break;
231-
}
232-
to_run -= 1;
215+
// This loop runs twice and the first `skip` futures
216+
// are not polled in the first iteration.
217+
loop {
218+
$(
219+
if skip == 0 {
220+
if to_run == 0 {
221+
// Every future has been polled
222+
break;
223+
}
224+
to_run -= 1;
233225

234-
// Extract the future for this branch from the tuple.
235-
let ( $($skip,)* fut, .. ) = &mut *futures;
226+
// Extract the future for this branch from the tuple.
227+
let ( $($skip,)* fut, .. ) = &mut *futures;
236228

237-
// Safety: future is stored on the stack above
238-
// and never moved.
239-
let mut fut = unsafe { Pin::new_unchecked(fut) };
229+
// Safety: future is stored on the stack above
230+
// and never moved.
231+
let mut fut = unsafe { Pin::new_unchecked(fut) };
240232

241-
// Try polling
242-
if fut.as_mut().poll(cx).is_pending() {
243-
is_pending = true;
244-
} else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
245-
return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
246-
}
247-
} else {
248-
// Future skipped, one less future to skip in the next iteration
249-
skip -= 1;
233+
// Try polling
234+
if fut.as_mut().poll(cx).is_pending() {
235+
is_pending = true;
236+
} else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
237+
return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
250238
}
251-
)*
252-
}
253-
254-
if is_pending {
255-
Pending
256239
} else {
257-
Ready(Ok(($({
258-
// Extract the future for this branch from the tuple.
259-
let ( $($skip,)* fut, .. ) = &mut futures;
260-
261-
// Safety: future is stored on the stack above
262-
// and never moved.
263-
let mut fut = unsafe { Pin::new_unchecked(fut) };
264-
265-
fut
266-
.take_output()
267-
.expect("expected completed future")
268-
.ok()
269-
.expect("expected Ok(_)")
270-
},)*)))
271-
}
272-
}).await,
273-
// don't rotate the poll order so no skipping
274-
false => poll_fn(move |cx| {
275-
const COUNT: u32 = $($total)*;
276-
let mut is_pending = false;
277-
let mut to_run = COUNT;
278-
279-
// no loop since we don't skip the first time through
280-
$(
281-
// Extract the future for this branch from the tuple.
282-
let ( $($skip,)* fut, .. ) = &mut *futures;
283-
284-
// Safety: future is stored on the stack above
285-
// and never moved.
286-
let mut fut = unsafe { Pin::new_unchecked(fut) };
287-
288-
// Try polling
289-
if fut.as_mut().poll(cx).is_pending() {
290-
is_pending = true;
291-
} else if fut.as_mut().output_mut().expect("expected completed future").is_err() {
292-
return Ready(Err(fut.take_output().expect("expected completed future").err().unwrap()))
240+
// Future skipped, one less future to skip in the next iteration
241+
skip -= 1;
293242
}
294243
)*
244+
}
295245

296246
if is_pending {
297247
Pending
@@ -312,22 +262,21 @@ doc! {macro_rules! try_join {
312262
},)*)))
313263
}
314264
}).await
315-
}
316265
}};
317266

318267
// ===== Normalize =====
319268

320-
(@ { rotate_poll_order=$rotate_poll_order:literal; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
321-
$crate::try_join!(@{ rotate_poll_order=$rotate_poll_order; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
269+
(@ { rotator=$rotator:ty; ( $($s:tt)* ) ( $($n:tt)* ) $($t:tt)* } $e:expr, $($r:tt)* ) => {
270+
$crate::try_join!(@{ rotator=$rotator; ($($s)* _) ($($n)* + 1) $($t)* ($($s)*) $e, } $($r)*)
322271
};
323272

324273
// ===== Entry point =====
325274
( biased; $($e:expr),+ $(,)?) => {
326-
$crate::try_join!(@{ rotate_poll_order=false; () (0) } $($e,)*)
275+
$crate::try_join!(@{ rotator=$crate::macros::support::BiasedRotator; () (0) } $($e,)*)
327276
};
328277

329278
( $($e:expr),+ $(,)?) => {
330-
$crate::try_join!(@{ rotate_poll_order=true; () (0) } $($e,)*)
279+
$crate::try_join!(@{ rotator=$crate::macros::support::Rotator<COUNT>; () (0) } $($e,)*)
331280
};
332281

333282
(biased;) => { async { Ok(()) }.await };

0 commit comments

Comments
 (0)