Skip to content

Commit cc522cc

Browse files
committed
Desugar for await loops
1 parent 2adf5fa commit cc522cc

File tree

8 files changed

+146
-30
lines changed

8 files changed

+146
-30
lines changed

compiler/rustc_ast_lowering/src/expr.rs

+100-28
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337337
),
338338
ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr),
339339

340-
ExprKind::Paren(_) | ExprKind::ForLoop{..} => {
340+
ExprKind::Paren(_) | ExprKind::ForLoop { .. } => {
341341
unreachable!("already handled")
342342
}
343343

@@ -871,6 +871,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
871871
/// }
872872
/// ```
873873
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
874+
let expr = self.arena.alloc(self.lower_expr_mut(expr));
875+
self.make_lowered_await(await_kw_span, expr, FutureKind::Future)
876+
}
877+
878+
/// Takes an expr that has already been lowered and generates a desugared await loop around it
879+
fn make_lowered_await(
880+
&mut self,
881+
await_kw_span: Span,
882+
expr: &'hir hir::Expr<'hir>,
883+
await_kind: FutureKind,
884+
) -> hir::ExprKind<'hir> {
874885
let full_span = expr.span.to(await_kw_span);
875886

876887
let is_async_gen = match self.coroutine_kind {
@@ -884,13 +895,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
884895
}
885896
};
886897

887-
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, None);
898+
let features = match await_kind {
899+
FutureKind::Future => None,
900+
FutureKind::AsyncIterator => Some(self.allow_for_await.clone()),
901+
};
902+
let span = self.mark_span_with_reason(DesugaringKind::Await, await_kw_span, features);
888903
let gen_future_span = self.mark_span_with_reason(
889904
DesugaringKind::Await,
890905
full_span,
891906
Some(self.allow_gen_future.clone()),
892907
);
893-
let expr = self.lower_expr_mut(expr);
894908
let expr_hir_id = expr.hir_id;
895909

896910
// Note that the name of this binding must not be changed to something else because
@@ -930,11 +944,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
930944
hir::LangItem::GetContext,
931945
arena_vec![self; task_context],
932946
);
933-
let call = self.expr_call_lang_item_fn(
934-
span,
935-
hir::LangItem::FuturePoll,
936-
arena_vec![self; new_unchecked, get_context],
937-
);
947+
let call = match await_kind {
948+
FutureKind::Future => self.expr_call_lang_item_fn(
949+
span,
950+
hir::LangItem::FuturePoll,
951+
arena_vec![self; new_unchecked, get_context],
952+
),
953+
FutureKind::AsyncIterator => self.expr_call_lang_item_fn(
954+
span,
955+
hir::LangItem::AsyncIteratorPollNext,
956+
arena_vec![self; new_unchecked, get_context],
957+
),
958+
};
938959
self.arena.alloc(self.expr_unsafe(call))
939960
};
940961

@@ -1018,11 +1039,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
10181039
let awaitee_arm = self.arm(awaitee_pat, loop_expr);
10191040

10201041
// `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1021-
let into_future_expr = self.expr_call_lang_item_fn(
1022-
span,
1023-
hir::LangItem::IntoFutureIntoFuture,
1024-
arena_vec![self; expr],
1025-
);
1042+
let into_future_expr = match await_kind {
1043+
FutureKind::Future => self.expr_call_lang_item_fn(
1044+
span,
1045+
hir::LangItem::IntoFutureIntoFuture,
1046+
arena_vec![self; *expr],
1047+
),
1048+
// Not needed for `for await` because we expect to have already called
1049+
// `IntoAsyncIterator::into_async_iter` on it.
1050+
FutureKind::AsyncIterator => expr,
1051+
};
10261052

10271053
// match <into_future_expr> {
10281054
// mut __awaitee => loop { .. }
@@ -1670,7 +1696,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
16701696
head: &Expr,
16711697
body: &Block,
16721698
opt_label: Option<Label>,
1673-
_loop_kind: ForLoopKind,
1699+
loop_kind: ForLoopKind,
16741700
) -> hir::Expr<'hir> {
16751701
let head = self.lower_expr_mut(head);
16761702
let pat = self.lower_pat(pat);
@@ -1699,17 +1725,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
16991725
let (iter_pat, iter_pat_nid) =
17001726
self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT);
17011727

1702-
// `match Iterator::next(&mut iter) { ... }`
17031728
let match_expr = {
17041729
let iter = self.expr_ident(head_span, iter, iter_pat_nid);
1705-
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
1706-
let next_expr = self.expr_call_lang_item_fn(
1707-
head_span,
1708-
hir::LangItem::IteratorNext,
1709-
arena_vec![self; ref_mut_iter],
1710-
);
1730+
let next_expr = match loop_kind {
1731+
ForLoopKind::For => {
1732+
// `Iterator::next(&mut iter)`
1733+
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
1734+
self.expr_call_lang_item_fn(
1735+
head_span,
1736+
hir::LangItem::IteratorNext,
1737+
arena_vec![self; ref_mut_iter],
1738+
)
1739+
}
1740+
ForLoopKind::ForAwait => {
1741+
// we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1742+
// to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1743+
// calls to `poll_next`. In user code, this would probably be a call to
1744+
// `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1745+
1746+
// `&mut iter`
1747+
let iter = self.expr_mut_addr_of(head_span, iter);
1748+
// `Pin::new_unchecked(...)`
1749+
let iter = self.arena.alloc(self.expr_call_lang_item_fn_mut(
1750+
head_span,
1751+
hir::LangItem::PinNewUnchecked,
1752+
arena_vec![self; iter],
1753+
));
1754+
// `unsafe { ... }`
1755+
let iter = self.arena.alloc(self.expr_unsafe(iter));
1756+
let kind = self.make_lowered_await(head_span, iter, FutureKind::AsyncIterator);
1757+
self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
1758+
}
1759+
};
17111760
let arms = arena_vec![self; none_arm, some_arm];
17121761

1762+
// `match $next_expr { ... }`
17131763
self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar)
17141764
};
17151765
let match_stmt = self.stmt_expr(for_span, match_expr);
@@ -1729,13 +1779,24 @@ impl<'hir> LoweringContext<'_, 'hir> {
17291779
// `mut iter => { ... }`
17301780
let iter_arm = self.arm(iter_pat, loop_expr);
17311781

1732-
// `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1733-
let into_iter_expr = {
1734-
self.expr_call_lang_item_fn(
1735-
head_span,
1736-
hir::LangItem::IntoIterIntoIter,
1737-
arena_vec![self; head],
1738-
)
1782+
let into_iter_expr = match loop_kind {
1783+
ForLoopKind::For => {
1784+
// `::std::iter::IntoIterator::into_iter(<head>)`
1785+
self.expr_call_lang_item_fn(
1786+
head_span,
1787+
hir::LangItem::IntoIterIntoIter,
1788+
arena_vec![self; head],
1789+
)
1790+
}
1791+
ForLoopKind::ForAwait => {
1792+
// `::core::async_iter::IntoAsyncIterator::into_async_iter(<head>)`
1793+
let iter = self.expr_call_lang_item_fn(
1794+
head_span,
1795+
hir::LangItem::IntoAsyncIterIntoIter,
1796+
arena_vec![self; head],
1797+
);
1798+
self.arena.alloc(self.expr_mut_addr_of(head_span, iter))
1799+
}
17391800
};
17401801

17411802
let match_expr = self.arena.alloc(self.expr_match(
@@ -2138,3 +2199,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
21382199
}
21392200
}
21402201
}
2202+
2203+
/// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
2204+
/// of future we are awaiting.
2205+
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
2206+
enum FutureKind {
2207+
/// We are awaiting a normal future
2208+
Future,
2209+
/// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2210+
/// a `for await` loop)
2211+
AsyncIterator,
2212+
}

compiler/rustc_ast_lowering/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ struct LoweringContext<'a, 'hir> {
132132
allow_try_trait: Lrc<[Symbol]>,
133133
allow_gen_future: Lrc<[Symbol]>,
134134
allow_async_iterator: Lrc<[Symbol]>,
135+
allow_for_await: Lrc<[Symbol]>,
135136

136137
/// Mapping from generics `def_id`s to TAIT generics `def_id`s.
137138
/// For each captured lifetime (e.g., 'a), we create a new lifetime parameter that is a generic
@@ -176,6 +177,7 @@ impl<'a, 'hir> LoweringContext<'a, 'hir> {
176177
} else {
177178
[sym::gen_future].into()
178179
},
180+
allow_for_await: [sym::async_iterator].into(),
179181
// FIXME(gen_blocks): how does `closure_track_caller`/`async_fn_track_caller`
180182
// interact with `gen`/`async gen` blocks
181183
allow_async_iterator: [sym::gen_future, sym::async_iterator].into(),

compiler/rustc_builtin_macros/src/assert/context.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> {
303303
| ExprKind::Continue(_)
304304
| ExprKind::Err
305305
| ExprKind::Field(_, _)
306-
| ExprKind::ForLoop {..}
306+
| ExprKind::ForLoop { .. }
307307
| ExprKind::FormatArgs(_)
308308
| ExprKind::IncludedBytes(..)
309309
| ExprKind::InlineAsm(_)

compiler/rustc_feature/src/unstable.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ declare_features! (
358358
/// Allows `#[track_caller]` on async functions.
359359
(unstable, async_fn_track_caller, "1.73.0", Some(110011)),
360360
/// Allows `for await` loops.
361-
(unstable, async_for_loop, "CURRENT_RUSTC_VERSION", None),
361+
(unstable, async_for_loop, "CURRENT_RUSTC_VERSION", Some(118898)),
362362
/// Allows builtin # foo() syntax
363363
(unstable, builtin_syntax, "1.71.0", Some(110680)),
364364
/// Treat `extern "C"` function as nounwind.

compiler/rustc_hir/src/lang_items.rs

+3
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,9 @@ language_item_table! {
307307
Context, sym::Context, context, Target::Struct, GenericRequirement::None;
308308
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
309309

310+
AsyncIteratorPollNext, sym::async_iterator_poll_next, async_iterator_poll_next, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(1);
311+
IntoAsyncIterIntoIter, sym::into_async_iter_into_iter, into_async_iter_into_iter, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(1);
312+
310313
Option, sym::Option, option_type, Target::Enum, GenericRequirement::None;
311314
OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None;
312315
OptionNone, sym::None, option_none_variant, Target::Variant, GenericRequirement::None;

compiler/rustc_span/src/symbol.rs

+3
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,9 @@ symbols! {
426426
async_closure,
427427
async_fn_in_trait,
428428
async_fn_track_caller,
429+
async_for_loop,
429430
async_iterator,
431+
async_iterator_poll_next,
430432
atomic,
431433
atomic_mod,
432434
atomics,
@@ -893,6 +895,7 @@ symbols! {
893895
instruction_set,
894896
integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below
895897
integral,
898+
into_async_iter_into_iter,
896899
into_future,
897900
into_iter,
898901
intra_doc_pointers,

library/core/src/async_iter/async_iter.rs

+2
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub trait AsyncIterator {
4747
/// Rust's usual rules apply: calls must never cause undefined behavior
4848
/// (memory corruption, incorrect use of `unsafe` functions, or the like),
4949
/// regardless of the async iterator's state.
50+
#[cfg_attr(not(bootstrap), lang = "async_iterator_poll_next")]
5051
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>;
5152

5253
/// Returns the bounds on the remaining length of the async iterator.
@@ -144,6 +145,7 @@ pub trait IntoAsyncIterator {
144145
type IntoAsyncIter: AsyncIterator<Item = Self::Item>;
145146

146147
/// Converts `self` into an async iterator
148+
#[cfg_attr(not(bootstrap), lang = "into_async_iter_into_iter")]
147149
fn into_async_iter(self) -> Self::IntoAsyncIter;
148150
}
149151

tests/ui/async-await/for-await.rs

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// run-pass
2+
// edition: 2021
3+
#![feature(async_iterator, async_iter_from_iter, const_waker, async_for_loop)]
4+
5+
use std::future::Future;
6+
7+
// make sure a simple for await loop works
8+
async fn real_main() {
9+
let iter = core::async_iter::from_iter(0..3);
10+
let mut count = 0;
11+
for await i in iter {
12+
assert_eq!(i, count);
13+
count += 1;
14+
}
15+
assert_eq!(count, 3);
16+
}
17+
18+
fn main() {
19+
let future = real_main();
20+
let waker = noop_waker::NOOP_WAKER;
21+
let mut cx = &mut core::task::Context::from_waker(&waker);
22+
let mut future = core::pin::pin!(future);
23+
while let core::task::Poll::Pending = future.as_mut().poll(&mut cx) {}
24+
}
25+
26+
mod noop_waker {
27+
use std::task::{RawWaker, RawWakerVTable, Waker};
28+
29+
const VTABLE: RawWakerVTable =
30+
RawWakerVTable::new(|_| RawWaker::new(core::ptr::null(), &VTABLE), |_| (), |_| (), |_| ());
31+
32+
pub(super) const NOOP_WAKER: Waker =
33+
unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &VTABLE)) };
34+
}

0 commit comments

Comments
 (0)