@@ -337,7 +337,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
337
337
) ,
338
338
ExprKind :: Try ( sub_expr) => self . lower_expr_try ( e. span , sub_expr) ,
339
339
340
- ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
340
+ ExprKind :: Paren ( _) | ExprKind :: ForLoop { .. } => {
341
341
unreachable ! ( "already handled" )
342
342
}
343
343
@@ -871,6 +871,17 @@ impl<'hir> LoweringContext<'_, 'hir> {
871
871
/// }
872
872
/// ```
873
873
fn lower_expr_await ( & mut self , await_kw_span : Span , expr : & Expr ) -> hir:: ExprKind < ' hir > {
874
+ let expr = self . arena . alloc ( self . lower_expr_mut ( expr) ) ;
875
+ self . make_lowered_await ( await_kw_span, expr, FutureKind :: Future )
876
+ }
877
+
878
+ /// Takes an expr that has already been lowered and generates a desugared await loop around it
879
+ fn make_lowered_await (
880
+ & mut self ,
881
+ await_kw_span : Span ,
882
+ expr : & ' hir hir:: Expr < ' hir > ,
883
+ await_kind : FutureKind ,
884
+ ) -> hir:: ExprKind < ' hir > {
874
885
let full_span = expr. span . to ( await_kw_span) ;
875
886
876
887
let is_async_gen = match self . coroutine_kind {
@@ -884,13 +895,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
884
895
}
885
896
} ;
886
897
887
- let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, None ) ;
898
+ let features = match await_kind {
899
+ FutureKind :: Future => None ,
900
+ FutureKind :: AsyncIterator => Some ( self . allow_for_await . clone ( ) ) ,
901
+ } ;
902
+ let span = self . mark_span_with_reason ( DesugaringKind :: Await , await_kw_span, features) ;
888
903
let gen_future_span = self . mark_span_with_reason (
889
904
DesugaringKind :: Await ,
890
905
full_span,
891
906
Some ( self . allow_gen_future . clone ( ) ) ,
892
907
) ;
893
- let expr = self . lower_expr_mut ( expr) ;
894
908
let expr_hir_id = expr. hir_id ;
895
909
896
910
// Note that the name of this binding must not be changed to something else because
@@ -930,11 +944,18 @@ impl<'hir> LoweringContext<'_, 'hir> {
930
944
hir:: LangItem :: GetContext ,
931
945
arena_vec ! [ self ; task_context] ,
932
946
) ;
933
- let call = self . expr_call_lang_item_fn (
934
- span,
935
- hir:: LangItem :: FuturePoll ,
936
- arena_vec ! [ self ; new_unchecked, get_context] ,
937
- ) ;
947
+ let call = match await_kind {
948
+ FutureKind :: Future => self . expr_call_lang_item_fn (
949
+ span,
950
+ hir:: LangItem :: FuturePoll ,
951
+ arena_vec ! [ self ; new_unchecked, get_context] ,
952
+ ) ,
953
+ FutureKind :: AsyncIterator => self . expr_call_lang_item_fn (
954
+ span,
955
+ hir:: LangItem :: AsyncIteratorPollNext ,
956
+ arena_vec ! [ self ; new_unchecked, get_context] ,
957
+ ) ,
958
+ } ;
938
959
self . arena . alloc ( self . expr_unsafe ( call) )
939
960
} ;
940
961
@@ -1018,11 +1039,16 @@ impl<'hir> LoweringContext<'_, 'hir> {
1018
1039
let awaitee_arm = self . arm ( awaitee_pat, loop_expr) ;
1019
1040
1020
1041
// `match ::std::future::IntoFuture::into_future(<expr>) { ... }`
1021
- let into_future_expr = self . expr_call_lang_item_fn (
1022
- span,
1023
- hir:: LangItem :: IntoFutureIntoFuture ,
1024
- arena_vec ! [ self ; expr] ,
1025
- ) ;
1042
+ let into_future_expr = match await_kind {
1043
+ FutureKind :: Future => self . expr_call_lang_item_fn (
1044
+ span,
1045
+ hir:: LangItem :: IntoFutureIntoFuture ,
1046
+ arena_vec ! [ self ; * expr] ,
1047
+ ) ,
1048
+ // Not needed for `for await` because we expect to have already called
1049
+ // `IntoAsyncIterator::into_async_iter` on it.
1050
+ FutureKind :: AsyncIterator => expr,
1051
+ } ;
1026
1052
1027
1053
// match <into_future_expr> {
1028
1054
// mut __awaitee => loop { .. }
@@ -1670,7 +1696,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
1670
1696
head : & Expr ,
1671
1697
body : & Block ,
1672
1698
opt_label : Option < Label > ,
1673
- _loop_kind : ForLoopKind ,
1699
+ loop_kind : ForLoopKind ,
1674
1700
) -> hir:: Expr < ' hir > {
1675
1701
let head = self . lower_expr_mut ( head) ;
1676
1702
let pat = self . lower_pat ( pat) ;
@@ -1699,17 +1725,41 @@ impl<'hir> LoweringContext<'_, 'hir> {
1699
1725
let ( iter_pat, iter_pat_nid) =
1700
1726
self . pat_ident_binding_mode ( head_span, iter, hir:: BindingAnnotation :: MUT ) ;
1701
1727
1702
- // `match Iterator::next(&mut iter) { ... }`
1703
1728
let match_expr = {
1704
1729
let iter = self . expr_ident ( head_span, iter, iter_pat_nid) ;
1705
- let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1706
- let next_expr = self . expr_call_lang_item_fn (
1707
- head_span,
1708
- hir:: LangItem :: IteratorNext ,
1709
- arena_vec ! [ self ; ref_mut_iter] ,
1710
- ) ;
1730
+ let next_expr = match loop_kind {
1731
+ ForLoopKind :: For => {
1732
+ // `Iterator::next(&mut iter)`
1733
+ let ref_mut_iter = self . expr_mut_addr_of ( head_span, iter) ;
1734
+ self . expr_call_lang_item_fn (
1735
+ head_span,
1736
+ hir:: LangItem :: IteratorNext ,
1737
+ arena_vec ! [ self ; ref_mut_iter] ,
1738
+ )
1739
+ }
1740
+ ForLoopKind :: ForAwait => {
1741
+ // we'll generate `unsafe { Pin::new_unchecked(&mut iter) })` and then pass this
1742
+ // to make_lowered_await with `FutureKind::AsyncIterator` which will generator
1743
+ // calls to `poll_next`. In user code, this would probably be a call to
1744
+ // `Pin::as_mut` but here it's easy enough to do `new_unchecked`.
1745
+
1746
+ // `&mut iter`
1747
+ let iter = self . expr_mut_addr_of ( head_span, iter) ;
1748
+ // `Pin::new_unchecked(...)`
1749
+ let iter = self . arena . alloc ( self . expr_call_lang_item_fn_mut (
1750
+ head_span,
1751
+ hir:: LangItem :: PinNewUnchecked ,
1752
+ arena_vec ! [ self ; iter] ,
1753
+ ) ) ;
1754
+ // `unsafe { ... }`
1755
+ let iter = self . arena . alloc ( self . expr_unsafe ( iter) ) ;
1756
+ let kind = self . make_lowered_await ( head_span, iter, FutureKind :: AsyncIterator ) ;
1757
+ self . arena . alloc ( hir:: Expr { hir_id : self . next_id ( ) , kind, span : head_span } )
1758
+ }
1759
+ } ;
1711
1760
let arms = arena_vec ! [ self ; none_arm, some_arm] ;
1712
1761
1762
+ // `match $next_expr { ... }`
1713
1763
self . expr_match ( head_span, next_expr, arms, hir:: MatchSource :: ForLoopDesugar )
1714
1764
} ;
1715
1765
let match_stmt = self . stmt_expr ( for_span, match_expr) ;
@@ -1729,13 +1779,24 @@ impl<'hir> LoweringContext<'_, 'hir> {
1729
1779
// `mut iter => { ... }`
1730
1780
let iter_arm = self . arm ( iter_pat, loop_expr) ;
1731
1781
1732
- // `match ::std::iter::IntoIterator::into_iter(<head>) { ... }`
1733
- let into_iter_expr = {
1734
- self . expr_call_lang_item_fn (
1735
- head_span,
1736
- hir:: LangItem :: IntoIterIntoIter ,
1737
- arena_vec ! [ self ; head] ,
1738
- )
1782
+ let into_iter_expr = match loop_kind {
1783
+ ForLoopKind :: For => {
1784
+ // `::std::iter::IntoIterator::into_iter(<head>)`
1785
+ self . expr_call_lang_item_fn (
1786
+ head_span,
1787
+ hir:: LangItem :: IntoIterIntoIter ,
1788
+ arena_vec ! [ self ; head] ,
1789
+ )
1790
+ }
1791
+ ForLoopKind :: ForAwait => {
1792
+ // `::core::async_iter::IntoAsyncIterator::into_async_iter(<head>)`
1793
+ let iter = self . expr_call_lang_item_fn (
1794
+ head_span,
1795
+ hir:: LangItem :: IntoAsyncIterIntoIter ,
1796
+ arena_vec ! [ self ; head] ,
1797
+ ) ;
1798
+ self . arena . alloc ( self . expr_mut_addr_of ( head_span, iter) )
1799
+ }
1739
1800
} ;
1740
1801
1741
1802
let match_expr = self . arena . alloc ( self . expr_match (
@@ -2138,3 +2199,14 @@ impl<'hir> LoweringContext<'_, 'hir> {
2138
2199
}
2139
2200
}
2140
2201
}
2202
+
2203
+ /// Used by [`LoweringContext::make_lowered_await`] to customize the desugaring based on what kind
2204
+ /// of future we are awaiting.
2205
+ #[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
2206
+ enum FutureKind {
2207
+ /// We are awaiting a normal future
2208
+ Future ,
2209
+ /// We are awaiting something that's known to be an AsyncIterator (i.e. we are in the header of
2210
+ /// a `for await` loop)
2211
+ AsyncIterator ,
2212
+ }
0 commit comments