Skip to content

Commit 6c75f74

Browse files
Teach typeck/borrowck/solvers how to deal with async closures
1 parent 21906bd commit 6c75f74

File tree

34 files changed

+1215
-58
lines changed

34 files changed

+1215
-58
lines changed

compiler/rustc_ast_lowering/src/expr.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::assert_matches::assert_matches;
2+
13
use super::errors::{
24
AsyncCoroutinesNotSupported, AwaitOnlyInAsyncFnAndBlocks, BaseExpressionDoubleDot,
35
ClosureCannotBeStatic, CoroutineTooManyParameters,
@@ -1028,6 +1030,12 @@ impl<'hir> LoweringContext<'_, 'hir> {
10281030
) -> hir::ExprKind<'hir> {
10291031
let (binder_clause, generic_params) = self.lower_closure_binder(binder);
10301032

1033+
assert_matches!(
1034+
coroutine_kind,
1035+
CoroutineKind::Async { .. },
1036+
"only async closures are supported currently"
1037+
);
1038+
10311039
let body = self.with_new_scopes(fn_decl_span, |this| {
10321040
let inner_decl =
10331041
FnDecl { inputs: decl.inputs.clone(), output: FnRetTy::Default(fn_decl_span) };

compiler/rustc_ast_lowering/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#![allow(internal_features)]
3434
#![feature(rustdoc_internals)]
3535
#![doc(rust_logo)]
36+
#![feature(assert_matches)]
3637
#![feature(if_let_guard)]
3738
#![feature(box_patterns)]
3839
#![feature(let_chains)]

compiler/rustc_borrowck/src/diagnostics/region_name.rs

+11-10
Original file line numberDiff line numberDiff line change
@@ -324,31 +324,32 @@ impl<'tcx> MirBorrowckCtxt<'_, 'tcx> {
324324
ty::BoundRegionKind::BrEnv => {
325325
let def_ty = self.regioncx.universal_regions().defining_ty;
326326

327-
let DefiningTy::Closure(_, args) = def_ty else {
328-
// Can't have BrEnv in functions, constants or coroutines.
329-
bug!("BrEnv outside of closure.");
327+
let closure_kind = match def_ty {
328+
DefiningTy::Closure(_, args) => args.as_closure().kind(),
329+
DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().kind(),
330+
_ => {
331+
// Can't have BrEnv in functions, constants or coroutines.
332+
bug!("BrEnv outside of closure.");
333+
}
330334
};
331335
let hir::ExprKind::Closure(&hir::Closure { fn_decl_span, .. }) =
332336
tcx.hir().expect_expr(self.mir_hir_id()).kind
333337
else {
334338
bug!("Closure is not defined by a closure expr");
335339
};
336340
let region_name = self.synthesize_region_name();
337-
338-
let closure_kind_ty = args.as_closure().kind_ty();
339-
let note = match closure_kind_ty.to_opt_closure_kind() {
340-
Some(ty::ClosureKind::Fn) => {
341+
let note = match closure_kind {
342+
ty::ClosureKind::Fn => {
341343
"closure implements `Fn`, so references to captured variables \
342344
can't escape the closure"
343345
}
344-
Some(ty::ClosureKind::FnMut) => {
346+
ty::ClosureKind::FnMut => {
345347
"closure implements `FnMut`, so references to captured variables \
346348
can't escape the closure"
347349
}
348-
Some(ty::ClosureKind::FnOnce) => {
350+
ty::ClosureKind::FnOnce => {
349351
bug!("BrEnv in a `FnOnce` closure");
350352
}
351-
None => bug!("Closure kind not inferred in borrow check"),
352353
};
353354

354355
Some(RegionName {

compiler/rustc_borrowck/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#![allow(internal_features)]
44
#![feature(rustdoc_internals)]
55
#![doc(rust_logo)]
6+
#![feature(assert_matches)]
67
#![feature(associated_type_bounds)]
78
#![feature(box_patterns)]
89
#![feature(let_chains)]

compiler/rustc_borrowck/src/type_check/input_output.rs

+68-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@
77
//! `RETURN_PLACE` the MIR arguments) are always fully normalized (and
88
//! contain revealed `impl Trait` values).
99
10+
use std::assert_matches::assert_matches;
11+
1012
use itertools::Itertools;
11-
use rustc_infer::infer::BoundRegionConversionTime;
13+
use rustc_hir as hir;
14+
use rustc_infer::infer::type_variable::{TypeVariableOrigin, TypeVariableOriginKind};
15+
use rustc_infer::infer::{BoundRegionConversionTime, RegionVariableOrigin};
1216
use rustc_middle::mir::*;
1317
use rustc_middle::ty::{self, Ty};
1418
use rustc_span::Span;
1519

16-
use crate::universal_regions::UniversalRegions;
20+
use crate::renumber::RegionCtxt;
21+
use crate::universal_regions::{DefiningTy, UniversalRegions};
1722

1823
use super::{Locations, TypeChecker};
1924

@@ -23,9 +28,11 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
2328
#[instrument(skip(self, body), level = "debug")]
2429
pub(super) fn check_signature_annotation(&mut self, body: &Body<'tcx>) {
2530
let mir_def_id = body.source.def_id().expect_local();
31+
2632
if !self.tcx().is_closure_or_coroutine(mir_def_id.to_def_id()) {
2733
return;
2834
}
35+
2936
let user_provided_poly_sig = self.tcx().closure_user_provided_sig(mir_def_id);
3037

3138
// Instantiate the canonicalized variables from user-provided signature
@@ -34,12 +41,70 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
3441
// so that they represent the view from "inside" the closure.
3542
let user_provided_sig = self
3643
.instantiate_canonical_with_fresh_inference_vars(body.span, &user_provided_poly_sig);
37-
let user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
44+
let mut user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
3845
body.span,
3946
BoundRegionConversionTime::FnCall,
4047
user_provided_sig,
4148
);
4249

50+
// FIXME(async_closures): We must apply the same transformation to our
51+
// signature here as we do during closure checking.
52+
if let DefiningTy::CoroutineClosure(_, args) =
53+
self.borrowck_context.universal_regions.defining_ty
54+
{
55+
assert_matches!(
56+
self.tcx().coroutine_kind(self.tcx().coroutine_for_closure(mir_def_id)),
57+
Some(hir::CoroutineKind::Desugared(
58+
hir::CoroutineDesugaring::Async,
59+
hir::CoroutineSource::Closure
60+
)),
61+
"this needs to be modified if we're lowering non-async closures"
62+
);
63+
let args = args.as_coroutine_closure();
64+
let tupled_upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
65+
self.tcx(),
66+
args.kind(),
67+
Ty::new_tup(self.tcx(), user_provided_sig.inputs()),
68+
args.tupled_upvars_ty(),
69+
args.coroutine_captures_by_ref_ty(),
70+
self.infcx.next_region_var(RegionVariableOrigin::MiscVariable(body.span), || {
71+
RegionCtxt::Unknown
72+
}),
73+
);
74+
75+
let next_ty_var = || {
76+
self.infcx.next_ty_var(TypeVariableOrigin {
77+
span: body.span,
78+
kind: TypeVariableOriginKind::MiscVariable,
79+
})
80+
};
81+
let output_ty = Ty::new_coroutine(
82+
self.tcx(),
83+
self.tcx().coroutine_for_closure(mir_def_id),
84+
ty::CoroutineArgs::new(
85+
self.tcx(),
86+
ty::CoroutineArgsParts {
87+
parent_args: args.parent_args(),
88+
kind_ty: Ty::from_closure_kind(self.tcx(), args.kind()),
89+
resume_ty: next_ty_var(),
90+
yield_ty: next_ty_var(),
91+
witness: next_ty_var(),
92+
return_ty: user_provided_sig.output(),
93+
tupled_upvars_ty: tupled_upvars_ty,
94+
},
95+
)
96+
.args,
97+
);
98+
99+
user_provided_sig = self.tcx().mk_fn_sig(
100+
user_provided_sig.inputs().iter().copied(),
101+
output_ty,
102+
user_provided_sig.c_variadic,
103+
user_provided_sig.unsafety,
104+
user_provided_sig.abi,
105+
);
106+
}
107+
43108
let is_coroutine_with_implicit_resume_ty = self.tcx().is_coroutine(mir_def_id.to_def_id())
44109
&& user_provided_sig.inputs().is_empty();
45110

compiler/rustc_borrowck/src/universal_regions.rs

+57-1
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ pub enum DefiningTy<'tcx> {
9797
/// `ClosureArgs::coroutine_return_ty`.
9898
Coroutine(DefId, GenericArgsRef<'tcx>),
9999

100+
/// The MIR is a special kind of closure that returns coroutines.
101+
/// TODO: describe how to make the sig...
102+
CoroutineClosure(DefId, GenericArgsRef<'tcx>),
103+
100104
/// The MIR is a fn item with the given `DefId` and args. The signature
101105
/// of the function can be bound then with the `fn_sig` query.
102106
FnDef(DefId, GenericArgsRef<'tcx>),
@@ -119,6 +123,7 @@ impl<'tcx> DefiningTy<'tcx> {
119123
pub fn upvar_tys(self) -> &'tcx ty::List<Ty<'tcx>> {
120124
match self {
121125
DefiningTy::Closure(_, args) => args.as_closure().upvar_tys(),
126+
DefiningTy::CoroutineClosure(_, args) => args.as_coroutine_closure().upvar_tys(),
122127
DefiningTy::Coroutine(_, args) => args.as_coroutine().upvar_tys(),
123128
DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => {
124129
ty::List::empty()
@@ -131,7 +136,9 @@ impl<'tcx> DefiningTy<'tcx> {
131136
/// user's code.
132137
pub fn implicit_inputs(self) -> usize {
133138
match self {
134-
DefiningTy::Closure(..) | DefiningTy::Coroutine(..) => 1,
139+
DefiningTy::Closure(..)
140+
| DefiningTy::CoroutineClosure(..)
141+
| DefiningTy::Coroutine(..) => 1,
135142
DefiningTy::FnDef(..) | DefiningTy::Const(..) | DefiningTy::InlineConst(..) => 0,
136143
}
137144
}
@@ -147,6 +154,7 @@ impl<'tcx> DefiningTy<'tcx> {
147154
pub fn def_id(&self) -> DefId {
148155
match *self {
149156
DefiningTy::Closure(def_id, ..)
157+
| DefiningTy::CoroutineClosure(def_id, ..)
150158
| DefiningTy::Coroutine(def_id, ..)
151159
| DefiningTy::FnDef(def_id, ..)
152160
| DefiningTy::Const(def_id, ..)
@@ -355,6 +363,9 @@ impl<'tcx> UniversalRegions<'tcx> {
355363
err.note(format!("late-bound region is {:?}", self.to_region_vid(r)));
356364
});
357365
}
366+
DefiningTy::CoroutineClosure(..) => {
367+
todo!()
368+
}
358369
DefiningTy::Coroutine(def_id, args) => {
359370
let v = with_no_trimmed_paths!(
360371
args[tcx.generics_of(def_id).parent_count..]
@@ -568,6 +579,9 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
568579
match *defining_ty.kind() {
569580
ty::Closure(def_id, args) => DefiningTy::Closure(def_id, args),
570581
ty::Coroutine(def_id, args) => DefiningTy::Coroutine(def_id, args),
582+
ty::CoroutineClosure(def_id, args) => {
583+
DefiningTy::CoroutineClosure(def_id, args)
584+
}
571585
ty::FnDef(def_id, args) => DefiningTy::FnDef(def_id, args),
572586
_ => span_bug!(
573587
tcx.def_span(self.mir_def),
@@ -623,6 +637,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
623637
let identity_args = GenericArgs::identity_for_item(tcx, typeck_root_def_id);
624638
let fr_args = match defining_ty {
625639
DefiningTy::Closure(_, args)
640+
| DefiningTy::CoroutineClosure(_, args)
626641
| DefiningTy::Coroutine(_, args)
627642
| DefiningTy::InlineConst(_, args) => {
628643
// In the case of closures, we rely on the fact that
@@ -702,6 +717,47 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
702717
ty::Binder::dummy(inputs_and_output)
703718
}
704719

720+
DefiningTy::CoroutineClosure(def_id, args) => {
721+
assert_eq!(self.mir_def.to_def_id(), def_id);
722+
let closure_sig = args.as_coroutine_closure().coroutine_closure_sig();
723+
let bound_vars = tcx.mk_bound_variable_kinds_from_iter(
724+
closure_sig
725+
.bound_vars()
726+
.iter()
727+
.chain(iter::once(ty::BoundVariableKind::Region(ty::BrEnv))),
728+
);
729+
let br = ty::BoundRegion {
730+
var: ty::BoundVar::from_usize(bound_vars.len() - 1),
731+
kind: ty::BrEnv,
732+
};
733+
let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br);
734+
let closure_kind = args.as_coroutine_closure().kind();
735+
736+
let closure_ty = tcx.closure_env_ty(
737+
Ty::new_coroutine_closure(tcx, def_id, args),
738+
closure_kind,
739+
env_region,
740+
);
741+
742+
let inputs = closure_sig.skip_binder().tupled_inputs_ty.tuple_fields();
743+
let output = closure_sig.skip_binder().to_coroutine_given_kind_and_upvars(
744+
tcx,
745+
args.as_coroutine_closure().parent_args(),
746+
tcx.coroutine_for_closure(def_id),
747+
closure_kind,
748+
env_region,
749+
args.as_coroutine_closure().tupled_upvars_ty(),
750+
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
751+
);
752+
753+
ty::Binder::bind_with_vars(
754+
tcx.mk_type_list_from_iter(
755+
iter::once(closure_ty).chain(inputs).chain(iter::once(output)),
756+
),
757+
bound_vars,
758+
)
759+
}
760+
705761
DefiningTy::FnDef(def_id, _) => {
706762
let sig = tcx.fn_sig(def_id).instantiate_identity();
707763
let sig = indices.fold_to_region_vids(tcx, sig);

compiler/rustc_const_eval/src/transform/validate.rs

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ impl<'tcx> MirPass<'tcx> for Validator {
5858
let body_abi = match body_ty.kind() {
5959
ty::FnDef(..) => body_ty.fn_sig(tcx).abi(),
6060
ty::Closure(..) => Abi::RustCall,
61+
ty::CoroutineClosure(..) => Abi::RustCall,
6162
ty::Coroutine(..) => Abi::Rust,
6263
_ => {
6364
span_bug!(body.span, "unexpected body ty: {:?} phase {:?}", body_ty, mir_phase)

compiler/rustc_hir/src/lang_items.rs

+1
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ language_item_table! {
209209
AsyncFn, sym::async_fn, async_fn_trait, Target::Trait, GenericRequirement::Exact(1);
210210
AsyncFnMut, sym::async_fn_mut, async_fn_mut_trait, Target::Trait, GenericRequirement::Exact(1);
211211
AsyncFnOnce, sym::async_fn_once, async_fn_once_trait, Target::Trait, GenericRequirement::Exact(1);
212+
AsyncFnKindHelper, sym::async_fn_kind_helper,async_fn_kind_helper, Target::Trait, GenericRequirement::Exact(1);
212213

213214
FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None;
214215

compiler/rustc_hir_analysis/src/collect.rs

+31
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pub fn provide(providers: &mut Providers) {
8181
impl_trait_ref,
8282
impl_polarity,
8383
coroutine_kind,
84+
coroutine_for_closure,
8485
collect_mod_item_types,
8586
is_type_alias_impl_trait,
8687
..*providers
@@ -1531,6 +1532,36 @@ fn coroutine_kind(tcx: TyCtxt<'_>, def_id: LocalDefId) -> Option<hir::CoroutineK
15311532
}
15321533
}
15331534

1535+
fn coroutine_for_closure(tcx: TyCtxt<'_>, def_id: LocalDefId) -> DefId {
1536+
let Node::Expr(&hir::Expr {
1537+
kind:
1538+
hir::ExprKind::Closure(&rustc_hir::Closure {
1539+
kind: hir::ClosureKind::CoroutineClosure(_),
1540+
body,
1541+
..
1542+
}),
1543+
..
1544+
}) = tcx.hir_node_by_def_id(def_id)
1545+
else {
1546+
bug!()
1547+
};
1548+
1549+
let &hir::Expr {
1550+
kind:
1551+
hir::ExprKind::Closure(&rustc_hir::Closure {
1552+
def_id,
1553+
kind: hir::ClosureKind::Coroutine(_),
1554+
..
1555+
}),
1556+
..
1557+
} = tcx.hir().body(body).value
1558+
else {
1559+
bug!()
1560+
};
1561+
1562+
def_id.to_def_id()
1563+
}
1564+
15341565
fn is_type_alias_impl_trait<'tcx>(tcx: TyCtxt<'tcx>, def_id: LocalDefId) -> bool {
15351566
match tcx.hir_node_by_def_id(def_id) {
15361567
Node::Item(hir::Item { kind: hir::ItemKind::OpaqueTy(opaque), .. }) => {

0 commit comments

Comments
 (0)