Skip to content

Commit 8d10c52

Browse files
Build a shim to call async closures with different AsyncFn trait kinds
1 parent 6c75f74 commit 8d10c52

File tree

13 files changed

+175
-11
lines changed

13 files changed

+175
-11
lines changed

compiler/rustc_const_eval/src/interpret/terminator.rs

+1
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
541541
ty::InstanceDef::VTableShim(..)
542542
| ty::InstanceDef::ReifyShim(..)
543543
| ty::InstanceDef::ClosureOnceShim { .. }
544+
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
544545
| ty::InstanceDef::FnPtrShim(..)
545546
| ty::InstanceDef::DropGlue(..)
546547
| ty::InstanceDef::CloneShim(..)

compiler/rustc_middle/src/mir/mono.rs

+1
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ impl<'tcx> CodegenUnit<'tcx> {
402402
| InstanceDef::FnPtrShim(..)
403403
| InstanceDef::Virtual(..)
404404
| InstanceDef::ClosureOnceShim { .. }
405+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
405406
| InstanceDef::DropGlue(..)
406407
| InstanceDef::CloneShim(..)
407408
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/mir/visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,7 @@ macro_rules! make_mir_visitor {
345345
ty::InstanceDef::Virtual(_def_id, _) |
346346
ty::InstanceDef::ThreadLocalShim(_def_id) |
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348+
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
348349
ty::InstanceDef::DropGlue(_def_id, None) => {}
349350

350351
ty::InstanceDef::FnPtrShim(_def_id, ty) |

compiler/rustc_middle/src/ty/instance.rs

+22-1
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,25 @@ pub enum InstanceDef<'tcx> {
8282
/// details on that).
8383
Virtual(DefId, usize),
8484

85-
/// `<[FnMut closure] as FnOnce>::call_once`.
85+
/// `<[FnMut/Fn closure] as FnOnce>::call_once`.
8686
///
8787
/// The `DefId` is the ID of the `call_once` method in `FnOnce`.
88+
///
89+
/// This generates a body that will just borrow the (owned) self type,
90+
/// and dispatch to the `FnMut::call_mut` instance for the closure.
8891
ClosureOnceShim { call_once: DefId, track_caller: bool },
8992

93+
/// `<[FnMut/Fn coroutine-closure] as FnOnce>::call_once` or
94+
/// `<[Fn coroutine-closure] as FnMut>::call_mut`.
95+
///
96+
/// The body generated here differs significantly from the `ClosureOnceShim`,
97+
/// since we need to generate a distinct coroutine type that will move the
98+
/// closure's upvars *out* of the closure.
99+
ConstructCoroutineInClosureShim {
100+
coroutine_closure_def_id: DefId,
101+
target_kind: ty::ClosureKind,
102+
},
103+
90104
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
91105
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
92106
/// native support.
@@ -168,6 +182,10 @@ impl<'tcx> InstanceDef<'tcx> {
168182
| InstanceDef::Intrinsic(def_id)
169183
| InstanceDef::ThreadLocalShim(def_id)
170184
| InstanceDef::ClosureOnceShim { call_once: def_id, track_caller: _ }
185+
| ty::InstanceDef::ConstructCoroutineInClosureShim {
186+
coroutine_closure_def_id: def_id,
187+
target_kind: _,
188+
}
171189
| InstanceDef::DropGlue(def_id, _)
172190
| InstanceDef::CloneShim(def_id, _)
173191
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -187,6 +205,7 @@ impl<'tcx> InstanceDef<'tcx> {
187205
| InstanceDef::Virtual(..)
188206
| InstanceDef::Intrinsic(..)
189207
| InstanceDef::ClosureOnceShim { .. }
208+
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
190209
| InstanceDef::DropGlue(..)
191210
| InstanceDef::CloneShim(..)
192211
| InstanceDef::FnPtrAddrShim(..) => None,
@@ -282,6 +301,7 @@ impl<'tcx> InstanceDef<'tcx> {
282301
| InstanceDef::FnPtrShim(..)
283302
| InstanceDef::DropGlue(_, Some(_)) => false,
284303
InstanceDef::ClosureOnceShim { .. }
304+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
285305
| InstanceDef::DropGlue(..)
286306
| InstanceDef::Item(_)
287307
| InstanceDef::Intrinsic(..)
@@ -319,6 +339,7 @@ fn fmt_instance(
319339
InstanceDef::Virtual(_, num) => write!(f, " - virtual#{num}"),
320340
InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
321341
InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
342+
InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
322343
InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
323344
InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
324345
InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),

compiler/rustc_middle/src/ty/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2354,6 +2354,7 @@ impl<'tcx> TyCtxt<'tcx> {
23542354
| ty::InstanceDef::FnPtrShim(..)
23552355
| ty::InstanceDef::Virtual(..)
23562356
| ty::InstanceDef::ClosureOnceShim { .. }
2357+
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
23572358
| ty::InstanceDef::DropGlue(..)
23582359
| ty::InstanceDef::CloneShim(..)
23592360
| ty::InstanceDef::ThreadLocalShim(..)

compiler/rustc_mir_transform/src/inline.rs

+1
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ impl<'tcx> Inliner<'tcx> {
317317
| InstanceDef::ReifyShim(_)
318318
| InstanceDef::FnPtrShim(..)
319319
| InstanceDef::ClosureOnceShim { .. }
320+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
320321
| InstanceDef::DropGlue(..)
321322
| InstanceDef::CloneShim(..)
322323
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_mir_transform/src/inline/cycle.rs

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ pub(crate) fn mir_callgraph_reachable<'tcx>(
8787
| InstanceDef::ReifyShim(_)
8888
| InstanceDef::FnPtrShim(..)
8989
| InstanceDef::ClosureOnceShim { .. }
90+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
9091
| InstanceDef::ThreadLocalShim { .. }
9192
| InstanceDef::CloneShim(..) => {}
9293

compiler/rustc_mir_transform/src/shim.rs

+120-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ use rustc_hir::def_id::DefId;
33
use rustc_hir::lang_items::LangItem;
44
use rustc_middle::mir::*;
55
use rustc_middle::query::Providers;
6-
use rustc_middle::ty::GenericArgs;
76
use rustc_middle::ty::{self, CoroutineArgs, EarlyBinder, Ty, TyCtxt};
7+
use rustc_middle::ty::{GenericArgs, CAPTURE_STRUCT_LOCAL};
88
use rustc_target::abi::{FieldIdx, VariantIdx, FIRST_VARIANT};
99

1010
use rustc_index::{Idx, IndexVec};
@@ -66,6 +66,21 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceDef<'tcx>) -> Body<'
6666
build_call_shim(tcx, instance, Some(Adjustment::RefMut), CallKind::Direct(call_mut))
6767
}
6868

69+
ty::InstanceDef::ConstructCoroutineInClosureShim {
70+
coroutine_closure_def_id,
71+
target_kind,
72+
} => match target_kind {
73+
ty::ClosureKind::Fn => unreachable!("shouldn't be building shim for Fn"),
74+
ty::ClosureKind::FnMut => {
75+
let body = build_construct_coroutine_by_mut_shim(tcx, coroutine_closure_def_id);
76+
// No need to optimize the body, it has already been optimized.
77+
return body;
78+
}
79+
ty::ClosureKind::FnOnce => {
80+
build_construct_coroutine_by_move_shim(tcx, coroutine_closure_def_id)
81+
}
82+
},
83+
6984
ty::InstanceDef::DropGlue(def_id, ty) => {
7085
// FIXME(#91576): Drop shims for coroutines aren't subject to the MIR passes at the end
7186
// of this function. Is this intentional?
@@ -981,3 +996,107 @@ fn build_fn_ptr_addr_shim<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Ty<'t
981996
let source = MirSource::from_instance(ty::InstanceDef::FnPtrAddrShim(def_id, self_ty));
982997
new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
983998
}
999+
1000+
fn build_construct_coroutine_by_move_shim<'tcx>(
1001+
tcx: TyCtxt<'tcx>,
1002+
coroutine_closure_def_id: DefId,
1003+
) -> Body<'tcx> {
1004+
let self_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
1005+
let ty::CoroutineClosure(_, args) = *self_ty.kind() else {
1006+
bug!();
1007+
};
1008+
1009+
let poly_sig = args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
1010+
tcx.mk_fn_sig(
1011+
[self_ty].into_iter().chain(sig.tupled_inputs_ty.tuple_fields()),
1012+
sig.to_coroutine_given_kind_and_upvars(
1013+
tcx,
1014+
args.as_coroutine_closure().parent_args(),
1015+
tcx.coroutine_for_closure(coroutine_closure_def_id),
1016+
ty::ClosureKind::FnOnce,
1017+
tcx.lifetimes.re_erased,
1018+
args.as_coroutine_closure().tupled_upvars_ty(),
1019+
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
1020+
),
1021+
sig.c_variadic,
1022+
sig.unsafety,
1023+
sig.abi,
1024+
)
1025+
});
1026+
let sig = tcx.liberate_late_bound_regions(coroutine_closure_def_id, poly_sig);
1027+
let ty::Coroutine(coroutine_def_id, coroutine_args) = *sig.output().kind() else {
1028+
bug!();
1029+
};
1030+
1031+
let span = tcx.def_span(coroutine_closure_def_id);
1032+
let locals = local_decls_for_sig(&sig, span);
1033+
1034+
let mut fields = vec![];
1035+
for idx in 1..sig.inputs().len() {
1036+
fields.push(Operand::Move(Local::from_usize(idx + 1).into()));
1037+
}
1038+
for (idx, ty) in args.as_coroutine_closure().upvar_tys().iter().enumerate() {
1039+
fields.push(Operand::Move(tcx.mk_place_field(
1040+
Local::from_usize(1).into(),
1041+
FieldIdx::from_usize(idx),
1042+
ty,
1043+
)));
1044+
}
1045+
1046+
let source_info = SourceInfo::outermost(span);
1047+
let rvalue = Rvalue::Aggregate(
1048+
Box::new(AggregateKind::Coroutine(coroutine_def_id, coroutine_args)),
1049+
IndexVec::from_raw(fields),
1050+
);
1051+
let stmt = Statement {
1052+
source_info,
1053+
kind: StatementKind::Assign(Box::new((Place::return_place(), rvalue))),
1054+
};
1055+
let statements = vec![stmt];
1056+
let start_block = BasicBlockData {
1057+
statements,
1058+
terminator: Some(Terminator { source_info, kind: TerminatorKind::Return }),
1059+
is_cleanup: false,
1060+
};
1061+
1062+
let source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
1063+
coroutine_closure_def_id,
1064+
target_kind: ty::ClosureKind::FnOnce,
1065+
});
1066+
1067+
new_body(source, IndexVec::from_elem_n(start_block, 1), locals, sig.inputs().len(), span)
1068+
}
1069+
1070+
fn build_construct_coroutine_by_mut_shim<'tcx>(
1071+
tcx: TyCtxt<'tcx>,
1072+
coroutine_closure_def_id: DefId,
1073+
) -> Body<'tcx> {
1074+
let mut body = tcx.optimized_mir(coroutine_closure_def_id).clone();
1075+
let coroutine_closure_ty = tcx.type_of(coroutine_closure_def_id).instantiate_identity();
1076+
let ty::CoroutineClosure(_, args) = *coroutine_closure_ty.kind() else {
1077+
bug!();
1078+
};
1079+
let args = args.as_coroutine_closure();
1080+
1081+
body.local_decls[RETURN_PLACE].ty =
1082+
tcx.instantiate_bound_regions_with_erased(args.coroutine_closure_sig().map_bound(|sig| {
1083+
sig.to_coroutine_given_kind_and_upvars(
1084+
tcx,
1085+
args.parent_args(),
1086+
tcx.coroutine_for_closure(coroutine_closure_def_id),
1087+
ty::ClosureKind::FnMut,
1088+
tcx.lifetimes.re_erased,
1089+
args.tupled_upvars_ty(),
1090+
args.coroutine_captures_by_ref_ty(),
1091+
)
1092+
}));
1093+
body.local_decls[CAPTURE_STRUCT_LOCAL].ty =
1094+
Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_closure_ty);
1095+
1096+
body.source = MirSource::from_instance(ty::InstanceDef::ConstructCoroutineInClosureShim {
1097+
coroutine_closure_def_id,
1098+
target_kind: ty::ClosureKind::FnMut,
1099+
});
1100+
1101+
body
1102+
}

compiler/rustc_monomorphize/src/collector.rs

+1
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,7 @@ fn visit_instance_use<'tcx>(
983983
| ty::InstanceDef::VTableShim(..)
984984
| ty::InstanceDef::ReifyShim(..)
985985
| ty::InstanceDef::ClosureOnceShim { .. }
986+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
986987
| ty::InstanceDef::Item(..)
987988
| ty::InstanceDef::FnPtrShim(..)
988989
| ty::InstanceDef::CloneShim(..)

compiler/rustc_monomorphize/src/partitioning.rs

+2
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,7 @@ fn characteristic_def_id_of_mono_item<'tcx>(
620620
| ty::InstanceDef::ReifyShim(..)
621621
| ty::InstanceDef::FnPtrShim(..)
622622
| ty::InstanceDef::ClosureOnceShim { .. }
623+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
623624
| ty::InstanceDef::Intrinsic(..)
624625
| ty::InstanceDef::DropGlue(..)
625626
| ty::InstanceDef::Virtual(..)
@@ -783,6 +784,7 @@ fn mono_item_visibility<'tcx>(
783784
| InstanceDef::Virtual(..)
784785
| InstanceDef::Intrinsic(..)
785786
| InstanceDef::ClosureOnceShim { .. }
787+
| InstanceDef::ConstructCoroutineInClosureShim { .. }
786788
| InstanceDef::DropGlue(..)
787789
| InstanceDef::CloneShim(..)
788790
| InstanceDef::FnPtrAddrShim(..) => return Visibility::Hidden,

compiler/rustc_smir/src/rustc_smir/convert/ty.rs

+1
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ impl<'tcx> Stable<'tcx> for ty::Instance<'tcx> {
800800
| ty::InstanceDef::ReifyShim(..)
801801
| ty::InstanceDef::FnPtrAddrShim(..)
802802
| ty::InstanceDef::ClosureOnceShim { .. }
803+
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
803804
| ty::InstanceDef::ThreadLocalShim(..)
804805
| ty::InstanceDef::DropGlue(..)
805806
| ty::InstanceDef::CloneShim(..)

compiler/rustc_ty_utils/src/abi.rs

+9-6
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,14 @@ fn fn_sig_for_fn_abi<'tcx>(
111111
kind: ty::BoundRegionKind::BrEnv,
112112
};
113113
let env_region = ty::Region::new_bound(tcx, ty::INNERMOST, br);
114-
let env_ty = tcx.closure_env_ty(
115-
Ty::new_coroutine_closure(tcx, def_id, args),
116-
args.as_coroutine_closure().kind(),
117-
env_region,
118-
);
114+
115+
let mut kind = args.as_coroutine_closure().kind();
116+
if let InstanceDef::ConstructCoroutineInClosureShim { target_kind, .. } = instance.def {
117+
kind = target_kind;
118+
}
119+
120+
let env_ty =
121+
tcx.closure_env_ty(Ty::new_coroutine_closure(tcx, def_id, args), kind, env_region);
119122

120123
let sig = sig.skip_binder();
121124
ty::Binder::bind_with_vars(
@@ -125,7 +128,7 @@ fn fn_sig_for_fn_abi<'tcx>(
125128
tcx,
126129
args.as_coroutine_closure().parent_args(),
127130
tcx.coroutine_for_closure(def_id),
128-
args.as_coroutine_closure().kind(),
131+
kind,
129132
env_region,
130133
args.as_coroutine_closure().tupled_upvars_ty(),
131134
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),

compiler/rustc_ty_utils/src/instance.rs

+14-3
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,21 @@ fn resolve_associated_item<'tcx>(
283283
tcx.item_name(trait_item_id)
284284
),
285285
}
286-
} else if tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id).is_some() {
286+
} else if let Some(target_kind) = tcx.async_fn_trait_kind_from_def_id(trait_ref.def_id)
287+
{
287288
match *rcvr_args.type_at(0).kind() {
288-
ty::CoroutineClosure(closure_def_id, args) => {
289-
Some(Instance::new(closure_def_id, args))
289+
ty::CoroutineClosure(coroutine_closure_def_id, args) => {
290+
if target_kind > args.as_coroutine_closure().kind() {
291+
Some(Instance {
292+
def: ty::InstanceDef::ConstructCoroutineInClosureShim {
293+
coroutine_closure_def_id,
294+
target_kind,
295+
},
296+
args,
297+
})
298+
} else {
299+
Some(Instance::new(coroutine_closure_def_id, args))
300+
}
290301
}
291302
_ => bug!(
292303
"no built-in definition for `{trait_ref}::{}` for non-lending-closure type",

0 commit comments

Comments
 (0)