|
| 1 | +use rustc_data_structures::fx::FxIndexSet; |
| 2 | +use rustc_hir as hir; |
| 3 | +use rustc_middle::mir::visit::MutVisitor; |
| 4 | +use rustc_middle::mir::{self, MirPass}; |
| 5 | +use rustc_middle::ty::{self, InstanceDef, Ty, TyCtxt}; |
| 6 | +use rustc_target::abi::FieldIdx; |
| 7 | + |
| 8 | +pub struct ByMoveBody; |
| 9 | + |
| 10 | +impl<'tcx> MirPass<'tcx> for ByMoveBody { |
| 11 | + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut mir::Body<'tcx>) { |
| 12 | + let Some(coroutine_def_id) = body.source.def_id().as_local() else { |
| 13 | + return; |
| 14 | + }; |
| 15 | + let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) = |
| 16 | + tcx.coroutine_kind(coroutine_def_id) |
| 17 | + else { |
| 18 | + return; |
| 19 | + }; |
| 20 | + let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty; |
| 21 | + let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!() }; |
| 22 | + if args.as_coroutine().kind_ty().to_opt_closure_kind().unwrap() == ty::ClosureKind::FnOnce { |
| 23 | + return; |
| 24 | + } |
| 25 | + |
| 26 | + let mut by_ref_fields = FxIndexSet::default(); |
| 27 | + let by_move_upvars = Ty::new_tup_from_iter( |
| 28 | + tcx, |
| 29 | + tcx.closure_captures(coroutine_def_id).iter().enumerate().map(|(idx, capture)| { |
| 30 | + if capture.is_by_ref() { |
| 31 | + by_ref_fields.insert(FieldIdx::from_usize(idx)); |
| 32 | + } |
| 33 | + capture.place.ty() |
| 34 | + }), |
| 35 | + ); |
| 36 | + let by_move_coroutine_ty = Ty::new_coroutine( |
| 37 | + tcx, |
| 38 | + coroutine_def_id.to_def_id(), |
| 39 | + ty::CoroutineArgs::new( |
| 40 | + tcx, |
| 41 | + ty::CoroutineArgsParts { |
| 42 | + parent_args: args.as_coroutine().parent_args(), |
| 43 | + kind_ty: Ty::from_closure_kind(tcx, ty::ClosureKind::FnOnce), |
| 44 | + resume_ty: args.as_coroutine().resume_ty(), |
| 45 | + yield_ty: args.as_coroutine().yield_ty(), |
| 46 | + return_ty: args.as_coroutine().return_ty(), |
| 47 | + witness: args.as_coroutine().witness(), |
| 48 | + tupled_upvars_ty: by_move_upvars, |
| 49 | + }, |
| 50 | + ) |
| 51 | + .args, |
| 52 | + ); |
| 53 | + |
| 54 | + let mut by_move_body = body.clone(); |
| 55 | + MakeByMoveBody { tcx, by_ref_fields, by_move_coroutine_ty }.visit_body(&mut by_move_body); |
| 56 | + by_move_body.source = mir::MirSource { |
| 57 | + instance: InstanceDef::CoroutineByMoveShim { |
| 58 | + coroutine_def_id: coroutine_def_id.to_def_id(), |
| 59 | + }, |
| 60 | + promoted: None, |
| 61 | + }; |
| 62 | + |
| 63 | + body.coroutine.as_mut().unwrap().by_move_body = Some(by_move_body); |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +struct MakeByMoveBody<'tcx> { |
| 68 | + tcx: TyCtxt<'tcx>, |
| 69 | + by_ref_fields: FxIndexSet<FieldIdx>, |
| 70 | + by_move_coroutine_ty: Ty<'tcx>, |
| 71 | +} |
| 72 | + |
| 73 | +impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> { |
| 74 | + fn tcx(&self) -> TyCtxt<'tcx> { |
| 75 | + self.tcx |
| 76 | + } |
| 77 | + |
| 78 | + fn visit_place( |
| 79 | + &mut self, |
| 80 | + place: &mut mir::Place<'tcx>, |
| 81 | + context: mir::visit::PlaceContext, |
| 82 | + location: mir::Location, |
| 83 | + ) { |
| 84 | + if place.local == ty::CAPTURE_STRUCT_LOCAL |
| 85 | + && !place.projection.is_empty() |
| 86 | + && let mir::ProjectionElem::Field(idx, ty) = place.projection[0] |
| 87 | + && self.by_ref_fields.contains(&idx) |
| 88 | + { |
| 89 | + let (begin, end) = place.projection[1..].split_first().unwrap(); |
| 90 | + assert_eq!(*begin, mir::ProjectionElem::Deref); |
| 91 | + *place = mir::Place { |
| 92 | + local: place.local, |
| 93 | + projection: self.tcx.mk_place_elems_from_iter( |
| 94 | + [mir::ProjectionElem::Field(idx, ty.builtin_deref(true).unwrap().ty)] |
| 95 | + .into_iter() |
| 96 | + .chain(end.iter().copied()), |
| 97 | + ), |
| 98 | + }; |
| 99 | + } |
| 100 | + self.super_place(place, context, location); |
| 101 | + } |
| 102 | + |
| 103 | + fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) { |
| 104 | + if local == ty::CAPTURE_STRUCT_LOCAL { |
| 105 | + local_decl.ty = self.by_move_coroutine_ty; |
| 106 | + } |
| 107 | + } |
| 108 | +} |
0 commit comments