diff --git a/crates/paralegal-flow/src/ana/inline_judge.rs b/crates/paralegal-flow/src/ana/inline_judge.rs index 183ee2d809..30b451205a 100644 --- a/crates/paralegal-flow/src/ana/inline_judge.rs +++ b/crates/paralegal-flow/src/ana/inline_judge.rs @@ -3,9 +3,13 @@ use std::rc::Rc; use flowistry_pdg_construction::{body_cache::BodyCache, CallInfo}; use paralegal_spdg::{utils::write_sep, Identifier}; use rustc_hash::FxHashSet; -use rustc_hir::def_id::{CrateNum, LOCAL_CRATE}; +use rustc_hir::{ + def::DefKind, + def_id::{CrateNum, DefId, LOCAL_CRATE}, +}; use rustc_middle::ty::{ - BoundVariableKind, ClauseKind, ImplPolarity, Instance, ParamEnv, TraitPredicate, + AssocKind, BoundVariableKind, Clause, ClauseKind, ImplPolarity, Instance, ParamEnv, + ProjectionPredicate, TraitPredicate, Ty, }; use rustc_span::{Span, Symbol}; use rustc_type_ir::TyKind; @@ -118,111 +122,164 @@ impl<'tcx> InlineJudge<'tcx> { call_span: Span, emit_err: bool, ) { - let sess = self.tcx().sess; - let predicates = self - .tcx() - .predicates_of(resolved.def_id()) - .instantiate(self.tcx(), resolved.args); - for (clause, span) in &predicates { - let err = move |s: &str| { - let msg = format!("Cannot verify that non-inlined function is safe due to: {s}"); - if emit_err { - let mut diagnostic = sess.struct_span_err(span, msg); - diagnostic.span_note(call_span, "Called from here"); - diagnostic.emit(); - } else { - let mut diagnostic = sess.struct_span_warn(span, msg); - diagnostic.span_note(call_span, "Called from here"); - diagnostic.emit(); - } - }; - let err_markers = |s: &str, markers: &[Identifier]| { - if !markers.is_empty() { - err(&format!( - "{s}: found marker(s) {}", - Print(|fmt| write_sep(fmt, ", ", markers, |elem, fmt| write!( - fmt, - "'{elem}'" - ))) - )); - } - }; - let kind = clause.kind(); - for bound in kind.bound_vars() { - match bound { - BoundVariableKind::Ty(t) => err(&format!("bound type {t:?}")), - BoundVariableKind::Const | BoundVariableKind::Region(_) => (), - } + SafetyChecker { + tcx: self.tcx(), + emit_err, + param_env, + resolved, + call_span, + marker_ctx: self.marker_ctx.clone(), + } + .check() + } +} + +struct SafetyChecker<'tcx> { + tcx: TyCtxt<'tcx>, + emit_err: bool, + param_env: ParamEnv<'tcx>, + resolved: Instance<'tcx>, + call_span: Span, + marker_ctx: MarkerCtx<'tcx>, +} + +impl<'tcx> SafetyChecker<'tcx> { + fn err(&self, s: &str, span: Span) { + let sess = self.tcx.sess; + let msg = format!("Cannot verify that non-inlined function is safe due to: {s}"); + if self.emit_err { + let mut diagnostic = sess.struct_span_err(span, msg); + diagnostic.span_note(self.call_span, "Called from here"); + diagnostic.emit(); + } else { + let mut diagnostic = sess.struct_span_warn(span, msg); + diagnostic.span_note(self.call_span, "Called from here"); + diagnostic.emit(); + } + } + + fn err_markers(&self, s: &str, markers: &[Identifier], span: Span) { + if !markers.is_empty() { + self.err( + &format!( + "{s}: found marker(s) {}", + Print(|fmt| write_sep(fmt, ", ", markers, |elem, fmt| write!(fmt, "'{elem}'"))) + ), + span, + ); + } + } + + fn check_projection_predicate(&self, predicate: &ProjectionPredicate<'tcx>, span: Span) { + if let Some(t) = predicate.term.ty() { + let t = self.tcx.normalize_erasing_regions(self.param_env, t); + let markers = self.marker_ctx.deep_type_markers(t); + if !markers.is_empty() { + let markers = markers.iter().map(|t| t.1).collect::>(); + self.err_markers( + &format!("type {t:?} is not approximation safe"), + &markers, + span, + ); } + } + } - match kind.skip_binder() { - ClauseKind::TypeOutlives(_) - | ClauseKind::WellFormed(_) - | ClauseKind::ConstArgHasType(..) - | ClauseKind::ConstEvaluatable(_) - | ClauseKind::RegionOutlives(_) => { - // These predicates do not allow for "code injection" since they do not concern things that can be marked. - } - ClauseKind::Projection(p) => { - // ProjectionPredicate(AliasTy { args: [Iter], def_id: IntoIterator::Item }, Term::Ty( as std::iter::Iterator>::Item)) - println!("Found projection {p:?}"); - if let Some(t) = p.term.ty() { - let t = self.tcx.normalize_erasing_regions(param_env, t); - let markers = self.marker_ctx().deep_type_markers(t); - if !markers.is_empty() { - let markers = markers.iter().map(|t| t.1).collect::>(); - err_markers(&format!("type {t:?} is not approximation safe"), &markers); + fn check_trait_predicate(&self, predicate: &TraitPredicate<'tcx>, span: Span) { + match predicate { + TraitPredicate { + polarity: ImplPolarity::Positive, + trait_ref, + } if !self.tcx.trait_is_auto(trait_ref.def_id) => { + let ref_1 = trait_ref.args[0]; + let Some(self_ty) = ref_1.as_type() else { + self.err("expected self type to be type, got {ref_1:?}", span); + return; + }; + + if self.tcx.is_fn_trait(trait_ref.def_id) { + let instance = match self_ty.kind() { + TyKind::Closure(id, args) | TyKind::FnDef(id, args) => { + Instance::resolve(self.tcx, ParamEnv::reveal_all(), *id, args) } - } - } - ClauseKind::Trait(TraitPredicate { - polarity: ImplPolarity::Positive, - trait_ref, - }) if !self.tcx().trait_is_auto(trait_ref.def_id) => { - let ref_1 = trait_ref.args[0]; - let Some(self_ty) = ref_1.as_type() else { - err("expected self type to be type, got {ref_1:?}"); - continue; - }; - - if self.tcx().is_fn_trait(trait_ref.def_id) { - let instance = match self_ty.kind() { - TyKind::Closure(id, args) | TyKind::FnDef(id, args) => { - Instance::resolve(self.tcx(), ParamEnv::reveal_all(), *id, args) - } - _ => { - err(&format!( + _ => { + self.err( + &format!( "fn-trait instance for {self_ty:?} not being a function of closure" - )); - continue; - } - } - .unwrap() - .unwrap(); - let markers = self.marker_ctx().get_reachable_markers(instance); - if !markers.is_empty() { - err_markers( - &format!("closure {instance:?} is not approximation safe"), - markers, + ), + span, ); + return; } - } else { - self.tcx() - .for_each_relevant_impl(trait_ref.def_id, self_ty, |impl_| { - for method in self.tcx().associated_item_def_ids(impl_) { - let markers = self.marker_ctx().get_reachable_markers(*method); - if !markers.is_empty() { - err_markers( - &format!("impl {impl_:?} for {self_ty:?}"), - markers, - ) - } - } - }) + } + .unwrap() + .unwrap(); + let markers = self.marker_ctx.get_reachable_markers(instance); + if !markers.is_empty() { + self.err_markers( + &format!("closure {instance:?} is not approximation safe"), + markers, + span, + ); + } + } else { + self.tcx + .for_each_relevant_impl(trait_ref.def_id, self_ty, |r#impl| { + self.check_impl(r#impl, self_ty, span) + }) + } + } + _ => (), + } + } + + fn check_impl(&self, r#impl: DefId, self_ty: Ty<'tcx>, span: Span) { + for item in self.tcx.associated_items(r#impl).in_definition_order() { + // NOTE: We don't need to check markers on types here, because they + // will be checked if there is a method that produces (or consumes) + // this type. + match item.kind { + AssocKind::Fn => { + let method = item.def_id; + let markers = self.marker_ctx.get_reachable_markers(method); + if !markers.is_empty() { + self.err_markers(&self.tcx.def_path_str(method), markers, span) } } - _ => (), + AssocKind::Const | AssocKind::Type => (), } } } + + fn check_predicate(&self, clause: Clause<'tcx>, span: Span) { + let kind = clause.kind(); + for bound in kind.bound_vars() { + match bound { + BoundVariableKind::Ty(t) => self.err(&format!("bound type {t:?}"), span), + BoundVariableKind::Const | BoundVariableKind::Region(_) => (), + } + } + + match &kind.skip_binder() { + ClauseKind::TypeOutlives(_) + | ClauseKind::WellFormed(_) + | ClauseKind::ConstArgHasType(..) + | ClauseKind::ConstEvaluatable(_) + | ClauseKind::RegionOutlives(_) => { + // These predicates do not allow for "code injection" since they do not concern things that can be marked. + } + ClauseKind::Projection(predicate) => self.check_projection_predicate(predicate, span), + ClauseKind::Trait(predicate) => self.check_trait_predicate(predicate, span), + } + } + + fn check(&self) { + let predicates = self + .tcx + .predicates_of(self.resolved.def_id()) + .instantiate(self.tcx, self.resolved.args); + for (clause, span) in &predicates { + self.check_predicate(clause, span) + } + } } diff --git a/crates/paralegal-flow/src/ana/mod.rs b/crates/paralegal-flow/src/ana/mod.rs index 55a8411cc7..146188933d 100644 --- a/crates/paralegal-flow/src/ana/mod.rs +++ b/crates/paralegal-flow/src/ana/mod.rs @@ -420,8 +420,6 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { fn on_inline(&self, info: CallInfo<'tcx, '_>) -> CallChanges<'tcx> { let changes = CallChanges::default(); - let mut skip = SkipCall::Skip; - let skip = match self.judge.should_inline(&info) { InlineJudgement::NoInline => SkipCall::Skip, InlineJudgement::UseFlowModel(model) => {