diff --git a/crates/flowistry_pdg_construction/src/construct.rs b/crates/flowistry_pdg_construction/src/construct.rs index 29f09d1d71..0d67618f63 100644 --- a/crates/flowistry_pdg_construction/src/construct.rs +++ b/crates/flowistry_pdg_construction/src/construct.rs @@ -160,6 +160,11 @@ impl<'tcx> MemoPdgConstructor<'tcx> { pub fn body_cache(&self) -> &Rc> { &self.body_cache } + + /// Used for testing. + pub fn take_call_changes_policy(&mut self) -> Option + 'tcx>> { + self.call_change_callback.take() + } } type LocalAnalysisResults<'tcx, 'mir> = Results<'tcx, &'mir LocalAnalysis<'tcx, 'mir>>; diff --git a/crates/flowistry_pdg_construction/tests/pdg.rs b/crates/flowistry_pdg_construction/tests/pdg.rs index 7ce7ddc908..dc74854a10 100644 --- a/crates/flowistry_pdg_construction/tests/pdg.rs +++ b/crates/flowistry_pdg_construction/tests/pdg.rs @@ -5,14 +5,14 @@ extern crate rustc_hir; extern crate rustc_middle; extern crate rustc_span; -use std::collections::HashSet; +use std::{collections::HashSet, rc::Rc}; use either::Either; use flowistry::mir::FlowistryInput; use flowistry_pdg_construction::{ body_cache::{dump_mir_and_borrowck_facts, BodyCache}, graph::{DepEdge, DepGraph}, - CallChangeCallbackFn, CallChanges, MemoPdgConstructor, SkipCall, + CallChangeCallback, CallChangeCallbackFn, CallChanges, MemoPdgConstructor, SkipCall, }; use itertools::Itertools; use rustc_hir::def_id::LocalDefId; @@ -34,6 +34,22 @@ fn get_main(tcx: TyCtxt<'_>) -> LocalDefId { .expect("Missing main") } +struct LocalLoadingOnly<'tcx>(Option + 'tcx>>); + +impl<'tcx> CallChangeCallback<'tcx> for LocalLoadingOnly<'tcx> { + fn on_inline(&self, info: flowistry_pdg_construction::CallInfo<'tcx>) -> CallChanges { + let is_local = info.callee.def_id().is_local(); + let mut changes = self + .0 + .as_ref() + .map_or_else(CallChanges::default, |cb| cb.on_inline(info)); + if !is_local { + changes = changes.with_skip(SkipCall::Skip); + } + changes + } +} + fn pdg( input: impl Into, configure: impl for<'tcx> FnOnce(TyCtxt<'tcx>, &mut MemoPdgConstructor<'tcx>) + Send, @@ -47,6 +63,8 @@ fn pdg( let def_id = get_main(tcx); let mut memo = MemoPdgConstructor::new(tcx); configure(tcx, &mut memo); + let policy = memo.take_call_changes_policy(); + memo.with_call_change_callback(LocalLoadingOnly(policy)); let pdg = memo.construct_graph(def_id); tests(tcx, memo.body_cache(), pdg) })