diff --git a/src/librustc_mir/transform/validate.rs b/src/librustc_mir/transform/validate.rs index 8150c328316cb..0145d132ac79a 100644 --- a/src/librustc_mir/transform/validate.rs +++ b/src/librustc_mir/transform/validate.rs @@ -1,15 +1,20 @@ //! Validates the MIR to ensure that invariants are upheld. use super::{MirPass, MirSource}; +use rustc_hir::lang_items::FnOnceTraitLangItem; +use rustc_hir::Constness; +use rustc_infer::infer::TyCtxtInferExt; use rustc_middle::mir::visit::Visitor; use rustc_middle::{ mir::{ BasicBlock, Body, Location, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }, - ty::{self, ParamEnv, TyCtxt}, + ty::{self, ParamEnv, ToPredicate, Ty, TyCtxt}, }; use rustc_span::def_id::DefId; +use rustc_trait_selection::traits; +use rustc_trait_selection::traits::query::evaluate_obligation::InferCtxtExt; #[derive(Copy, Clone, Debug)] enum EdgeKind { @@ -26,12 +31,16 @@ impl<'tcx> MirPass<'tcx> for Validator { fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut Body<'tcx>) { let def_id = source.def_id(); let param_env = tcx.param_env(def_id); - TypeChecker { when: &self.when, def_id, body, tcx, param_env }.visit_body(body); + let validating_shim = + if let ty::InstanceDef::Item(_) = source.instance { false } else { true }; + TypeChecker { when: &self.when, def_id, body, tcx, param_env, validating_shim } + .visit_body(body); } } struct TypeChecker<'a, 'tcx> { when: &'a str, + validating_shim: bool, def_id: DefId, body: &'a Body<'tcx>, tcx: TyCtxt<'tcx>, @@ -83,6 +92,66 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> { self.fail(location, format!("encountered jump to invalid basic block {:?}", bb)) } } + + fn check_ty_callable(&self, location: Location, ty: Ty<'tcx>) { + if let ty::FnPtr(..) | ty::FnDef(..) = ty.kind { + // We have a `FnPtr` or `FnDef` which is trivially safe to call. + // + // By this point, calls to closures should already have been lowered to calls to + // `Fn*::call*` so we do not consider them callable. + } else if self.validating_shim && ty == self.tcx.types.self_param { + // FIXME(#69925): we shouldn't be special-casing for call-shims as we'd hope they + // have concrete substs by this point. + // + // We haven't got a `FnPtr` or `FnDef` but if we're looking at a MIR shim, this could + // be due to a `Self` type still hanging about. To avoid rejecting these shims we + // any type in MIR shims as callable so long as: + // 1. it's `Self` + // 2. it implements `FnOnce` + let fn_once_trait = self.tcx.require_lang_item(FnOnceTraitLangItem, None); + let item_def_id = self + .tcx + .associated_items(fn_once_trait) + .in_definition_order() + .next() + .unwrap() + .def_id; + self.tcx.infer_ctxt().enter(|infcx| { + let trait_ref = ty::TraitRef { + def_id: fn_once_trait, + substs: self.tcx.mk_substs_trait( + ty, + infcx.fresh_substs_for_item(self.body.span, item_def_id), + ), + }; + let predicate = ty::PredicateKind::Trait( + ty::Binder::bind(ty::TraitPredicate { trait_ref }), + Constness::NotConst, + ) + .to_predicate(self.tcx); + let obligation = traits::Obligation::new( + traits::ObligationCause::dummy(), + self.param_env, + predicate, + ); + if !infcx.predicate_may_hold(&obligation) { + self.fail( + location, + format!( + "encountered {} in `Call` terminator of shim \ + which does not implement `FnOnce`", + ty, + ), + ); + } + }); + } else { + self.fail( + location, + format!("encountered non-callable type {} in `Call` terminator", ty), + ); + } + } } impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { @@ -151,14 +220,7 @@ impl<'a, 'tcx> Visitor<'tcx> for TypeChecker<'a, 'tcx> { } } TerminatorKind::Call { func, destination, cleanup, .. } => { - let func_ty = func.ty(&self.body.local_decls, self.tcx); - match func_ty.kind { - ty::FnPtr(..) | ty::FnDef(..) => {} - _ => self.fail( - location, - format!("encountered non-callable type {} in `Call` terminator", func_ty), - ), - } + self.check_ty_callable(location, &func.ty(&self.body.local_decls, self.tcx)); if let Some((_, target)) = destination { self.check_edge(location, *target, EdgeKind::Normal); }