Skip to content

Commit

Permalink
Implemented basic flow models
Browse files Browse the repository at this point in the history
  • Loading branch information
JustusAdam committed Jul 31, 2024
1 parent 1ba3fd6 commit 681c3b5
Show file tree
Hide file tree
Showing 14 changed files with 314 additions and 81 deletions.
42 changes: 29 additions & 13 deletions crates/flowistry_pdg_construction/src/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

use flowistry_pdg::{rustc_portable::Location, CallString};

use rustc_middle::ty::Instance;
use rustc_middle::{
mir::{self, Operand},
ty::{Instance, ParamEnv},
};
use rustc_span::Span;

use crate::calling_convention::CallingConvention;

pub trait CallChangeCallback<'tcx> {
fn on_inline(&self, info: CallInfo<'tcx>) -> CallChanges;
fn on_inline(&self, info: CallInfo<'tcx, '_>) -> CallChanges<'tcx>;

fn on_inline_miss(
&self,
Expand All @@ -20,17 +25,17 @@ pub trait CallChangeCallback<'tcx> {
}

pub struct CallChangeCallbackFn<'tcx> {
f: Box<dyn Fn(CallInfo<'tcx>) -> CallChanges + 'tcx>,
f: Box<dyn Fn(CallInfo<'tcx, '_>) -> CallChanges<'tcx> + 'tcx>,
}

impl<'tcx> CallChangeCallbackFn<'tcx> {
pub fn new(f: impl Fn(CallInfo<'tcx>) -> CallChanges + 'tcx) -> Self {
pub fn new(f: impl Fn(CallInfo<'tcx, '_>) -> CallChanges<'tcx> + 'tcx) -> Self {
Self { f: Box::new(f) }
}
}

impl<'tcx> CallChangeCallback<'tcx> for CallChangeCallbackFn<'tcx> {
fn on_inline(&self, info: CallInfo<'tcx>) -> CallChanges {
fn on_inline(&self, info: CallInfo<'tcx, '_>) -> CallChanges<'tcx> {
(self.f)(info)
}
}
Expand All @@ -40,7 +45,7 @@ pub enum InlineMissReason {
Async(String),
}

impl Default for CallChanges {
impl<'tcx> Default for CallChanges<'tcx> {
fn default() -> Self {
CallChanges {
skip: SkipCall::NoSkip,
Expand All @@ -49,7 +54,7 @@ impl Default for CallChanges {
}

/// Information about the function being called.
pub struct CallInfo<'tcx> {
pub struct CallInfo<'tcx, 'mir> {
/// The potentially-monomorphized resolution of the callee.
pub callee: Instance<'tcx>,

Expand All @@ -64,29 +69,40 @@ pub struct CallInfo<'tcx> {
pub is_cached: bool,

pub span: Span,

pub arguments: &'mir [Operand<'tcx>],

pub caller_body: &'mir mir::Body<'tcx>,
pub param_env: ParamEnv<'tcx>,
}

/// User-provided changes to the default PDG construction behavior for function calls.
///
/// Construct [`CallChanges`] via [`CallChanges::default`].
#[derive(Debug)]
pub struct CallChanges {
pub(crate) skip: SkipCall,
pub struct CallChanges<'tcx> {
pub(crate) skip: SkipCall<'tcx>,
}

/// Whether or not to skip recursing into a function call during PDG construction.
#[derive(Debug)]
pub enum SkipCall {
pub enum SkipCall<'tcx> {
/// Skip the function, and perform a modular approxmation of its effects.
Skip,

/// Recurse into the function as normal.
NoSkip,

/// Replace with a call to this other function and arguments.
Replace {
instance: Instance<'tcx>,
calling_convention: CallingConvention<'tcx>,
},
}

impl CallChanges {
/// Inidicate whether or not to skip recursing into the given function.
pub fn with_skip(self, skip: SkipCall) -> Self {
impl<'tcx> CallChanges<'tcx> {
/// Indicate whether or not to skip recursing into the given function.
pub fn with_skip(self, skip: SkipCall<'tcx>) -> Self {
CallChanges { skip }
}
}
23 changes: 13 additions & 10 deletions crates/flowistry_pdg_construction/src/calling_convention.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::borrow::Cow;

use flowistry_pdg::rustc_portable::DefId;
use log::trace;
use rustc_abi::FieldIdx;
Expand All @@ -9,26 +11,27 @@ use rustc_middle::{

use crate::{async_support::AsyncInfo, local_analysis::CallKind, utils};

pub enum CallingConvention<'tcx, 'a> {
Direct(&'a [Operand<'tcx>]),
#[derive(Debug)]
pub enum CallingConvention<'tcx> {
Direct(Box<[Operand<'tcx>]>),
Indirect {
closure_arg: &'a Operand<'tcx>,
tupled_arguments: &'a Operand<'tcx>,
closure_arg: Operand<'tcx>,
tupled_arguments: Operand<'tcx>,
},
Async(Place<'tcx>),
}

impl<'tcx, 'a> CallingConvention<'tcx, 'a> {
impl<'tcx> CallingConvention<'tcx> {
pub fn from_call_kind(
kind: &CallKind<'tcx>,
args: &'a [Operand<'tcx>],
) -> CallingConvention<'tcx, 'a> {
args: Cow<'_, [Operand<'tcx>]>,
) -> CallingConvention<'tcx> {
match kind {
CallKind::AsyncPoll(poll) => CallingConvention::Async(poll.generator_data),
CallKind::Direct => CallingConvention::Direct(args),
CallKind::Direct => CallingConvention::Direct(args.into()),
CallKind::Indirect => CallingConvention::Indirect {
closure_arg: &args[0],
tupled_arguments: &args[1],
closure_arg: args[0].clone(),
tupled_arguments: args[1].clone(),
},
}
}
Expand Down
15 changes: 9 additions & 6 deletions crates/flowistry_pdg_construction/src/construct.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::rc::Rc;
use std::{borrow::Cow, rc::Rc};

use either::Either;
use flowistry::mir::FlowistryInput;
Expand Down Expand Up @@ -259,8 +259,8 @@ impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, LocalAnalysisResults<'tcx, 'mir>>
if matches!(
constructor.determine_call_handling(
location,
func,
args,
Cow::Borrowed(func),
Cow::Borrowed(args),
terminator.source_info.span
),
Some(CallHandling::Ready { .. })
Expand Down Expand Up @@ -331,9 +331,12 @@ impl<'tcx> PartialGraph<'tcx> {
function: constructor.def_id,
};

let Some(handling) =
constructor.determine_call_handling(location, func, args, terminator.source_info.span)
else {
let Some(handling) = constructor.determine_call_handling(
location,
Cow::Borrowed(func),
Cow::Borrowed(args),
terminator.source_info.span,
) else {
return false;
};

Expand Down
2 changes: 1 addition & 1 deletion crates/flowistry_pdg_construction/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ use rustc_middle::ty::TyCtxt;
mod approximation;
mod async_support;
pub mod body_cache;
mod calling_convention;
pub mod calling_convention;
mod construct;
pub mod encoder;
pub mod graph;
Expand Down
51 changes: 32 additions & 19 deletions crates/flowistry_pdg_construction/src/local_analysis.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{collections::HashSet, iter, rc::Rc};
use std::{borrow::Cow, collections::HashSet, iter, rc::Rc};

use flowistry::mir::{placeinfo::PlaceInfo, FlowistryInput};
use flowistry_pdg::{CallString, GlobalLocation, RichLocation};
Expand Down Expand Up @@ -352,8 +352,8 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
pub(crate) fn determine_call_handling<'b>(
&'b self,
location: Location,
func: &Operand<'tcx>,
args: &'b [Operand<'tcx>],
func: Cow<'_, Operand<'tcx>>,
args: Cow<'b, [Operand<'tcx>]>,
span: Span,
) -> Option<CallHandling<'tcx, 'b>> {
let tcx = self.tcx();
Expand All @@ -363,7 +363,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
self.tcx().def_path_str(self.def_id)
);

let Some((called_def_id, generic_args)) = self.operand_to_def_id(func) else {
let Some((called_def_id, generic_args)) = self.operand_to_def_id(&func) else {
tcx.sess
.span_err(span, "Operand is cannot be interpreted as function");
return None;
Expand All @@ -372,7 +372,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {

// Monomorphize the called function with the known generic_args.
let param_env = tcx.param_env(self.def_id);
let Some(resolved_fn) =
let Some(mut resolved_fn) =
utils::try_resolve_function(self.tcx(), called_def_id, param_env, generic_args)
else {
let dynamics = generic_args.iter()
Expand Down Expand Up @@ -426,9 +426,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
return Some(CallHandling::ApproxAsyncSM(handler));
};

let call_kind = self.classify_call_kind(called_def_id, resolved_fn, args, span);

let calling_convention = CallingConvention::from_call_kind(&call_kind, args);
let call_kind = self.classify_call_kind(called_def_id, resolved_fn, &args, span);

trace!(
" Handling call! with kind {}",
Expand All @@ -449,7 +447,7 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
callee: resolved_fn,
call_string: self.make_call_string(location),
is_cached,
async_parent: if let CallKind::AsyncPoll(poll) = call_kind {
async_parent: if let CallKind::AsyncPoll(poll) = &call_kind {
// Special case for async. We ask for skipping not on the closure, but
// on the "async" function that created it. This is needed for
// consistency in skipping. Normally, when "poll" is inlined, mutations
Expand All @@ -463,6 +461,9 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
None
},
span,
arguments: &args,
caller_body: &self.mono_body,
param_env,
};
callback.on_inline(info)
});
Expand All @@ -484,16 +485,26 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
.then_some(CallHandling::ApproxAsyncFn);
}

if matches!(
call_changes,
let calling_convention = match call_changes {
Some(CallChanges {
skip: SkipCall::Skip,
..
})
) {
trace!(" Bailing because user callback said to bail");
return None;
}
}) => {
trace!(" Bailing because user callback said to bail");
return None;
}
Some(CallChanges {
skip:
SkipCall::Replace {
instance,
calling_convention,
},
}) => {
trace!(" Replacing call as instructed by user");
resolved_fn = instance;
calling_convention
}
_ => CallingConvention::from_call_kind(&call_kind, args),
};
let Some(descriptor) = self.memo.construct_for(resolved_fn) else {
return None;
};
Expand All @@ -518,7 +529,9 @@ impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> {
// Note: my comments here will use "child" to refer to the callee and
// "parent" to refer to the caller, since the words are most visually distinct.

let Some(preamble) = self.determine_call_handling(location, func, args, span) else {
let Some(preamble) =
self.determine_call_handling(location, Cow::Borrowed(func), Cow::Borrowed(args), span)
else {
return false;
};

Expand Down Expand Up @@ -764,7 +777,7 @@ pub enum CallKind<'tcx> {
pub(crate) enum CallHandling<'tcx, 'a> {
ApproxAsyncFn,
Ready {
calling_convention: CallingConvention<'tcx, 'a>,
calling_convention: CallingConvention<'tcx>,
descriptor: &'a PartialGraph<'tcx>,
},
ApproxAsyncSM(ApproximationHandler<'tcx, 'a>),
Expand Down
49 changes: 38 additions & 11 deletions crates/paralegal-flow/src/ana/inline_judge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::rc::Rc;

use flowistry_pdg_construction::{body_cache::BodyCache, CallInfo};
use paralegal_spdg::{utils::write_sep, Identifier};
use rustc_hash::FxHashSet;
use rustc_hash::{FxHashMap, FxHashSet};
use rustc_hir::def_id::{CrateNum, DefId, LOCAL_CRATE};
use rustc_middle::ty::{
BoundVariableKind, ClauseKind, ImplPolarity, Instance, ParamEnv, TraitPredicate,
Expand All @@ -11,9 +11,14 @@ use rustc_span::Symbol;
use rustc_type_ir::TyKind;

use crate::{
ana::Print, ann::db::MarkerDatabase, args::InliningDepth, AnalysisCtrl, Args, MarkerCtx, TyCtxt,
ana::Print,
ann::db::MarkerDatabase,
args::{FlowModel, InliningDepth},
AnalysisCtrl, Args, MarkerCtx, TyCtxt,
};

use super::resolve::expect_resolve_string_to_def_id;

/// The interpretation of marker placement as it pertains to inlining and inline
/// elision.
///
Expand All @@ -27,6 +32,22 @@ pub struct InlineJudge<'tcx> {
tcx: TyCtxt<'tcx>,
}

pub enum InlineJudgement {
Inline,
UseFlowModel(&'static FlowModel),
NoInline,
}

impl From<bool> for InlineJudgement {
fn from(value: bool) -> Self {
if value {
InlineJudgement::Inline
} else {
InlineJudgement::NoInline
}
}
}

impl<'tcx> InlineJudge<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>, body_cache: Rc<BodyCache<'tcx>>, opts: &'static Args) -> Self {
let included_crate_names = opts
Expand Down Expand Up @@ -61,23 +82,29 @@ impl<'tcx> InlineJudge<'tcx> {
}

/// Should we perform inlining on this function?
pub fn should_inline(&self, info: &CallInfo<'tcx>) -> bool {
pub fn should_inline(&self, info: &CallInfo<'tcx, '_>) -> InlineJudgement {
let marker_target = info.async_parent.unwrap_or(info.callee);
let marker_target_def_id = marker_target.def_id();
if let Some(model) = self.marker_ctx().has_flow_model(marker_target_def_id) {
return InlineJudgement::UseFlowModel(model);
}
let is_marked = self.marker_ctx.is_marked(marker_target_def_id);
let should_inline = match self.analysis_control.inlining_depth() {
_ if !self.included_crates.contains(&marker_target_def_id.krate) || is_marked => false,
let judgement = match self.analysis_control.inlining_depth() {
_ if !self.included_crates.contains(&marker_target_def_id.krate) || is_marked => {
InlineJudgement::NoInline
}
InliningDepth::Adaptive => self
.marker_ctx
.has_transitive_reachable_markers(marker_target),
InliningDepth::Shallow => false,
InliningDepth::Unconstrained => true,
.has_transitive_reachable_markers(marker_target)
.into(),
InliningDepth::Shallow => InlineJudgement::NoInline,
InliningDepth::Unconstrained => InlineJudgement::Inline,
};
if !should_inline {
if !matches!(judgement, InlineJudgement::NoInline) {
//println!("Ensuring approximate safety of {:?}", info.callee);
self.ensure_is_safe_to_approximate(resolved, !is_marked)
self.ensure_is_safe_to_approximate(info.callee, !is_marked)
}
should_inline
judgement
}

pub fn marker_ctx(&self) -> &MarkerCtx<'tcx> {
Expand Down
Loading

0 comments on commit 681c3b5

Please sign in to comment.