diff --git a/crates/paralegal-flow/src/ana/graph_converter.rs b/crates/paralegal-flow/src/ana/graph_converter.rs index 4ef7818f9a..f9d5ee37b2 100644 --- a/crates/paralegal-flow/src/ana/graph_converter.rs +++ b/crates/paralegal-flow/src/ana/graph_converter.rs @@ -100,6 +100,7 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { generator.tcx, def_id, generator.pdg_constructor.body_cache(), + generator.marker_ctx().clone(), ), stats, }) @@ -599,12 +600,12 @@ mod call_string_resolver { utils::{manufacture_substs_for, try_monomorphize, try_resolve_function}, }; use paralegal_spdg::Endpoint; - use rustc_middle::ty::Instance; + use rustc_middle::{mir::TerminatorKind, ty::Instance}; use rustc_utils::cache::Cache; - use crate::{Either, TyCtxt}; + use crate::{Either, MarkerCtx, TyCtxt}; - use super::{map_either, match_async_trait_assign, AsFnAndArgs}; + use super::{func_of_term, map_either, match_async_trait_assign, AsFnAndArgs}; /// Cached resolution of [`CallString`]s to [`FnResolution`]s. /// @@ -615,6 +616,7 @@ mod call_string_resolver { tcx: TyCtxt<'tcx>, entrypoint_is_async: bool, body_cache: &'a BodyCache<'tcx>, + marker_context: MarkerCtx<'tcx>, } impl<'tcx, 'a> CallStringResolver<'tcx, 'a> { @@ -645,12 +647,14 @@ mod call_string_resolver { tcx: TyCtxt<'tcx>, entrypoint: Endpoint, body_cache: &'a BodyCache<'tcx>, + marker_context: MarkerCtx<'tcx>, ) -> Self { Self { cache: Default::default(), tcx, entrypoint_is_async: super::entrypoint_is_async(body_cache, tcx, entrypoint), body_cache, + marker_context, } } @@ -681,7 +685,21 @@ mod call_string_resolver { }, ); let res = match normalized { - Either::Right(term) => term.as_instance_and_args(tcx).unwrap().0, + Either::Right(term) => { + let (def_id, args) = func_of_term(tcx, &term).unwrap(); + let instance = Instance::expect_resolve(tcx, param_env, def_id, args); + if let Some(model) = self.marker_context.has_flow_model(def_id) { + let TerminatorKind::Call { args, .. } = &term.kind else { + unreachable!() + }; + model + .apply(tcx, instance, param_env, args, term.source_info.span) + .unwrap() + .0 + } else { + term.as_instance_and_args(tcx).unwrap().0 + } + } Either::Left(stmt) => { let (def_id, generics) = match_async_trait_assign(&stmt).unwrap(); try_resolve_function(tcx, def_id, param_env, generics).unwrap() diff --git a/crates/paralegal-flow/src/ana/inline_judge.rs b/crates/paralegal-flow/src/ana/inline_judge.rs index 198dc1066e..2d79d3984c 100644 --- a/crates/paralegal-flow/src/ana/inline_judge.rs +++ b/crates/paralegal-flow/src/ana/inline_judge.rs @@ -1,6 +1,6 @@ use std::rc::Rc; -use flowistry_pdg_construction::{body_cache::BodyCache, CallInfo}; +use flowistry_pdg_construction::{body_cache::BodyCache, utils::is_async, CallInfo}; use paralegal_spdg::{utils::write_sep, Identifier}; use rustc_hash::FxHashSet; use rustc_hir::def_id::{CrateNum, DefId, LOCAL_CRATE}; @@ -90,7 +90,14 @@ impl<'tcx> InlineJudge<'tcx> { let marker_target = info.async_parent.unwrap_or(info.callee); let marker_target_def_id = marker_target.def_id(); if let Some(model) = self.marker_ctx().has_flow_model(marker_target_def_id) { - return InlineJudgement::UseFlowModel(model); + // If we're replacing an async function skip the poll call. + // + // I tried to have it replace the poll call only but that didn't seem to work. + return if info.async_parent.is_some() { + InlineJudgement::AbstractViaType + } else { + InlineJudgement::UseFlowModel(model) + }; } let is_marked = self.marker_ctx.is_marked(marker_target_def_id); let judgement = match self.opts.anactrl().inlining_depth() { diff --git a/crates/paralegal-flow/src/ana/mod.rs b/crates/paralegal-flow/src/ana/mod.rs index 30a9a63fa3..9a7f0d22cc 100644 --- a/crates/paralegal-flow/src/ana/mod.rs +++ b/crates/paralegal-flow/src/ana/mod.rs @@ -29,10 +29,10 @@ use petgraph::visit::GraphBase; use rustc_hir::{self as hir, def, def_id::DefId}; use rustc_middle::{ - mir::Location, + mir::{Location, Operand}, ty::{Instance, ParamEnv, TyCtxt}, }; -use rustc_span::{FileNameDisplayPreference, Span as RustSpan, Symbol}; +use rustc_span::{ErrorGuaranteed, FileNameDisplayPreference, Span as RustSpan, Symbol}; mod graph_converter; mod inline_judge; @@ -412,6 +412,90 @@ struct MyCallback<'tcx> { tcx: TyCtxt<'tcx>, } +impl FlowModel { + fn apply<'tcx>( + &self, + tcx: TyCtxt<'tcx>, + function: Instance<'tcx>, + param_env: ParamEnv<'tcx>, + arguments: &[Operand<'tcx>], + at: RustSpan, + ) -> Result<(Instance<'tcx>, CallingConvention<'tcx>), ErrorGuaranteed> { + match self { + FlowModel::SubClosure { generic_name } | FlowModel::SubFuture { generic_name } => { + let name = Symbol::intern(&generic_name); + let generics = tcx.generics_of(function.def_id()); + let Some(param_index) = (0..generics.count()).find(|&idx| { + let param = generics.param_at(idx, tcx); + param.name == name + }) else { + return Err(tcx.sess.span_err( + at, + format!("Function has no parameter named {generic_name}"), + )); + }; + let ty = function.args[param_index].expect_ty(); + let (def_id, args) = + flowistry_pdg_construction::utils::type_as_fn(tcx, ty).unwrap(); + let instance = Instance::resolve(tcx, param_env, def_id, args) + .unwrap() + .unwrap(); + + let expect_indirect = match self { + FlowModel::SubClosure { .. } => { + use rustc_hir::def::DefKind; + match tcx.def_kind(def_id) { + DefKind::Closure => true, + DefKind::Fn => false, + kind => { + return Err(tcx.sess.span_err( + at, + format!("Expected `fn` or `closure` def kind, got {kind:?}"), + )) + } + } + } + FlowModel::SubFuture { .. } => { + assert!(tcx.generator_is_async(def_id)); + true + } + }; + let poll = tcx.lang_items().poll(); + let calling_convention = if expect_indirect { + let clj = match arguments { + [clj] => clj, + [gen, _] + if tcx.def_kind(function.def_id()) == hir::def::DefKind::AssocFn + && tcx.associated_item(function.def_id()).trait_item_def_id + == poll => + { + gen + } + _ => return Err(tcx.sess.span_err( + at, + format!( + "this function ({:?}) should have only one argument but it has {}", + function.def_id(), + arguments.len() + ), + )), + }; + CallingConvention::Indirect { + closure_arg: clj.clone(), + // This is incorrect, but we only support + // non-argument closures at the moment so this + // will never be used. + tupled_arguments: clj.clone(), + } + } else { + CallingConvention::Direct(arguments.into()) + }; + Ok((instance, calling_convention)) + } + } + } +} + impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { fn on_inline(&self, info: CallInfo<'tcx, '_>) -> CallChanges<'tcx> { let changes = CallChanges::default(); @@ -419,45 +503,16 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { let skip = match self.judge.should_inline(&info) { InlineJudgement::AbstractViaType => SkipCall::Skip, InlineJudgement::UseFlowModel(model) => { - // Set in case of errors - assert!(matches!( - model, - FlowModel::SubClosure | FlowModel::SubFuture - )); - if let [clj] = &info.arguments { - let ty = clj.ty(info.caller_body, self.tcx); - let (def_id, args) = - flowistry_pdg_construction::utils::type_as_fn(self.tcx, ty).unwrap(); - let instance = Instance::resolve(self.tcx, info.param_env, def_id, args) - .unwrap() - .unwrap(); - let num_inputs = instance.sig(self.tcx).unwrap().inputs().len(); - match model { - FlowModel::SubClosure => { - use rustc_hir::def::DefKind; - match self.tcx.def_kind(def_id) { - DefKind::Closure => assert_eq!(num_inputs, 1), - DefKind::Fn => assert_eq!(num_inputs, 0), - kind => assert!( - false, - "Expected `fn` or `closure` def kind, got {kind:?}" - ), - } - } - FlowModel::SubFuture => { - assert_eq!(num_inputs, 1); - assert!(self.tcx.generator_is_async(def_id)) - } - }; + if let Ok((instance, calling_convention)) = model.apply( + self.tcx, + info.callee, + info.param_env, + info.arguments, + info.span, + ) { SkipCall::Replace { instance, - calling_convention: CallingConvention::Indirect { - closure_arg: clj.clone(), - // This is incorrect, but we only support - // non-argument closures at the moment so this - // will never be used. - tupled_arguments: clj.clone(), - }, + calling_convention, } } else { SkipCall::Skip diff --git a/crates/paralegal-flow/src/ann/db.rs b/crates/paralegal-flow/src/ann/db.rs index 19a041843b..a610d5f3c9 100644 --- a/crates/paralegal-flow/src/ann/db.rs +++ b/crates/paralegal-flow/src/ann/db.rs @@ -14,7 +14,7 @@ use crate::{ ann::{Annotation, MarkerAnnotation}, args::{Args, FlowModel}, utils::{ - resolve::expect_resolve_string_to_def_id, ty_of_const, FunctionKind, InstanceExt, + func_of_term, resolve::expect_resolve_string_to_def_id, FunctionKind, InstanceExt, IntoDefId, TyExt, }, Either, HashMap, HashSet, @@ -24,7 +24,7 @@ use flowistry_pdg_construction::{ body_cache::{local_or_remote_paths, BodyCache}, determine_async, encoder::ParalegalDecoder, - utils::{is_virtual, try_monomorphize, try_resolve_function, type_as_fn}, + utils::{is_virtual, try_monomorphize, try_resolve_function}, }; use paralegal_spdg::Identifier; @@ -276,14 +276,7 @@ impl<'tcx> MarkerCtx<'tcx> { " Finding reachable markers for terminator {:?}", terminator.kind ); - let mir::TerminatorKind::Call { func, .. } = &terminator.kind else { - return v.into_iter(); - }; - let Some(const_) = func.constant() else { - return v.into_iter(); - }; - let ty = ty_of_const(const_); - let Some((def_id, gargs)) = type_as_fn(self.tcx(), ty) else { + let Some((def_id, gargs)) = func_of_term(self.tcx(), terminator) else { return v.into_iter(); }; let res = if expect_resolve { diff --git a/crates/paralegal-flow/src/args.rs b/crates/paralegal-flow/src/args.rs index fbb3fa841f..ad4b218128 100644 --- a/crates/paralegal-flow/src/args.rs +++ b/crates/paralegal-flow/src/args.rs @@ -553,11 +553,13 @@ pub struct DepConfig { #[derive(serde::Serialize, serde::Deserialize, Debug, Clone)] #[serde(tag = "mode", rename_all = "kebab-case")] pub enum FlowModel { + #[serde(rename_all = "kebab-case")] /// Replaces the result of a call to a higher-order function with a call to /// the input closure. - SubClosure, + SubClosure { generic_name: String }, + #[serde(rename_all = "kebab-case")] /// Replaces the result of a higher-order future by an input future. - SubFuture, + SubFuture { generic_name: String }, } /// Additional configuration for the build process/rustc diff --git a/crates/paralegal-flow/src/utils/mod.rs b/crates/paralegal-flow/src/utils/mod.rs index 99781edd8b..9e41a7d410 100644 --- a/crates/paralegal-flow/src/utils/mod.rs +++ b/crates/paralegal-flow/src/utils/mod.rs @@ -19,8 +19,8 @@ use rustc_hir::{ BodyId, }; use rustc_middle::{ - mir::{self, Constant, Location, Place, ProjectionElem}, - ty::{self, Instance, Ty}, + mir::{self, Constant, Location, Place, ProjectionElem, Terminator}, + ty::{self, GenericArgsRef, Instance, Ty}, }; use rustc_span::{symbol::Ident, Span as RustSpan, Span}; use rustc_target::spec::abi::Abi; @@ -275,6 +275,18 @@ impl FunctionKind { } } +pub fn func_of_term<'tcx>( + tcx: TyCtxt<'tcx>, + terminator: &Terminator<'tcx>, +) -> Option<(DefId, GenericArgsRef<'tcx>)> { + let mir::TerminatorKind::Call { func, .. } = &terminator.kind else { + return None; + }; + let const_ = func.constant()?; + let ty = ty_of_const(const_); + type_as_fn(tcx, ty) +} + /// A simplified version of the argument list that is stored in a /// `TerminatorKind::Call`. /// diff --git a/crates/paralegal-flow/tests/flow-models.rs b/crates/paralegal-flow/tests/flow-models.rs index a6fceaf3f1..17295e1ae3 100644 --- a/crates/paralegal-flow/tests/flow-models.rs +++ b/crates/paralegal-flow/tests/flow-models.rs @@ -60,6 +60,7 @@ fn simple_source_target_flow(graph: CtrlRef<'_>) { } define_test!(block_fn: graph -> { + simple_source_target_flow(graph) }); define_test!(block_closure: graph -> { diff --git a/crates/paralegal-flow/tests/flow-models/Paralegal.toml b/crates/paralegal-flow/tests/flow-models/Paralegal.toml index f3c53481e0..ddf456bc97 100644 --- a/crates/paralegal-flow/tests/flow-models/Paralegal.toml +++ b/crates/paralegal-flow/tests/flow-models/Paralegal.toml @@ -4,9 +4,12 @@ rust-features = ["saturating_int_impl"] [flow-models."std::thread::spawn"] mode = "sub-closure" +generic-name = "F" [flow-models."tokio::spawn"] mode = "sub-future" +generic-name = "F" [flow-models."actix_web::web::block"] mode = "sub-closure" +generic-name = "F"