Skip to content

Commit f466e61

Browse files
MajorBreakfastcramertj
authored andcommitted
Improvements to Shared
1 parent 4fd55ab commit f466e61

File tree

2 files changed

+81
-87
lines changed

2 files changed

+81
-87
lines changed

futures-util/src/future/mod.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,9 @@ pub trait FutureExt: Future {
586586
/// Create a cloneable handle to this future where all handles will resolve
587587
/// to the same result.
588588
///
589-
/// The shared() method provides a method to convert any future into a
590-
/// cloneable future. It enables a future to be polled by multiple threads.
589+
/// The `shared` combinator method provides a method to convert any future
590+
/// into a cloneable future. It enables a future to be polled by multiple
591+
/// threads.
591592
///
592593
/// This method is only available when the `std` feature of this
593594
/// library is activated, and it is activated by default.

futures-util/src/future/shared.rs

Lines changed: 78 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1+
use crate::task::local_waker_ref_from_nonlocal;
12
use futures_core::future::Future;
23
use futures_core::task::{LocalWaker, Poll, Wake, Waker};
34
use slab::Slab;
4-
use std::fmt;
55
use std::cell::UnsafeCell;
6+
use std::fmt;
67
use std::marker::Unpin;
78
use std::pin::Pin;
8-
use std::sync::{Arc, Mutex};
99
use std::sync::atomic::AtomicUsize;
1010
use std::sync::atomic::Ordering::SeqCst;
11-
use std::task::local_waker_from_nonlocal;
11+
use std::sync::{Arc, Mutex};
1212

1313
/// A future that is cloneable and can be polled in multiple threads.
14-
/// Use [`FutureExt::shared()`](crate::FutureExt::shared) method to convert any future into a `Shared` future.
14+
/// Use the [`shared`](crate::FutureExt::shared) combinator method to convert
15+
/// any future into a `Shared` future.
1516
#[must_use = "futures do nothing unless polled"]
1617
pub struct Shared<Fut: Future> {
1718
inner: Option<Arc<Inner<Fut>>>,
@@ -43,8 +44,7 @@ impl<Fut: Future> fmt::Debug for Shared<Fut> {
4344

4445
impl<Fut: Future> fmt::Debug for Inner<Fut> {
4546
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
46-
fmt.debug_struct("Inner")
47-
.finish()
47+
fmt.debug_struct("Inner").finish()
4848
}
4949
}
5050

@@ -75,14 +75,16 @@ const NULL_WAKER_KEY: usize = usize::max_value();
7575

7676
impl<Fut: Future> Shared<Fut> {
7777
pub(super) fn new(future: Fut) -> Shared<Fut> {
78+
let inner = Inner {
79+
future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
80+
notifier: Arc::new(Notifier {
81+
state: AtomicUsize::new(IDLE),
82+
wakers: Mutex::new(Some(Slab::new())),
83+
}),
84+
};
85+
7886
Shared {
79-
inner: Some(Arc::new(Inner {
80-
future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
81-
notifier: Arc::new(Notifier {
82-
state: AtomicUsize::new(IDLE),
83-
wakers: Mutex::new(Some(Slab::new())),
84-
}),
85-
})),
87+
inner: Some(Arc::new(inner)),
8688
waker_key: NULL_WAKER_KEY,
8789
}
8890
}
@@ -93,36 +95,37 @@ where
9395
Fut: Future,
9496
Fut::Output: Clone,
9597
{
96-
/// Returns Some containing a reference to this [`Shared`](crate::future::Shared)'s output if it has
97-
/// already been computed by a clone or [`None`](std::option::Option::None) if it hasn't been computed yet
98-
/// or if this [`Shared`](std::option::Option::Some) already returned its output from poll.
98+
/// Returns [`Some`] containing a reference to this [`Shared`]'s output if
99+
/// it has already been computed by a clone or [`None`] if it hasn't been
100+
/// computed yet or this [`Shared`] already returned its output from
101+
/// [`poll`](Future::poll).
99102
pub fn peek(&self) -> Option<&Fut::Output> {
100-
match self.inner.as_ref().map(|inner| inner.notifier.state.load(SeqCst)) {
101-
Some(COMPLETE) => unsafe { Some(self.inner().get_output()) },
102-
Some(POISONED) => panic!("inner future panicked during poll"),
103-
_ => None,
103+
if let Some(inner) = self.inner.as_ref() {
104+
match inner.notifier.state.load(SeqCst) {
105+
COMPLETE => unsafe { return Some(inner.output()) },
106+
POISONED => panic!("inner future panicked during poll"),
107+
_ => {}
108+
}
104109
}
105-
}
106-
107-
fn inner(&self) -> &Arc<Inner<Fut>> {
108-
Self::inner_(&self.inner)
109-
}
110-
111-
fn inner_(inner: &Option<Arc<Inner<Fut>>>) -> &Arc<Inner<Fut>> {
112-
inner.as_ref().expect("Shared future polled again after completion")
110+
None
113111
}
114112

115113
/// Registers the current task to receive a wakeup when `Inner` is awoken.
116114
fn set_waker(&mut self, lw: &LocalWaker) {
117115
// Acquire the lock first before checking COMPLETE to ensure there
118116
// isn't a race.
119-
let mut wakers = Self::inner_(&self.inner).notifier.wakers.lock().unwrap();
120-
let wakers = if let Some(wakers) = wakers.as_mut() {
117+
let mut wakers_guard = if let Some(inner) = self.inner.as_ref() {
118+
inner.notifier.wakers.lock().unwrap()
119+
} else {
120+
return;
121+
};
122+
123+
let wakers = if let Some(wakers) = wakers_guard.as_mut() {
121124
wakers
122125
} else {
123-
// The value is already available, so there's no need to set the waker.
124-
return
126+
return;
125127
};
128+
126129
if self.waker_key == NULL_WAKER_KEY {
127130
self.waker_key = wakers.insert(Some(lw.clone().into_waker()));
128131
} else {
@@ -143,29 +146,27 @@ where
143146

144147
/// Safety: callers must first ensure that `self.inner.state`
145148
/// is `COMPLETE`
146-
unsafe fn take_or_clone_output(inner: Arc<Inner<Fut>>) -> Fut::Output {
149+
unsafe fn take_or_clone_output(&mut self) -> Fut::Output {
150+
let inner = self.inner.take().unwrap();
151+
147152
match Arc::try_unwrap(inner) {
148-
Ok(inner) => {
149-
match inner.future_or_output.into_inner() {
150-
FutureOrOutput::Output(item) => item,
151-
FutureOrOutput::Future(_) => unreachable!(),
152-
}
153-
}
154-
Err(inner) => inner.get_output().clone(),
153+
Ok(inner) => match inner.future_or_output.into_inner() {
154+
FutureOrOutput::Output(item) => item,
155+
FutureOrOutput::Future(_) => unreachable!(),
156+
},
157+
Err(inner) => inner.output().clone(),
155158
}
156159
}
157-
158160
}
159161

160-
161162
impl<Fut> Inner<Fut>
162163
where
163164
Fut: Future,
164165
Fut::Output: Clone,
165166
{
166167
/// Safety: callers must first ensure that `self.inner.state`
167168
/// is `COMPLETE`
168-
unsafe fn get_output(&self) -> &Fut::Output {
169+
unsafe fn output(&self) -> &Fut::Output {
169170
match &*self.future_or_output.get() {
170171
FutureOrOutput::Output(ref item) => &item,
171172
FutureOrOutput::Future(_) => unreachable!(),
@@ -182,30 +183,33 @@ where
182183
fn poll(mut self: Pin<&mut Self>, lw: &LocalWaker) -> Poll<Self::Output> {
183184
let this = &mut *self;
184185

185-
// Assert that we aren't completed
186-
this.inner();
187-
188186
this.set_waker(lw);
189187

190-
match this.inner().notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) {
188+
let inner = if let Some(inner) = this.inner.as_ref() {
189+
inner
190+
} else {
191+
panic!("Shared future polled again after completion");
192+
};
193+
194+
match inner.notifier.state.compare_and_swap(IDLE, POLLING, SeqCst) {
191195
IDLE => {
192196
// Lock acquired, fall through
193197
}
194198
POLLING | REPOLL => {
195199
// Another task is currently polling, at this point we just want
196-
// to ensure that our task handle is currently registered
200+
// to ensure that the waker for this task is registered
197201

198202
return Poll::Pending;
199203
}
200204
COMPLETE => {
201-
let inner = self.inner.take().unwrap();
202-
return unsafe { Poll::Ready(Self::take_or_clone_output(inner)) };
205+
// Safety: We're in the COMPLETE state
206+
return unsafe { Poll::Ready(this.take_or_clone_output()) };
203207
}
204208
POISONED => panic!("inner future panicked during poll"),
205209
_ => unreachable!(),
206210
}
207211

208-
let waker = local_waker_from_nonlocal(this.inner().notifier.clone());
212+
let waker = local_waker_ref_from_nonlocal(&inner.notifier);
209213
let lw = &waker;
210214

211215
struct Reset<'a>(&'a AtomicUsize);
@@ -220,36 +224,29 @@ where
220224
}
221225
}
222226

227+
let _reset = Reset(&inner.notifier.state);
223228

224229
let output = loop {
225-
let inner = this.inner();
226-
let _reset = Reset(&inner.notifier.state);
227-
228-
// Poll the future
229-
let res = unsafe {
230-
if let FutureOrOutput::Future(future) =
231-
&mut *inner.future_or_output.get()
232-
{
233-
Pin::new_unchecked(future).poll(lw)
234-
} else {
235-
unreachable!()
230+
let future = unsafe {
231+
match &mut *inner.future_or_output.get() {
232+
FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
233+
_ => unreachable!(),
236234
}
237235
};
238-
match res {
236+
237+
let poll = future.poll(&lw);
238+
239+
match poll {
239240
Poll::Pending => {
240-
// Not ready, try to release the handle
241-
match inner
242-
.notifier
243-
.state
244-
.compare_and_swap(POLLING, IDLE, SeqCst)
245-
{
241+
let state = &inner.notifier.state;
242+
match state.compare_and_swap(POLLING, IDLE, SeqCst) {
246243
POLLING => {
247244
// Success
248245
return Poll::Pending;
249246
}
250247
REPOLL => {
251-
// Gotta poll again!
252-
let prev = inner.notifier.state.swap(POLLING, SeqCst);
248+
// Was woken since: Gotta poll again!
249+
let prev = state.swap(POLLING, SeqCst);
253250
assert_eq!(prev, REPOLL);
254251
}
255252
_ => unreachable!(),
@@ -259,31 +256,27 @@ where
259256
}
260257
};
261258

262-
if Arc::get_mut(this.inner.as_mut().unwrap()).is_some() {
263-
this.inner.take();
264-
return Poll::Ready(output);
265-
}
266-
267-
let inner = this.inner();
268-
269-
let _reset = Reset(&inner.notifier.state);
270-
271259
unsafe {
272260
*inner.future_or_output.get() =
273-
FutureOrOutput::Output(output.clone());
261+
FutureOrOutput::Output(output);
274262
}
275263

276-
// Complete the future
277-
let mut lock = inner.notifier.wakers.lock().unwrap();
278264
inner.notifier.state.store(COMPLETE, SeqCst);
279-
let wakers = &mut lock.take().unwrap();
265+
266+
// Wake all tasks and drop the slab
267+
let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
268+
let wakers = &mut wakers_guard.take().unwrap();
280269
for (_key, opt_waker) in wakers {
281270
if let Some(waker) = opt_waker.take() {
282271
waker.wake();
283272
}
284273
}
285274

286-
Poll::Ready(output)
275+
drop(_reset); // Make borrow checker happy
276+
drop(wakers_guard);
277+
278+
// Safety: We're in the COMPLETE state
279+
unsafe { Poll::Ready(this.take_or_clone_output()) }
287280
}
288281
}
289282

0 commit comments

Comments
 (0)