From d108a0159d7a4ce7eaf991900f530f60841617e1 Mon Sep 17 00:00:00 2001 From: Justus Adam Date: Fri, 6 Sep 2024 16:36:43 -0700 Subject: [PATCH] Applying stubs in marker discovery --- crates/paralegal-flow/src/ana/mod.rs | 150 +++++++++++++++------------ crates/paralegal-flow/src/ann/db.rs | 46 ++++++-- crates/paralegal-flow/src/lib.rs | 1 + 3 files changed, 121 insertions(+), 76 deletions(-) diff --git a/crates/paralegal-flow/src/ana/mod.rs b/crates/paralegal-flow/src/ana/mod.rs index c623795fde..d6f39db3fb 100644 --- a/crates/paralegal-flow/src/ana/mod.rs +++ b/crates/paralegal-flow/src/ana/mod.rs @@ -413,20 +413,13 @@ struct MyCallback<'tcx> { } impl Stub { - /// Performs the effects of this model on the provided function. - /// - /// `function` is what was to be called but for which a flow model exists, - /// `arguments` are the arguments to that call. - /// - /// Returns a new instance to call instead and how it should be called. - fn apply<'tcx>( + pub fn resolve_alternate_instance<'tcx>( &self, tcx: TyCtxt<'tcx>, function: Instance<'tcx>, param_env: ParamEnv<'tcx>, - arguments: &[Operand<'tcx>], at: RustSpan, - ) -> Result<(Instance<'tcx>, CallingConvention<'tcx>), ErrorGuaranteed> { + ) -> Result, ErrorGuaranteed> { match self { Stub::SubClosure { generic_name } | Stub::SubFuture { generic_name } => { let name = Symbol::intern(generic_name); @@ -443,66 +436,93 @@ impl Stub { let ty = function.args[param_index].expect_ty(); let (def_id, args) = flowistry_pdg_construction::utils::type_as_fn(tcx, ty).unwrap(); - let instance = Instance::resolve(tcx, param_env, def_id, args) + Ok(Instance::resolve(tcx, param_env, def_id, args) .unwrap() - .unwrap(); - - let expect_indirect = match self { - Stub::SubClosure { .. } => { - use rustc_hir::def::DefKind; - match tcx.def_kind(def_id) { - DefKind::Closure => true, - DefKind::Fn => false, - kind => { - return Err(tcx.sess.span_err( - at, - format!("Expected `fn` or `closure` def kind, got {kind:?}"), - )) - } - } - } - Stub::SubFuture { .. } => { - assert!(tcx.generator_is_async(def_id)); - true - } - }; - let poll = tcx.lang_items().poll(); - let calling_convention = if expect_indirect { - let clj = match arguments { - [clj] => clj, - [gen, _] - if tcx.def_kind(function.def_id()) == hir::def::DefKind::AssocFn - && tcx.associated_item(function.def_id()).trait_item_def_id - == poll => - { - gen - } - _ => { - return Err(tcx.sess.span_err( - at, - format!( - "this function ({:?}) should have only one argument but it has {}", - function.def_id(), - arguments.len() - ), - )) - } - }; - CallingConvention::Indirect { - once_shim: false, - closure_arg: clj.clone(), - // This is incorrect, but we only support - // non-argument closures at the moment so this - // will never be used. - tupled_arguments: clj.clone(), - } - } else { - CallingConvention::Direct(arguments.into()) - }; - Ok((instance, calling_convention)) + .unwrap()) } } } + + fn indirect_required( + &self, + tcx: TyCtxt, + def_id: DefId, + at: RustSpan, + ) -> Result { + let bool = match self { + Stub::SubClosure { .. } => { + use rustc_hir::def::DefKind; + match tcx.def_kind(def_id) { + DefKind::Closure => true, + DefKind::Fn => false, + kind => { + return Err(tcx.sess.span_err( + at, + format!("Expected `fn` or `closure` def kind, got {kind:?}"), + )) + } + } + } + Stub::SubFuture { .. } => { + assert!(tcx.generator_is_async(def_id)); + true + } + }; + Ok(bool) + } + + /// Performs the effects of this model on the provided function. + /// + /// `function` is what was to be called but for which a flow model exists, + /// `arguments` are the arguments to that call. + /// + /// Returns a new instance to call instead and how it should be called. + pub fn apply<'tcx>( + &self, + tcx: TyCtxt<'tcx>, + function: Instance<'tcx>, + param_env: ParamEnv<'tcx>, + arguments: &[Operand<'tcx>], + at: RustSpan, + ) -> Result<(Instance<'tcx>, CallingConvention<'tcx>), ErrorGuaranteed> { + let instance = self.resolve_alternate_instance(tcx, function, param_env, at)?; + let def_id = instance.def_id(); + + let expect_indirect = self.indirect_required(tcx, def_id, at)?; + let poll = tcx.lang_items().poll(); + let calling_convention = if expect_indirect { + let clj = match arguments { + [clj] => clj, + [gen, _] + if tcx.def_kind(function.def_id()) == hir::def::DefKind::AssocFn + && tcx.associated_item(function.def_id()).trait_item_def_id == poll => + { + gen + } + _ => { + return Err(tcx.sess.span_err( + at, + format!( + "this function ({:?}) should have only one argument but it has {}", + function.def_id(), + arguments.len() + ), + )) + } + }; + CallingConvention::Indirect { + once_shim: false, + closure_arg: clj.clone(), + // This is incorrect, but we only support + // non-argument closures at the moment so this + // will never be used. + tupled_arguments: clj.clone(), + } + } else { + CallingConvention::Direct(arguments.into()) + }; + Ok((instance, calling_convention)) + } } impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { diff --git a/crates/paralegal-flow/src/ann/db.rs b/crates/paralegal-flow/src/ann/db.rs index 5a9940226e..34d6a235f7 100644 --- a/crates/paralegal-flow/src/ann/db.rs +++ b/crates/paralegal-flow/src/ann/db.rs @@ -28,6 +28,7 @@ use flowistry_pdg_construction::{ }; use paralegal_spdg::Identifier; +use rustc_errors::DiagnosticMessage; use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hir::def_id::DefId; use rustc_hir::{def::DefKind, def_id::CrateNum}; @@ -37,6 +38,7 @@ use rustc_middle::{ }; use rustc_serialize::Decodable; +use rustc_span::Span; use rustc_utils::cache::Cache; use std::{borrow::Cow, fs::File, io::Read, rc::Rc}; @@ -264,6 +266,14 @@ impl<'tcx> MarkerCtx<'tcx> { .collect() } + fn span_err(&self, span: Span, msg: impl Into) { + if self.0.config.relaxed() { + self.tcx().sess.span_warn(span, msg.into()); + } else { + self.tcx().sess.span_err(span, msg.into()); + } + } + /// Does this terminator carry a marker? fn terminator_reachable_markers( &self, @@ -271,6 +281,7 @@ impl<'tcx> MarkerCtx<'tcx> { terminator: &mir::Terminator<'tcx>, expect_resolve: bool, ) -> impl Iterator + '_ { + let param_env = ty::ParamEnv::reveal_all(); let mut v = vec![]; trace!( " Finding reachable markers for terminator {:?}", @@ -279,21 +290,13 @@ impl<'tcx> MarkerCtx<'tcx> { let Some((def_id, gargs)) = func_of_term(self.tcx(), terminator) else { return v.into_iter(); }; - let res = if expect_resolve { - let Some(instance) = - Instance::resolve(self.tcx(), ty::ParamEnv::reveal_all(), def_id, gargs).unwrap() + let mut res = if expect_resolve { + let Some(instance) = Instance::resolve(self.tcx(), param_env, def_id, gargs).unwrap() else { - if self.0.config.relaxed() { - self.tcx().sess.span_warn( + self.span_err( terminator.source_info.span, format!("cannot determine reachable markers, failed to resolve {def_id:?} with {gargs:?}") ); - } else { - self.tcx().sess.span_err( - terminator.source_info.span, - format!("cannot determine reachable markers, failed to resolve {def_id:?} with {gargs:?}") - ); - } return v.into_iter(); }; MaybeMonomorphized::Monomorphized(instance) @@ -304,6 +307,27 @@ impl<'tcx> MarkerCtx<'tcx> { " Checking function {} for markers", self.tcx().def_path_debug_str(res.def_id()) ); + + if let Some(model) = self.has_flow_model(res.def_id()) { + let MaybeMonomorphized::Monomorphized(instance) = &mut res else { + self.span_err( + terminator.source_info.span, + "Could not apply stub to an partially resolved function", + ); + return v.into_iter(); + }; + if let Ok(new_instance) = model.resolve_alternate_instance( + self.tcx(), + *instance, + param_env, + terminator.source_info.span, + ) { + *instance = new_instance; + } else { + return v.into_iter(); + } + } + v.extend(self.get_reachable_and_self_markers(res)); // We have to proceed differently than graph construction, diff --git a/crates/paralegal-flow/src/lib.rs b/crates/paralegal-flow/src/lib.rs index b44a5b5d4d..498c238dc8 100644 --- a/crates/paralegal-flow/src/lib.rs +++ b/crates/paralegal-flow/src/lib.rs @@ -29,6 +29,7 @@ extern crate rustc_ast; extern crate rustc_borrowck; extern crate rustc_data_structures; extern crate rustc_driver; +extern crate rustc_errors; extern crate rustc_hash; extern crate rustc_hir; extern crate rustc_index;