Skip to content

Commit 0406518

Browse files
Construct body for by-move coroutine closure output
1 parent fb8bfd0 commit 0406518

File tree

23 files changed

+229
-15
lines changed

23 files changed

+229
-15
lines changed

compiler/rustc_const_eval/src/interpret/terminator.rs

+1
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,7 @@ impl<'mir, 'tcx: 'mir, M: Machine<'mir, 'tcx>> InterpCx<'mir, 'tcx, M> {
542542
| ty::InstanceDef::ReifyShim(..)
543543
| ty::InstanceDef::ClosureOnceShim { .. }
544544
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
545+
| ty::InstanceDef::CoroutineByMoveShim { .. }
545546
| ty::InstanceDef::FnPtrShim(..)
546547
| ty::InstanceDef::DropGlue(..)
547548
| ty::InstanceDef::CloneShim(..)

compiler/rustc_hir_typeck/src/callee.rs

+1
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
182182
coroutine_closure_sig.to_coroutine(
183183
self.tcx,
184184
closure_args.parent_args(),
185+
closure_args.kind_ty(),
185186
self.tcx.coroutine_for_closure(def_id),
186187
tupled_upvars_ty,
187188
),

compiler/rustc_hir_typeck/src/closure.rs

+11
Original file line numberDiff line numberDiff line change
@@ -175,10 +175,20 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
175175
interior,
176176
));
177177

178+
let kind_ty = match kind {
179+
hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure) => self
180+
.next_ty_var(TypeVariableOrigin {
181+
kind: TypeVariableOriginKind::ClosureSynthetic,
182+
span: expr_span,
183+
}),
184+
_ => tcx.types.unit,
185+
};
186+
178187
let coroutine_args = ty::CoroutineArgs::new(
179188
tcx,
180189
ty::CoroutineArgsParts {
181190
parent_args,
191+
kind_ty,
182192
resume_ty,
183193
yield_ty,
184194
return_ty: liberated_sig.output(),
@@ -256,6 +266,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
256266
sig.to_coroutine(
257267
tcx,
258268
parent_args,
269+
closure_kind_ty,
259270
tcx.coroutine_for_closure(expr_def_id),
260271
coroutine_upvars_ty,
261272
)

compiler/rustc_hir_typeck/src/upvar.rs

+10
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
393393
args.as_coroutine_closure().coroutine_captures_by_ref_ty(),
394394
coroutine_captures_by_ref_ty,
395395
);
396+
397+
let ty::Coroutine(_, args) = *self.typeck_results.borrow().expr_ty(body.value).kind()
398+
else {
399+
bug!();
400+
};
401+
self.demand_eqtype(
402+
span,
403+
args.as_coroutine().kind_ty(),
404+
Ty::from_closure_kind(self.tcx, closure_kind),
405+
);
396406
}
397407

398408
self.log_closure_min_capture_info(closure_def_id, span);

compiler/rustc_middle/src/mir/mod.rs

+5
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,10 @@ pub struct CoroutineInfo<'tcx> {
260260
/// Coroutine drop glue. This field is populated after the state transform pass.
261261
pub coroutine_drop: Option<Body<'tcx>>,
262262

263+
/// The body of the coroutine, modified to take its upvars by move.
264+
/// TODO:
265+
pub by_move_body: Option<Body<'tcx>>,
266+
263267
/// The layout of a coroutine. This field is populated after the state transform pass.
264268
pub coroutine_layout: Option<CoroutineLayout<'tcx>>,
265269

@@ -279,6 +283,7 @@ impl<'tcx> CoroutineInfo<'tcx> {
279283
coroutine_kind,
280284
yield_ty: Some(yield_ty),
281285
resume_ty: Some(resume_ty),
286+
by_move_body: None,
282287
coroutine_drop: None,
283288
coroutine_layout: None,
284289
}

compiler/rustc_middle/src/mir/mono.rs

+1
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ impl<'tcx> CodegenUnit<'tcx> {
403403
| InstanceDef::Virtual(..)
404404
| InstanceDef::ClosureOnceShim { .. }
405405
| InstanceDef::ConstructCoroutineInClosureShim { .. }
406+
| InstanceDef::CoroutineByMoveShim { .. }
406407
| InstanceDef::DropGlue(..)
407408
| InstanceDef::CloneShim(..)
408409
| InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/mir/visit.rs

+1
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,7 @@ macro_rules! make_mir_visitor {
346346
ty::InstanceDef::ThreadLocalShim(_def_id) |
347347
ty::InstanceDef::ClosureOnceShim { call_once: _def_id, track_caller: _ } |
348348
ty::InstanceDef::ConstructCoroutineInClosureShim { coroutine_closure_def_id: _def_id, target_kind: _ } |
349+
ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: _def_id } |
349350
ty::InstanceDef::DropGlue(_def_id, None) => {}
350351

351352
ty::InstanceDef::FnPtrShim(_def_id, ty) |

compiler/rustc_middle/src/ty/instance.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ pub enum InstanceDef<'tcx> {
101101
target_kind: ty::ClosureKind,
102102
},
103103

104+
/// TODO:
105+
CoroutineByMoveShim { coroutine_def_id: DefId },
106+
104107
/// Compiler-generated accessor for thread locals which returns a reference to the thread local
105108
/// the `DefId` defines. This is used to export thread locals from dylibs on platforms lacking
106109
/// native support.
@@ -186,6 +189,7 @@ impl<'tcx> InstanceDef<'tcx> {
186189
coroutine_closure_def_id: def_id,
187190
target_kind: _,
188191
}
192+
| ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id: def_id }
189193
| InstanceDef::DropGlue(def_id, _)
190194
| InstanceDef::CloneShim(def_id, _)
191195
| InstanceDef::FnPtrAddrShim(def_id, _) => def_id,
@@ -206,6 +210,7 @@ impl<'tcx> InstanceDef<'tcx> {
206210
| InstanceDef::Intrinsic(..)
207211
| InstanceDef::ClosureOnceShim { .. }
208212
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
213+
| ty::InstanceDef::CoroutineByMoveShim { .. }
209214
| InstanceDef::DropGlue(..)
210215
| InstanceDef::CloneShim(..)
211216
| InstanceDef::FnPtrAddrShim(..) => None,
@@ -302,6 +307,7 @@ impl<'tcx> InstanceDef<'tcx> {
302307
| InstanceDef::DropGlue(_, Some(_)) => false,
303308
InstanceDef::ClosureOnceShim { .. }
304309
| InstanceDef::ConstructCoroutineInClosureShim { .. }
310+
| InstanceDef::CoroutineByMoveShim { .. }
305311
| InstanceDef::DropGlue(..)
306312
| InstanceDef::Item(_)
307313
| InstanceDef::Intrinsic(..)
@@ -340,6 +346,7 @@ fn fmt_instance(
340346
InstanceDef::FnPtrShim(_, ty) => write!(f, " - shim({ty})"),
341347
InstanceDef::ClosureOnceShim { .. } => write!(f, " - shim"),
342348
InstanceDef::ConstructCoroutineInClosureShim { .. } => write!(f, " - shim"),
349+
InstanceDef::CoroutineByMoveShim { .. } => write!(f, " - shim"),
343350
InstanceDef::DropGlue(_, None) => write!(f, " - shim(None)"),
344351
InstanceDef::DropGlue(_, Some(ty)) => write!(f, " - shim(Some({ty}))"),
345352
InstanceDef::CloneShim(_, ty) => write!(f, " - shim({ty})"),
@@ -631,7 +638,19 @@ impl<'tcx> Instance<'tcx> {
631638
};
632639

633640
if tcx.lang_items().get(coroutine_callable_item) == Some(trait_item_id) {
634-
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args: args })
641+
let ty::Coroutine(_, id_args) = *tcx.type_of(coroutine_def_id).skip_binder().kind()
642+
else {
643+
bug!()
644+
};
645+
646+
if args.as_coroutine().kind_ty() == id_args.as_coroutine().kind_ty() {
647+
Some(Instance { def: ty::InstanceDef::Item(coroutine_def_id), args })
648+
} else {
649+
Some(Instance {
650+
def: ty::InstanceDef::CoroutineByMoveShim { coroutine_def_id },
651+
args,
652+
})
653+
}
635654
} else {
636655
// All other methods should be defaulted methods of the built-in trait.
637656
// This is important for `Iterator`'s combinators, but also useful for

compiler/rustc_middle/src/ty/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2355,6 +2355,7 @@ impl<'tcx> TyCtxt<'tcx> {
23552355
| ty::InstanceDef::Virtual(..)
23562356
| ty::InstanceDef::ClosureOnceShim { .. }
23572357
| ty::InstanceDef::ConstructCoroutineInClosureShim { .. }
2358+
| ty::InstanceDef::CoroutineByMoveShim { .. }
23582359
| ty::InstanceDef::DropGlue(..)
23592360
| ty::InstanceDef::CloneShim(..)
23602361
| ty::InstanceDef::ThreadLocalShim(..)

compiler/rustc_middle/src/ty/sty.rs

+35-12
Original file line numberDiff line numberDiff line change
@@ -475,13 +475,15 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
475475
self,
476476
tcx: TyCtxt<'tcx>,
477477
parent_args: &'tcx [GenericArg<'tcx>],
478+
kind_ty: Ty<'tcx>,
478479
coroutine_def_id: DefId,
479480
tupled_upvars_ty: Ty<'tcx>,
480481
) -> Ty<'tcx> {
481482
let coroutine_args = ty::CoroutineArgs::new(
482483
tcx,
483484
ty::CoroutineArgsParts {
484485
parent_args,
486+
kind_ty,
485487
resume_ty: self.resume_ty,
486488
yield_ty: self.yield_ty,
487489
return_ty: self.return_ty,
@@ -512,7 +514,13 @@ impl<'tcx> CoroutineClosureSignature<'tcx> {
512514
env_region,
513515
);
514516

515-
self.to_coroutine(tcx, parent_args, coroutine_def_id, tupled_upvars_ty)
517+
self.to_coroutine(
518+
tcx,
519+
parent_args,
520+
Ty::from_closure_kind(tcx, closure_kind),
521+
coroutine_def_id,
522+
tupled_upvars_ty,
523+
)
516524
}
517525

518526
/// Given a closure kind, compute the tupled upvars that the given coroutine would return.
@@ -564,6 +572,8 @@ pub struct CoroutineArgs<'tcx> {
564572
pub struct CoroutineArgsParts<'tcx> {
565573
/// This is the args of the typeck root.
566574
pub parent_args: &'tcx [GenericArg<'tcx>],
575+
// TODO: why
576+
pub kind_ty: Ty<'tcx>,
567577
pub resume_ty: Ty<'tcx>,
568578
pub yield_ty: Ty<'tcx>,
569579
pub return_ty: Ty<'tcx>,
@@ -582,6 +592,7 @@ impl<'tcx> CoroutineArgs<'tcx> {
582592
pub fn new(tcx: TyCtxt<'tcx>, parts: CoroutineArgsParts<'tcx>) -> CoroutineArgs<'tcx> {
583593
CoroutineArgs {
584594
args: tcx.mk_args_from_iter(parts.parent_args.iter().copied().chain([
595+
parts.kind_ty.into(),
585596
parts.resume_ty.into(),
586597
parts.yield_ty.into(),
587598
parts.return_ty.into(),
@@ -595,16 +606,23 @@ impl<'tcx> CoroutineArgs<'tcx> {
595606
/// The ordering assumed here must match that used by `CoroutineArgs::new` above.
596607
fn split(self) -> CoroutineArgsParts<'tcx> {
597608
match self.args[..] {
598-
[ref parent_args @ .., resume_ty, yield_ty, return_ty, witness, tupled_upvars_ty] => {
599-
CoroutineArgsParts {
600-
parent_args,
601-
resume_ty: resume_ty.expect_ty(),
602-
yield_ty: yield_ty.expect_ty(),
603-
return_ty: return_ty.expect_ty(),
604-
witness: witness.expect_ty(),
605-
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
606-
}
607-
}
609+
[
610+
ref parent_args @ ..,
611+
kind_ty,
612+
resume_ty,
613+
yield_ty,
614+
return_ty,
615+
witness,
616+
tupled_upvars_ty,
617+
] => CoroutineArgsParts {
618+
parent_args,
619+
kind_ty: kind_ty.expect_ty(),
620+
resume_ty: resume_ty.expect_ty(),
621+
yield_ty: yield_ty.expect_ty(),
622+
return_ty: return_ty.expect_ty(),
623+
witness: witness.expect_ty(),
624+
tupled_upvars_ty: tupled_upvars_ty.expect_ty(),
625+
},
608626
_ => bug!("coroutine args missing synthetics"),
609627
}
610628
}
@@ -614,6 +632,11 @@ impl<'tcx> CoroutineArgs<'tcx> {
614632
self.split().parent_args
615633
}
616634

635+
// TODO:
636+
pub fn kind_ty(self) -> Ty<'tcx> {
637+
self.split().kind_ty
638+
}
639+
617640
/// This describes the types that can be contained in a coroutine.
618641
/// It will be a type variable initially and unified in the last stages of typeck of a body.
619642
/// It contains a tuple of all the types that could end up on a coroutine frame.
@@ -2381,7 +2404,7 @@ impl<'tcx> Ty<'tcx> {
23812404
) -> Ty<'tcx> {
23822405
debug_assert_eq!(
23832406
coroutine_args.len(),
2384-
tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 5,
2407+
tcx.generics_of(tcx.typeck_root_def_id(def_id)).count() + 6,
23852408
"coroutine constructed with incorrect number of substitutions"
23862409
);
23872410
Ty::new(tcx, Coroutine(def_id, coroutine_args))

compiler/rustc_mir_transform/src/coroutine.rs

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@
5050
//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
5151
//! Otherwise it drops all the values in scope at the last suspension point.
5252
53+
mod by_move_body;
54+
pub use by_move_body::ByMoveBody;
55+
5356
use crate::abort_unwinding_calls;
5457
use crate::deref_separator::deref_finder;
5558
use crate::errors;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use rustc_data_structures::fx::FxIndexSet;
2+
use rustc_hir as hir;
3+
use rustc_middle::mir::visit::MutVisitor;
4+
use rustc_middle::mir::{self, MirPass};
5+
use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt};
6+
use rustc_target::abi::FieldIdx;
7+
8+
pub struct ByMoveBody;
9+
10+
impl<'tcx> MirPass<'tcx> for ByMoveBody {
11+
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) {
12+
let Some(coroutine_def_id) = body.source.def_id().as_local() else {
13+
return;
14+
};
15+
let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
16+
tcx.coroutine_kind(coroutine_def_id)
17+
else {
18+
return;
19+
};
20+
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
21+
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() };
22+
if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce {
23+
return;
24+
}
25+
26+
let mut by_ref_fields = FxIndexSet::default();
27+
let by_move_upvars = Ty::new_tup_from_iter(
28+
tcx,
29+
tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| {
30+
if capture.is_by_ref() {
31+
by_ref_fields.insert(FieldIdx::from_usize(idx));
32+
}
33+
capture.place.ty()
34+
}),
35+
);
36+
let by_move_coroutine_ty = Ty::new_coroutine(
37+
tcx,
38+
coroutine_def_id.to_def_id(),
39+
ty::CoroutineArgs::new(
40+
tcx,
41+
ty::CoroutineArgsParts {
42+
parent_args: args.as_coroutine().parent_args(),
43+
kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce),
44+
resume_ty: args.as_coroutine().resume_ty(),
45+
yield_ty: args.as_coroutine().yield_ty(),
46+
return_ty: args.as_coroutine().return_ty(),
47+
witness: args.as_coroutine().witness(),
48+
tupled_upvars_ty: by_move_upvars,
49+
},
50+
)
51+
.args,
52+
);
53+
54+
let mut by_move_body = body.clone();
55+
MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body);
56+
by_move_body.source = mir::MirSource {
57+
instance: InstanceDef::CoroutineByMoveShim {
58+
coroutine_def_id: coroutine_def_id.to_def_id(),
59+
},
60+
promoted: None,
61+
};
62+
63+
body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body);
64+
}
65+
}
66+
67+
struct MakeByMoveBody<'tcx> {
68+
tcx: TyCtxt<'tcx>,
69+
by_ref_fields: FxIndexSet<FieldIdx>,
70+
by_move_coroutine_ty: Ty<'tcx>,
71+
}
72+
73+
impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
74+
fn tcx(&self) -> TyCtxt<'tcx> {
75+
self.tcx
76+
}
77+
78+
fn visit_place(
79+
&mut self,
80+
place: &mut mir::Place<'tcx>,
81+
context: mir::visit::PlaceContext,
82+
location: mir::Location,
83+
) {
84+
if place.local == ty::CAPTURE_STRUCT_LOCAL
85+
&& !place.projection.is_empty()
86+
&& let mir::ProjectionElem::Field(idx, ty) = place.projection[0]
87+
&& self.by_ref_fields.contains(&idx)
88+
{
89+
let (begin, end) = place.projection[1..].split_first().unwrap();
90+
assert_eq!(*begin, mir::ProjectionElem::Deref);
91+
*place = mir::Place {
92+
local: place.local,
93+
projection: self.tcx.mk_place_elems_from_iter(
94+
[mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)]
95+
.into_iter()
96+
.chain(end.iter().copied()),
97+
),
98+
};
99+
}
100+
self.super_place(place, context, location);
101+
}
102+
103+
fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
104+
if local == ty::CAPTURE_STRUCT_LOCAL {
105+
local_decl.ty = self.by_move_coroutine_ty;
106+
}
107+
}
108+
}

compiler/rustc_mir_transform/src/inline.rs

+1
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ impl<'tcx> Inliner<'tcx> {
318318
| InstanceDef::FnPtrShim(..)
319319
| InstanceDef::ClosureOnceShim { .. }
320320
| InstanceDef::ConstructCoroutineInClosureShim { .. }
321+
| InstanceDef::CoroutineByMoveShim { .. }
321322
| InstanceDef::DropGlue(..)
322323
| InstanceDef::CloneShim(..)
323324
| InstanceDef::ThreadLocalShim(..)

0 commit comments

Comments
 (0)