diff --git a/crates/flowistry_pdg_construction/src/callback.rs b/crates/flowistry_pdg_construction/src/callback.rs index 0d2822de43..bcd78b64eb 100644 --- a/crates/flowistry_pdg_construction/src/callback.rs +++ b/crates/flowistry_pdg_construction/src/callback.rs @@ -4,6 +4,7 @@ use flowistry_pdg::{rustc_portable::Location, CallString}; use rustc_hir::def_id::DefId; use rustc_middle::ty::Instance; +use rustc_span::Span; pub trait CallChangeCallback<'tcx> { fn on_inline(&self, info: CallInfo<'tcx>) -> CallChanges; @@ -63,7 +64,7 @@ pub struct CallInfo<'tcx> { /// Would the PDG for this function be served from the cache. pub is_cached: bool, - pub original_called_fn: DefId, + pub span: Span, } /// User-provided changes to the default PDG construction behavior for function calls. diff --git a/crates/flowistry_pdg_construction/src/local_analysis.rs b/crates/flowistry_pdg_construction/src/local_analysis.rs index 542b533f50..7124ae4558 100644 --- a/crates/flowistry_pdg_construction/src/local_analysis.rs +++ b/crates/flowistry_pdg_construction/src/local_analysis.rs @@ -446,7 +446,6 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> { let call_changes = self.call_change_callback().map(|callback| { let info = CallInfo { - original_called_fn: called_def_id, callee: resolved_fn, call_string: self.make_call_string(location), is_cached, @@ -463,6 +462,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> { } else { None }, + span, }; callback.on_inline(info) }); diff --git a/crates/paralegal-flow/src/ana/inline_judge.rs b/crates/paralegal-flow/src/ana/inline_judge.rs index 908615061a..955a224ad7 100644 --- a/crates/paralegal-flow/src/ana/inline_judge.rs +++ b/crates/paralegal-flow/src/ana/inline_judge.rs @@ -43,7 +43,7 @@ impl<'tcx> InlineJudge<'tcx> { .chain(Some(LOCAL_CRATE)) .collect::>(); let marker_ctx = - MarkerDatabase::init(tcx, opts, body_cache, included_crates.iter().copied()).into(); + MarkerDatabase::init(tcx, opts, body_cache, included_crates.clone()).into(); Self { marker_ctx, included_crates, @@ -82,12 +82,11 @@ impl<'tcx> InlineJudge<'tcx> { &self.marker_ctx } - pub fn ensure_is_safe_to_approximate(&self, original_target: DefId, resolved: Instance<'tcx>) { - println!("Ensuring approximation safety for {resolved:?}"); + pub fn ensure_is_safe_to_approximate(&self, resolved: Instance<'tcx>) { let sess = self.tcx().sess; let predicates = self .tcx() - .predicates_of(original_target) + .predicates_of(resolved.def_id()) .instantiate(self.tcx(), resolved.args); for (clause, span) in &predicates { let err = |s: &str| { diff --git a/crates/paralegal-flow/src/ana/mod.rs b/crates/paralegal-flow/src/ana/mod.rs index dbdbe81df2..260491411a 100644 --- a/crates/paralegal-flow/src/ana/mod.rs +++ b/crates/paralegal-flow/src/ana/mod.rs @@ -426,8 +426,8 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { }; if skip { - self.judge - .ensure_is_safe_to_approximate(info.original_called_fn, info.callee); + println!("Ensuring approximate safety of {:?}", info.callee); + self.judge.ensure_is_safe_to_approximate(info.callee); changes = changes.with_skip(SkipCall::Skip); } else { // record_inlining( diff --git a/crates/paralegal-flow/src/ann/db.rs b/crates/paralegal-flow/src/ann/db.rs index 30d1b6bbb2..1753ec0b0c 100644 --- a/crates/paralegal-flow/src/ann/db.rs +++ b/crates/paralegal-flow/src/ann/db.rs @@ -28,7 +28,7 @@ use flowistry_pdg_construction::{ }; use paralegal_spdg::Identifier; -use rustc_hash::FxHashMap; +use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hir::def_id::DefId; use rustc_hir::{def::DefKind, def_id::CrateNum}; use rustc_middle::{ @@ -182,6 +182,18 @@ impl<'tcx> MarkerCtx<'tcx> { pub fn get_reachable_markers(&self, res: impl Into>) -> &[Identifier] { let res = res.into(); + let def_id = res.def_id(); + if self.is_marked(def_id) { + trace!(" Is marked"); + return &[]; + } + if is_virtual(self.tcx(), def_id) { + trace!(" Is virtual"); + return &[]; + } + if !self.0.included_crates.contains(&def_id.krate) { + return &[]; + } self.db() .reachable_markers .get_maybe_recursive(res, |_| self.compute_reachable_markers(res)) @@ -219,14 +231,6 @@ impl<'tcx> MarkerCtx<'tcx> { /// computes it. fn compute_reachable_markers(&self, res: MaybeMonomorphized<'tcx>) -> Box<[Identifier]> { trace!("Computing reachable markers for {res:?}"); - if self.is_marked(res.def_id()) { - trace!(" Is marked"); - return Box::new([]); - } - if is_virtual(self.tcx(), res.def_id()) { - trace!(" Is virtual"); - return Box::new([]); - } let Some(body) = self.0.body_cache.get(res.def_id()) else { trace!(" Cannot find body"); return Box::new([]); @@ -543,6 +547,7 @@ pub struct MarkerDatabase<'tcx> { _config: &'static MarkerControl, type_markers: Cache, Box>, body_cache: Rc>, + included_crates: FxHashSet, } impl<'tcx> MarkerDatabase<'tcx> { @@ -551,16 +556,17 @@ impl<'tcx> MarkerDatabase<'tcx> { tcx: TyCtxt<'tcx>, args: &'static Args, body_cache: Rc>, - included_crates: impl IntoIterator, + included_crates: FxHashSet, ) -> Self { Self { tcx, - annotations: load_annotations(tcx, included_crates), + annotations: load_annotations(tcx, included_crates.iter().copied()), external_annotations: resolve_external_markers(args, tcx), reachable_markers: Default::default(), _config: args.marker_control(), type_markers: Default::default(), body_cache, + included_crates, } } }