Skip to content

Commit 6027978

Browse files
committed
Desugar for await loops
1 parent fb7192d commit 6027978

File tree

7 files changed

+87
-15
lines changed

7 files changed

+87
-15
lines changed

compiler/rustc_ast/src/mut_visit.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1388,7 +1388,7 @@ pub fn noop_visit_expr<T: MutVisitor>(
13881388
vis.visit_block(body);
13891389
visit_opt(label, |label| vis.visit_label(label));
13901390
}
1391-
ExprKind::ForLoop{pat, iter, body, label, is_await: _} => {
1391+
ExprKind::ForLoop { pat, iter, body, label, is_await: _ } => {
13921392
vis.visit_pat(pat);
13931393
vis.visit_expr(iter);
13941394
vis.visit_block(body);

compiler/rustc_ast_lowering/src/expr.rs

+48-13
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
331331
),
332332
ExprKind::Try(sub_expr) => self.lower_expr_try(e.span, sub_expr),
333333

334-
ExprKind::Paren(_) | ExprKind::ForLoop{..} => {
334+
ExprKind::Paren(_) | ExprKind::ForLoop { .. } => {
335335
unreachable!("already handled")
336336
}
337337

@@ -784,6 +784,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
784784
/// }
785785
/// ```
786786
fn lower_expr_await(&mut self, await_kw_span: Span, expr: &Expr) -> hir::ExprKind<'hir> {
787+
let expr = self.lower_expr_mut(expr);
788+
self.make_lowered_await(await_kw_span, &expr)
789+
}
790+
791+
/// Takes an expr that has already been lowered and generates a desugared await loop around it
792+
fn make_lowered_await(
793+
&mut self,
794+
await_kw_span: Span,
795+
expr: &hir::Expr<'hir>,
796+
) -> hir::ExprKind<'hir> {
787797
let full_span = expr.span.to(await_kw_span);
788798
match self.coroutine_kind {
789799
Some(hir::CoroutineKind::Async(_)) => {}
@@ -800,7 +810,6 @@ impl<'hir> LoweringContext<'_, 'hir> {
800810
full_span,
801811
Some(self.allow_gen_future.clone()),
802812
);
803-
let expr = self.lower_expr_mut(expr);
804813
let expr_hir_id = expr.hir_id;
805814

806815
// Note that the name of this binding must not be changed to something else because
@@ -924,7 +933,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
924933
let into_future_expr = self.expr_call_lang_item_fn(
925934
span,
926935
hir::LangItem::IntoFutureIntoFuture,
927-
arena_vec![self; expr],
936+
arena_vec![self; *expr],
928937
);
929938

930939
// match <into_future_expr> {
@@ -1554,7 +1563,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
15541563
head: &Expr,
15551564
body: &Block,
15561565
opt_label: Option<Label>,
1557-
_is_await: bool,
1566+
is_await: bool,
15581567
) -> hir::Expr<'hir> {
15591568
let head = self.lower_expr_mut(head);
15601569
let pat = self.lower_pat(pat);
@@ -1583,15 +1592,33 @@ impl<'hir> LoweringContext<'_, 'hir> {
15831592
let (iter_pat, iter_pat_nid) =
15841593
self.pat_ident_binding_mode(head_span, iter, hir::BindingAnnotation::MUT);
15851594

1586-
// `match Iterator::next(&mut iter) { ... }`
15871595
let match_expr = {
15881596
let iter = self.expr_ident(head_span, iter, iter_pat_nid);
1589-
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
1590-
let next_expr = self.expr_call_lang_item_fn(
1591-
head_span,
1592-
hir::LangItem::IteratorNext,
1593-
arena_vec![self; ref_mut_iter],
1594-
);
1597+
let next_expr = if is_await {
1598+
// `async_iter_next(unsafe { Pin::new_unchecked(&mut iter) }).await`
1599+
let future = self.expr_mut_addr_of(head_span, iter);
1600+
let future = self.arena.alloc(self.expr_call_lang_item_fn_mut(
1601+
head_span,
1602+
hir::LangItem::PinNewUnchecked,
1603+
arena_vec![self; future],
1604+
));
1605+
let future = self.expr_unsafe(future);
1606+
let future = self.expr_call_lang_item_fn(
1607+
head_span,
1608+
hir::LangItem::AsyncIterNext,
1609+
arena_vec![self; future],
1610+
);
1611+
let kind = self.make_lowered_await(head_span, future);
1612+
self.arena.alloc(hir::Expr { hir_id: self.next_id(), kind, span: head_span })
1613+
} else {
1614+
// `match Iterator::next(&mut iter) { ... }`
1615+
let ref_mut_iter = self.expr_mut_addr_of(head_span, iter);
1616+
self.expr_call_lang_item_fn(
1617+
head_span,
1618+
hir::LangItem::IteratorNext,
1619+
arena_vec![self; ref_mut_iter],
1620+
)
1621+
};
15951622
let arms = arena_vec![self; none_arm, some_arm];
15961623

15971624
self.expr_match(head_span, next_expr, arms, hir::MatchSource::ForLoopDesugar)
@@ -1613,8 +1640,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
16131640
// `mut iter => { ... }`
16141641
let iter_arm = self.arm(iter_pat, loop_expr);
16151642

1616-
// `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1617-
let into_iter_expr = {
1643+
let into_iter_expr = if is_await {
1644+
// `::core::async_iter::IntoAsyncIterator::into_async_iter(<head>)`
1645+
let iter = self.expr_call_lang_item_fn(
1646+
head_span,
1647+
hir::LangItem::IntoAsyncIterIntoIter,
1648+
arena_vec![self; head],
1649+
);
1650+
self.arena.alloc(self.expr_mut_addr_of(head_span, iter))
1651+
} else {
1652+
// `::std::iter::IntoIterator::into_iter(<head>)`
16181653
self.expr_call_lang_item_fn(
16191654
head_span,
16201655
hir::LangItem::IntoIterIntoIter,

compiler/rustc_builtin_macros/src/assert/context.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ impl<'cx, 'a> Context<'cx, 'a> {
303303
| ExprKind::Continue(_)
304304
| ExprKind::Err
305305
| ExprKind::Field(_, _)
306-
| ExprKind::ForLoop {..}
306+
| ExprKind::ForLoop { .. }
307307
| ExprKind::FormatArgs(_)
308308
| ExprKind::IncludedBytes(..)
309309
| ExprKind::InlineAsm(_)

compiler/rustc_hir/src/lang_items.rs

+1
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,7 @@ language_item_table! {
303303
FuturePoll, sym::poll, future_poll_fn, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::None;
304304

305305
AsyncIterNext, sym::async_iter_next, async_iter_next, Target::Fn, GenericRequirement::Exact(2);
306+
IntoAsyncIterIntoIter, sym::into_async_iter_into_iter, into_async_iter_into_iter, Target::Method(MethodKind::Trait { body: false }), GenericRequirement::Exact(1);
306307

307308
Option, sym::Option, option_type, Target::Enum, GenericRequirement::None;
308309
OptionSome, sym::Some, option_some_variant, Target::Variant, GenericRequirement::None;

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ symbols! {
889889
instruction_set,
890890
integer_: "integer", // underscore to avoid clashing with the function `sym::integer` below
891891
integral,
892+
into_async_iter_into_iter,
892893
into_future,
893894
into_iter,
894895
intra_doc_pointers,

library/core/src/async_iter/async_iter.rs

+1
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ pub trait IntoAsyncIterator {
119119
type IntoAsyncIter: AsyncIterator<Item = Self::Item>;
120120

121121
/// Converts `self` into an async iterator
122+
#[cfg_attr(not(bootstrap), lang = "into_async_iter_into_iter")]
122123
fn into_async_iter(self) -> Self::IntoAsyncIter;
123124
}
124125

tests/ui/async-await/for-await.rs

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// run-pass
2+
// edition: 2021
3+
#![feature(async_iterator, async_iter_from_iter, const_waker)]
4+
5+
use std::future::Future;
6+
7+
// make sure a simple for await loop works
8+
async fn real_main() {
9+
let iter = core::async_iter::from_iter(0..3);
10+
let mut count = 0;
11+
for await i in iter {
12+
assert_eq!(i, count);
13+
count += 1;
14+
}
15+
assert_eq!(count, 3);
16+
}
17+
18+
fn main() {
19+
let future = real_main();
20+
let waker = noop_waker::NOOP_WAKER;
21+
let mut cx = &mut core::task::Context::from_waker(&waker);
22+
let mut future = core::pin::pin!(future);
23+
while let core::task::Poll::Pending = future.as_mut().poll(&mut cx) {}
24+
}
25+
26+
mod noop_waker {
27+
use std::task::{RawWaker, RawWakerVTable, Waker};
28+
29+
const VTABLE: RawWakerVTable =
30+
RawWakerVTable::new(|_| RawWaker::new(core::ptr::null(), &VTABLE), |_| (), |_| (), |_| ());
31+
32+
pub(super) const NOOP_WAKER: Waker =
33+
unsafe { Waker::from_raw(RawWaker::new(core::ptr::null(), &VTABLE)) };
34+
}

0 commit comments

Comments
 (0)