Skip to content

Commit

Permalink
More flow model fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
JustusAdam committed Aug 30, 2024
1 parent f9cd8ce commit 22abdc2
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 59 deletions.
26 changes: 22 additions & 4 deletions crates/paralegal-flow/src/ana/graph_converter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ impl<'a, 'tcx, C: Extend<DefId>> GraphConverter<'tcx, 'a, C> {
generator.tcx,
def_id,
generator.pdg_constructor.body_cache(),
generator.marker_ctx().clone(),
),
stats,
})
Expand Down Expand Up @@ -599,12 +600,12 @@ mod call_string_resolver {
utils::{manufacture_substs_for, try_monomorphize, try_resolve_function},
};
use paralegal_spdg::Endpoint;
use rustc_middle::ty::Instance;
use rustc_middle::{mir::TerminatorKind, ty::Instance};
use rustc_utils::cache::Cache;

use crate::{Either, TyCtxt};
use crate::{Either, MarkerCtx, TyCtxt};

use super::{map_either, match_async_trait_assign, AsFnAndArgs};
use super::{func_of_term, map_either, match_async_trait_assign, AsFnAndArgs};

/// Cached resolution of [`CallString`]s to [`FnResolution`]s.
///
Expand All @@ -615,6 +616,7 @@ mod call_string_resolver {
tcx: TyCtxt<'tcx>,
entrypoint_is_async: bool,
body_cache: &'a BodyCache<'tcx>,
marker_context: MarkerCtx<'tcx>,
}

impl<'tcx, 'a> CallStringResolver<'tcx, 'a> {
Expand Down Expand Up @@ -645,12 +647,14 @@ mod call_string_resolver {
tcx: TyCtxt<'tcx>,
entrypoint: Endpoint,
body_cache: &'a BodyCache<'tcx>,
marker_context: MarkerCtx<'tcx>,
) -> Self {
Self {
cache: Default::default(),
tcx,
entrypoint_is_async: super::entrypoint_is_async(body_cache, tcx, entrypoint),
body_cache,
marker_context,
}
}

Expand Down Expand Up @@ -681,7 +685,21 @@ mod call_string_resolver {
},
);
let res = match normalized {
Either::Right(term) => term.as_instance_and_args(tcx).unwrap().0,
Either::Right(term) => {
let (def_id, args) = func_of_term(tcx, &term).unwrap();
let instance = Instance::expect_resolve(tcx, param_env, def_id, args);
if let Some(model) = self.marker_context.has_flow_model(def_id) {
let TerminatorKind::Call { args, .. } = &term.kind else {
unreachable!()
};
model
.apply(tcx, instance, param_env, args, term.source_info.span)
.unwrap()
.0
} else {
term.as_instance_and_args(tcx).unwrap().0
}
}
Either::Left(stmt) => {
let (def_id, generics) = match_async_trait_assign(&stmt).unwrap();
try_resolve_function(tcx, def_id, param_env, generics).unwrap()
Expand Down
11 changes: 9 additions & 2 deletions crates/paralegal-flow/src/ana/inline_judge.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::rc::Rc;

use flowistry_pdg_construction::{body_cache::BodyCache, CallInfo};
use flowistry_pdg_construction::{body_cache::BodyCache, utils::is_async, CallInfo};
use paralegal_spdg::{utils::write_sep, Identifier};
use rustc_hash::FxHashSet;
use rustc_hir::def_id::{CrateNum, DefId, LOCAL_CRATE};
Expand Down Expand Up @@ -90,7 +90,14 @@ impl<'tcx> InlineJudge<'tcx> {
let marker_target = info.async_parent.unwrap_or(info.callee);
let marker_target_def_id = marker_target.def_id();
if let Some(model) = self.marker_ctx().has_flow_model(marker_target_def_id) {
return InlineJudgement::UseFlowModel(model);
// If we're replacing an async function skip the poll call.
//
// I tried to have it replace the poll call only but that didn't seem to work.
return if info.async_parent.is_some() {
InlineJudgement::AbstractViaType
} else {
InlineJudgement::UseFlowModel(model)
};
}
let is_marked = self.marker_ctx.is_marked(marker_target_def_id);
let judgement = match self.opts.anactrl().inlining_depth() {
Expand Down
133 changes: 94 additions & 39 deletions crates/paralegal-flow/src/ana/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ use petgraph::visit::GraphBase;

use rustc_hir::{self as hir, def, def_id::DefId};
use rustc_middle::{
mir::Location,
mir::{Location, Operand},
ty::{Instance, ParamEnv, TyCtxt},
};
use rustc_span::{FileNameDisplayPreference, Span as RustSpan, Symbol};
use rustc_span::{ErrorGuaranteed, FileNameDisplayPreference, Span as RustSpan, Symbol};

mod graph_converter;
mod inline_judge;
Expand Down Expand Up @@ -412,52 +412,107 @@ struct MyCallback<'tcx> {
tcx: TyCtxt<'tcx>,
}

impl FlowModel {
fn apply<'tcx>(
&self,
tcx: TyCtxt<'tcx>,
function: Instance<'tcx>,
param_env: ParamEnv<'tcx>,
arguments: &[Operand<'tcx>],
at: RustSpan,
) -> Result<(Instance<'tcx>, CallingConvention<'tcx>), ErrorGuaranteed> {
match self {
FlowModel::SubClosure { generic_name } | FlowModel::SubFuture { generic_name } => {
let name = Symbol::intern(&generic_name);
let generics = tcx.generics_of(function.def_id());
let Some(param_index) = (0..generics.count()).find(|&idx| {
let param = generics.param_at(idx, tcx);
param.name == name
}) else {
return Err(tcx.sess.span_err(
at,
format!("Function has no parameter named {generic_name}"),
));
};
let ty = function.args[param_index].expect_ty();
let (def_id, args) =
flowistry_pdg_construction::utils::type_as_fn(tcx, ty).unwrap();
let instance = Instance::resolve(tcx, param_env, def_id, args)
.unwrap()
.unwrap();

let expect_indirect = match self {
FlowModel::SubClosure { .. } => {
use rustc_hir::def::DefKind;
match tcx.def_kind(def_id) {
DefKind::Closure => true,
DefKind::Fn => false,
kind => {
return Err(tcx.sess.span_err(
at,
format!("Expected `fn` or `closure` def kind, got {kind:?}"),
))
}
}
}
FlowModel::SubFuture { .. } => {
assert!(tcx.generator_is_async(def_id));
true
}
};
let poll = tcx.lang_items().poll();
let calling_convention = if expect_indirect {
let clj = match arguments {
[clj] => clj,
[gen, _]
if tcx.def_kind(function.def_id()) == hir::def::DefKind::AssocFn
&& tcx.associated_item(function.def_id()).trait_item_def_id
== poll =>
{
gen
}
_ => return Err(tcx.sess.span_err(
at,
format!(
"this function ({:?}) should have only one argument but it has {}",
function.def_id(),
arguments.len()
),
)),
};
CallingConvention::Indirect {
closure_arg: clj.clone(),
// This is incorrect, but we only support
// non-argument closures at the moment so this
// will never be used.
tupled_arguments: clj.clone(),
}
} else {
CallingConvention::Direct(arguments.into())
};
Ok((instance, calling_convention))
}
}
}
}

impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> {
fn on_inline(&self, info: CallInfo<'tcx, '_>) -> CallChanges<'tcx> {
let changes = CallChanges::default();

let skip = match self.judge.should_inline(&info) {
InlineJudgement::AbstractViaType => SkipCall::Skip,
InlineJudgement::UseFlowModel(model) => {
// Set in case of errors
assert!(matches!(
model,
FlowModel::SubClosure | FlowModel::SubFuture
));
if let [clj] = &info.arguments {
let ty = clj.ty(info.caller_body, self.tcx);
let (def_id, args) =
flowistry_pdg_construction::utils::type_as_fn(self.tcx, ty).unwrap();
let instance = Instance::resolve(self.tcx, info.param_env, def_id, args)
.unwrap()
.unwrap();
let num_inputs = instance.sig(self.tcx).unwrap().inputs().len();
match model {
FlowModel::SubClosure => {
use rustc_hir::def::DefKind;
match self.tcx.def_kind(def_id) {
DefKind::Closure => assert_eq!(num_inputs, 1),
DefKind::Fn => assert_eq!(num_inputs, 0),
kind => assert!(
false,
"Expected `fn` or `closure` def kind, got {kind:?}"
),
}
}
FlowModel::SubFuture => {
assert_eq!(num_inputs, 1);
assert!(self.tcx.generator_is_async(def_id))
}
};
if let Ok((instance, calling_convention)) = model.apply(
self.tcx,
info.callee,
info.param_env,
info.arguments,
info.span,
) {
SkipCall::Replace {
instance,
calling_convention: CallingConvention::Indirect {
closure_arg: clj.clone(),
// This is incorrect, but we only support
// non-argument closures at the moment so this
// will never be used.
tupled_arguments: clj.clone(),
},
calling_convention,
}
} else {
SkipCall::Skip
Expand Down
13 changes: 3 additions & 10 deletions crates/paralegal-flow/src/ann/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
ann::{Annotation, MarkerAnnotation},
args::{Args, FlowModel},
utils::{
resolve::expect_resolve_string_to_def_id, ty_of_const, FunctionKind, InstanceExt,
func_of_term, resolve::expect_resolve_string_to_def_id, FunctionKind, InstanceExt,
IntoDefId, TyExt,
},
Either, HashMap, HashSet,
Expand All @@ -24,7 +24,7 @@ use flowistry_pdg_construction::{
body_cache::{local_or_remote_paths, BodyCache},
determine_async,
encoder::ParalegalDecoder,
utils::{is_virtual, try_monomorphize, try_resolve_function, type_as_fn},
utils::{is_virtual, try_monomorphize, try_resolve_function},
};
use paralegal_spdg::Identifier;

Expand Down Expand Up @@ -276,14 +276,7 @@ impl<'tcx> MarkerCtx<'tcx> {
" Finding reachable markers for terminator {:?}",
terminator.kind
);
let mir::TerminatorKind::Call { func, .. } = &terminator.kind else {
return v.into_iter();
};
let Some(const_) = func.constant() else {
return v.into_iter();
};
let ty = ty_of_const(const_);
let Some((def_id, gargs)) = type_as_fn(self.tcx(), ty) else {
let Some((def_id, gargs)) = func_of_term(self.tcx(), terminator) else {
return v.into_iter();
};
let res = if expect_resolve {
Expand Down
6 changes: 4 additions & 2 deletions crates/paralegal-flow/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,11 +553,13 @@ pub struct DepConfig {
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone)]
#[serde(tag = "mode", rename_all = "kebab-case")]
pub enum FlowModel {
#[serde(rename_all = "kebab-case")]
/// Replaces the result of a call to a higher-order function with a call to
/// the input closure.
SubClosure,
SubClosure { generic_name: String },
#[serde(rename_all = "kebab-case")]
/// Replaces the result of a higher-order future by an input future.
SubFuture,
SubFuture { generic_name: String },
}

/// Additional configuration for the build process/rustc
Expand Down
16 changes: 14 additions & 2 deletions crates/paralegal-flow/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ use rustc_hir::{
BodyId,
};
use rustc_middle::{
mir::{self, Constant, Location, Place, ProjectionElem},
ty::{self, Instance, Ty},
mir::{self, Constant, Location, Place, ProjectionElem, Terminator},
ty::{self, GenericArgsRef, Instance, Ty},
};
use rustc_span::{symbol::Ident, Span as RustSpan, Span};
use rustc_target::spec::abi::Abi;
Expand Down Expand Up @@ -275,6 +275,18 @@ impl FunctionKind {
}
}

pub fn func_of_term<'tcx>(
tcx: TyCtxt<'tcx>,
terminator: &Terminator<'tcx>,
) -> Option<(DefId, GenericArgsRef<'tcx>)> {
let mir::TerminatorKind::Call { func, .. } = &terminator.kind else {
return None;
};
let const_ = func.constant()?;
let ty = ty_of_const(const_);
type_as_fn(tcx, ty)
}

/// A simplified version of the argument list that is stored in a
/// `TerminatorKind::Call`.
///
Expand Down
1 change: 1 addition & 0 deletions crates/paralegal-flow/tests/flow-models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ fn simple_source_target_flow(graph: CtrlRef<'_>) {
}

define_test!(block_fn: graph -> {
simple_source_target_flow(graph)
});

define_test!(block_closure: graph -> {
Expand Down
3 changes: 3 additions & 0 deletions crates/paralegal-flow/tests/flow-models/Paralegal.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ rust-features = ["saturating_int_impl"]

[flow-models."std::thread::spawn"]
mode = "sub-closure"
generic-name = "F"

[flow-models."tokio::spawn"]
mode = "sub-future"
generic-name = "F"

[flow-models."actix_web::web::block"]
mode = "sub-closure"
generic-name = "F"

0 comments on commit 22abdc2

Please sign in to comment.