|
| 1 | +use super::Zip as ZipTrait; |
| 2 | +use crate::stream::IntoStream; |
| 3 | +use crate::utils::{self, PollState, WakerList}; |
1 | 4 |
|
| 5 | +use core::array; |
| 6 | +use core::fmt; |
| 7 | +use core::mem::MaybeUninit; |
| 8 | +use core::pin::Pin; |
| 9 | +use core::task::{Context, Poll}; |
| 10 | +use std::mem; |
| 11 | + |
| 12 | +use futures_core::Stream; |
| 13 | +use pin_project::{pin_project, pinned_drop}; |
| 14 | + |
| 15 | +/// ‘Zips up’ two streams into a single stream of pairs. |
| 16 | +/// |
| 17 | +/// This `struct` is created by the [`merge`] method on the [`Zip`] trait. See its |
| 18 | +/// documentation for more. |
| 19 | +/// |
| 20 | +/// [`zip`]: trait.Zip.html#method.zip |
| 21 | +/// [`Zip`]: trait.Zip.html |
| 22 | +#[pin_project(PinnedDrop)] |
| 23 | +pub struct Zip<S, const N: usize> |
| 24 | +where |
| 25 | + S: Stream, |
| 26 | +{ |
| 27 | + #[pin] |
| 28 | + streams: [S; N], |
| 29 | + output: [MaybeUninit<<S as Stream>::Item>; N], |
| 30 | + wakers: WakerList, |
| 31 | + state: [PollState; N], |
| 32 | + done: bool, |
| 33 | +} |
| 34 | + |
| 35 | +impl<S, const N: usize> Zip<S, N> |
| 36 | +where |
| 37 | + S: Stream, |
| 38 | +{ |
| 39 | + pub(crate) fn new(streams: [S; N]) -> Self { |
| 40 | + Self { |
| 41 | + streams, |
| 42 | + output: array::from_fn(|_| MaybeUninit::uninit()), |
| 43 | + state: array::from_fn(|_| PollState::default()), |
| 44 | + wakers: WakerList::new(N), |
| 45 | + done: false, |
| 46 | + } |
| 47 | + } |
| 48 | +} |
| 49 | + |
| 50 | +impl<S, const N: usize> fmt::Debug for Zip<S, N> |
| 51 | +where |
| 52 | + S: Stream + fmt::Debug, |
| 53 | +{ |
| 54 | + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |
| 55 | + f.debug_list().entries(self.streams.iter()).finish() |
| 56 | + } |
| 57 | +} |
| 58 | + |
| 59 | +impl<S, const N: usize> Stream for Zip<S, N> |
| 60 | +where |
| 61 | + S: Stream, |
| 62 | +{ |
| 63 | + type Item = [S::Item; N]; |
| 64 | + |
| 65 | + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { |
| 66 | + let mut this = self.project(); |
| 67 | + |
| 68 | + assert!(!*this.done, "Stream should not be polled after completion"); |
| 69 | + |
| 70 | + let mut readiness = this.wakers.readiness().lock().unwrap(); |
| 71 | + readiness.set_waker(cx.waker()); |
| 72 | + for index in 0..N { |
| 73 | + if !readiness.any_ready() { |
| 74 | + // Nothing is ready yet |
| 75 | + return Poll::Pending; |
| 76 | + } else if this.state[index].is_done() { |
| 77 | + // We already have data stored for this stream |
| 78 | + continue; |
| 79 | + } else if !readiness.clear_ready(index) { |
| 80 | + // This waker isn't ready yet |
| 81 | + continue; |
| 82 | + } |
| 83 | + |
| 84 | + // unlock readiness so we don't deadlock when polling |
| 85 | + drop(readiness); |
| 86 | + |
| 87 | + // Obtain the intermediate waker. |
| 88 | + let mut cx = Context::from_waker(this.wakers.get(index).unwrap()); |
| 89 | + |
| 90 | + let stream = utils::get_pin_mut(this.streams.as_mut(), index).unwrap(); |
| 91 | + match stream.poll_next(&mut cx) { |
| 92 | + Poll::Ready(Some(item)) => { |
| 93 | + this.output[index] = MaybeUninit::new(item); |
| 94 | + this.state[index] = PollState::Done; |
| 95 | + |
| 96 | + let all_ready = this.state.iter().all(|state| state.is_done()); |
| 97 | + if all_ready { |
| 98 | + // Reset the future's state. |
| 99 | + readiness = this.wakers.readiness().lock().unwrap(); |
| 100 | + readiness.set_all_ready(); |
| 101 | + this.state.fill(PollState::Pending); |
| 102 | + |
| 103 | + // Take the output |
| 104 | + // |
| 105 | + // SAFETY: we just validated all our data is populated, meaning |
| 106 | + // we can assume this is initialized. |
| 107 | + let mut output = array::from_fn(|_| MaybeUninit::uninit()); |
| 108 | + mem::swap(this.output, &mut output); |
| 109 | + let output = unsafe { array_assume_init(output) }; |
| 110 | + return Poll::Ready(Some(output)); |
| 111 | + } |
| 112 | + } |
| 113 | + Poll::Ready(None) => { |
| 114 | + // If one stream returns `None`, we can no longer return |
| 115 | + // pairs - meaning the stream is over. |
| 116 | + *this.done = true; |
| 117 | + return Poll::Ready(None); |
| 118 | + } |
| 119 | + Poll::Pending => {} |
| 120 | + } |
| 121 | + |
| 122 | + // Lock readiness so we can use it again |
| 123 | + readiness = this.wakers.readiness().lock().unwrap(); |
| 124 | + } |
| 125 | + Poll::Pending |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +/// Drop the already initialized values on cancellation. |
| 130 | +#[pinned_drop] |
| 131 | +impl<S, const N: usize> PinnedDrop for Zip<S, N> |
| 132 | +where |
| 133 | + S: Stream, |
| 134 | +{ |
| 135 | + fn drop(self: Pin<&mut Self>) { |
| 136 | + let this = self.project(); |
| 137 | + |
| 138 | + for (state, output) in this.state.iter_mut().zip(this.output.iter_mut()) { |
| 139 | + if state.is_done() { |
| 140 | + // SAFETY: we've just filtered down to *only* the initialized values. |
| 141 | + // We can assume they're initialized, and this is where we drop them. |
| 142 | + unsafe { output.assume_init_drop() }; |
| 143 | + } |
| 144 | + } |
| 145 | + } |
| 146 | +} |
| 147 | + |
| 148 | +impl<S, const N: usize> ZipTrait for [S; N] |
| 149 | +where |
| 150 | + S: IntoStream, |
| 151 | +{ |
| 152 | + type Item = <Zip<S::IntoStream, N> as Stream>::Item; |
| 153 | + type Stream = Zip<S::IntoStream, N>; |
| 154 | + |
| 155 | + fn zip(self) -> Self::Stream { |
| 156 | + Zip::new(self.map(|i| i.into_stream())) |
| 157 | + } |
| 158 | +} |
| 159 | + |
| 160 | +#[cfg(test)] |
| 161 | +mod tests { |
| 162 | + use crate::stream::Zip; |
| 163 | + use futures_lite::future::block_on; |
| 164 | + use futures_lite::prelude::*; |
| 165 | + use futures_lite::stream; |
| 166 | + |
| 167 | + #[test] |
| 168 | + fn zip_array_3() { |
| 169 | + block_on(async { |
| 170 | + let a = stream::repeat(1).take(2); |
| 171 | + let b = stream::repeat(2).take(2); |
| 172 | + let c = stream::repeat(3).take(2); |
| 173 | + let mut s = Zip::zip([a, b, c]); |
| 174 | + |
| 175 | + assert_eq!(s.next().await, Some([1, 2, 3])); |
| 176 | + assert_eq!(s.next().await, Some([1, 2, 3])); |
| 177 | + assert_eq!(s.next().await, None); |
| 178 | + }) |
| 179 | + } |
| 180 | +} |
| 181 | + |
| 182 | +// Inlined version of the unstable `MaybeUninit::array_assume_init` feature. |
| 183 | +// FIXME: replace with `utils::array_assume_init` |
| 184 | +unsafe fn array_assume_init<T, const N: usize>(array: [MaybeUninit<T>; N]) -> [T; N] { |
| 185 | + // SAFETY: |
| 186 | + // * The caller guarantees that all elements of the array are initialized |
| 187 | + // * `MaybeUninit<T>` and T are guaranteed to have the same layout |
| 188 | + // * `MaybeUninit` does not drop, so there are no double-frees |
| 189 | + // And thus the conversion is safe |
| 190 | + let ret = unsafe { (&array as *const _ as *const [T; N]).read() }; |
| 191 | + mem::forget(array); |
| 192 | + ret |
| 193 | +} |
0 commit comments