diff --git a/Cargo.lock b/Cargo.lock index 53acc2a8c3..317068f51d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -492,6 +492,9 @@ dependencies = [ "log", "petgraph", "rustc_utils 0.7.4-nightly-2023-08-25 (registry+https://github.com/rust-lang/crates.io-index)", + "serde", + "strum", + "thiserror", ] [[package]] @@ -1110,12 +1113,12 @@ name = "rustc_plugin" version = "0.7.4-nightly-2023-08-25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1348edfa020dbe4807a4d99272332dadcbbedff6b587accb95faefe20d2c7129" -replace = "rustc_plugin 0.7.4-nightly-2023-08-25 (git+https://github.com/JustusAdam/rustc_plugin?rev=d4fefb5c0344cdf4812b4877d5b03cb19a2c4672)" +replace = "rustc_plugin 0.7.4-nightly-2023-08-25 (git+https://github.com/JustusAdam/rustc_plugin?rev=dd382b79fc12ee86bc774c290a00bda32a0d54db)" [[package]] name = "rustc_plugin" version = "0.7.4-nightly-2023-08-25" -source = "git+https://github.com/JustusAdam/rustc_plugin?rev=d4fefb5c0344cdf4812b4877d5b03cb19a2c4672#d4fefb5c0344cdf4812b4877d5b03cb19a2c4672" +source = "git+https://github.com/JustusAdam/rustc_plugin?rev=dd382b79fc12ee86bc774c290a00bda32a0d54db#dd382b79fc12ee86bc774c290a00bda32a0d54db" dependencies = [ "cargo_metadata", "log", @@ -1136,12 +1139,12 @@ name = "rustc_utils" version = "0.7.4-nightly-2023-08-25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09428c7086894369685cca54a516acc0f0ab6d0e5a628c094ba83bfddaf1aedf" -replace = "rustc_utils 0.7.4-nightly-2023-08-25 (git+https://github.com/JustusAdam/rustc_plugin?rev=d4fefb5c0344cdf4812b4877d5b03cb19a2c4672)" +replace = "rustc_utils 0.7.4-nightly-2023-08-25 (git+https://github.com/JustusAdam/rustc_plugin?rev=dd382b79fc12ee86bc774c290a00bda32a0d54db)" [[package]] name = "rustc_utils" version = "0.7.4-nightly-2023-08-25" -source = "git+https://github.com/JustusAdam/rustc_plugin?rev=d4fefb5c0344cdf4812b4877d5b03cb19a2c4672#d4fefb5c0344cdf4812b4877d5b03cb19a2c4672" +source = "git+https://github.com/JustusAdam/rustc_plugin?rev=dd382b79fc12ee86bc774c290a00bda32a0d54db#dd382b79fc12ee86bc774c290a00bda32a0d54db" dependencies = [ "anyhow", "cfg-if", diff --git a/Cargo.toml b/Cargo.toml index 1f8e38ed1e..a02f24432a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ serde = "1.0.188" petgraph = { version = "0.6", features = ["serde-1"] } strum = { version = "0.25", features = ["derive"] } anyhow = { version = "1.0.72", features = ["backtrace"] } +thiserror = "1" rustc_utils = { version = "=0.7.4-nightly-2023-08-25", features = [ "indexical", @@ -29,5 +30,12 @@ debug = true # "rustc_utils:0.7.4-nightly-2023-08-25" = { path = "../rustc_plugin/crates/rustc_utils" } # "rustc_plugin:0.7.4-nightly-2023-08-25" = { path = "../rustc_plugin/crates/rustc_plugin" } -"rustc_utils:0.7.4-nightly-2023-08-25" = { git = "https://github.com/JustusAdam/rustc_plugin", rev = "d4fefb5c0344cdf4812b4877d5b03cb19a2c4672" } -"rustc_plugin:0.7.4-nightly-2023-08-25" = { git = "https://github.com/JustusAdam/rustc_plugin", rev = "d4fefb5c0344cdf4812b4877d5b03cb19a2c4672" } +[replace."rustc_utils:0.7.4-nightly-2023-08-25"] +# path = "../rustc_plugin/crates/rustc_utils" +git = "https://github.com/JustusAdam/rustc_plugin" +rev = "dd382b79fc12ee86bc774c290a00bda32a0d54db" + +[replace."rustc_plugin:0.7.4-nightly-2023-08-25"] +# path = "../rustc_plugin/crates/rustc_plugin" +git = "https://github.com/JustusAdam/rustc_plugin" +rev = "dd382b79fc12ee86bc774c290a00bda32a0d54db" diff --git a/Makefile.toml b/Makefile.toml index 7d201ae6e9..6b00347eab 100644 --- a/Makefile.toml +++ b/Makefile.toml @@ -65,6 +65,8 @@ args = [ "--test", "async_tests", "--no-fail-fast", + "--test", + "cross-crate", ] description = "Low-level tests for the PDG emitted by the analyzer specifically." command = "cargo" @@ -72,7 +74,7 @@ command = "cargo" [tasks.policy-framework-tests] description = "Tests related to the correctness of the policy framework." command = "cargo" -args = ["test", "-p", "paralegal-policy", "--lib"] +args = ["test", "-p", "paralegal-policy"] [tasks.guide-project] description = "Build and run the policy from the guide." diff --git a/crates/flowistry_pdg/src/lib.rs b/crates/flowistry_pdg/src/lib.rs index 1e77a45d93..9ff0656de5 100644 --- a/crates/flowistry_pdg/src/lib.rs +++ b/crates/flowistry_pdg/src/lib.rs @@ -11,6 +11,11 @@ pub(crate) mod rustc { pub use middle::mir; } +#[cfg(feature = "rustc")] +extern crate rustc_macros; +#[cfg(feature = "rustc")] +extern crate rustc_serialize; + mod pdg; #[cfg(feature = "rustc")] mod rustc_impls; diff --git a/crates/flowistry_pdg/src/pdg.rs b/crates/flowistry_pdg/src/pdg.rs index 76cf784ffe..92d568a59c 100644 --- a/crates/flowistry_pdg/src/pdg.rs +++ b/crates/flowistry_pdg/src/pdg.rs @@ -8,6 +8,10 @@ use serde::{Deserialize, Serialize}; use crate::rustc_portable::*; #[cfg(feature = "rustc")] use crate::rustc_proxies; +#[cfg(feature = "rustc")] +use rustc_macros::{Decodable, Encodable}; +#[cfg(feature = "rustc")] +use rustc_serialize::{Decodable, Decoder, Encodable, Encoder}; /// Extends a MIR body's `Location` with `Start` (before the first instruction) and `End` (after all returns). #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug, Serialize, Deserialize)] @@ -26,6 +30,35 @@ pub enum RichLocation { End, } +#[cfg(feature = "rustc")] +impl Encodable for RichLocation { + fn encode(&self, s: &mut E) { + match self { + Self::Location(loc) => s.emit_enum_variant(0, |s| { + s.emit_u32(loc.block.as_u32()); + s.emit_usize(loc.statement_index); + }), + Self::Start => s.emit_enum_variant(1, |_| ()), + Self::End => s.emit_enum_variant(2, |_| ()), + } + } +} + +#[cfg(feature = "rustc")] +impl Decodable for RichLocation { + fn decode(d: &mut D) -> Self { + match d.read_usize() { + 0 => Self::Location(Location { + block: d.read_u32().into(), + statement_index: d.read_usize(), + }), + 1 => Self::Start, + 2 => Self::End, + v => panic!("Unknown variant index: {v}"), + } + } +} + impl RichLocation { /// Returns true if this is a `Start` location. pub fn is_start(self) -> bool { @@ -74,17 +107,18 @@ impl From for RichLocation { /// A [`RichLocation`] within a specific point in a codebase. #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug, Serialize, Deserialize)] +#[cfg_attr(feature = "rustc", derive(Encodable, Decodable))] pub struct GlobalLocation { + // TODO Change to `DefId` /// The function containing the location. - #[cfg_attr(feature = "rustc", serde(with = "rustc_proxies::LocalDefId"))] - pub function: LocalDefId, + #[cfg_attr(feature = "rustc", serde(with = "rustc_proxies::DefId"))] + pub function: DefId, /// The location of an instruction in the function, or the function's start. pub location: RichLocation, } #[cfg(not(feature = "rustc"))] - impl fmt::Display for GlobalLocation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:?}::{}", self.function, self.location) @@ -103,6 +137,21 @@ impl fmt::Display for GlobalLocation { #[derive(PartialEq, Eq, Hash, Copy, Clone, Debug, Serialize, Deserialize)] pub struct CallString(Intern); +#[cfg(feature = "rustc")] +impl Encodable for CallString { + fn encode(&self, s: &mut S) { + let inner: &CallStringInner = &self.0; + inner.encode(s); + } +} + +#[cfg(feature = "rustc")] +impl Decodable for CallString { + fn decode(d: &mut D) -> Self { + Self(Intern::new(CallStringInner::decode(d))) + } +} + type CallStringInner = Box<[GlobalLocation]>; impl CallString { @@ -153,6 +202,10 @@ impl CallString { CallString::new(string) } + pub fn push_front(self, loc: GlobalLocation) -> Self { + CallString::new([loc].into_iter().chain(self.0.iter().copied()).collect()) + } + pub fn is_at_root(self) -> bool { self.0.len() == 1 } @@ -199,6 +252,7 @@ impl fmt::Display for CallString { #[derive( PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Debug, Serialize, Deserialize, strum::EnumIs, )] +#[cfg_attr(feature = "rustc", derive(Decodable, Encodable))] pub enum SourceUse { Operand, Argument(u8), @@ -206,6 +260,7 @@ pub enum SourceUse { /// Additional information about this mutation. #[derive(PartialEq, Eq, Hash, Clone, Copy, Debug, Serialize, Deserialize, strum::EnumIs)] +#[cfg_attr(feature = "rustc", derive(Decodable, Encodable))] pub enum TargetUse { /// A function returned, assigning to it's return destination Return, diff --git a/crates/flowistry_pdg/src/rustc_impls.rs b/crates/flowistry_pdg/src/rustc_impls.rs index 665f1d75d6..c74d5c4ff1 100644 --- a/crates/flowistry_pdg/src/rustc_impls.rs +++ b/crates/flowistry_pdg/src/rustc_impls.rs @@ -77,7 +77,7 @@ impl From for def_id::DefIndex { impl fmt::Display for GlobalLocation { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { tls::with_opt(|opt_tcx| match opt_tcx { - Some(tcx) => match tcx.opt_item_name(self.function.to_def_id()) { + Some(tcx) => match tcx.opt_item_name(self.function) { Some(name) => name.fmt(f), None => write!(f, ""), }, diff --git a/crates/flowistry_pdg_construction/Cargo.toml b/crates/flowistry_pdg_construction/Cargo.toml index d2bba5ed90..365af8352b 100644 --- a/crates/flowistry_pdg_construction/Cargo.toml +++ b/crates/flowistry_pdg_construction/Cargo.toml @@ -23,6 +23,9 @@ flowistry_pdg = { version = "0.1.0", path = "../flowistry_pdg", features = [ ] } #flowistry = { path = "../../../flowistry/crates/flowistry", default-features = false } flowistry = { workspace = true } +serde = { workspace = true, features = ["derive"] } +strum = { workspace = true } +thiserror = { workspace = true } [dev-dependencies] rustc_utils = { workspace = true, features = ["indexical", "test"] } diff --git a/crates/flowistry_pdg_construction/src/approximation.rs b/crates/flowistry_pdg_construction/src/approximation.rs new file mode 100644 index 0000000000..623a96f792 --- /dev/null +++ b/crates/flowistry_pdg_construction/src/approximation.rs @@ -0,0 +1,75 @@ +use log::trace; + +use rustc_abi::VariantIdx; + +use rustc_hir::def_id::DefId; +use rustc_index::IndexVec; +use rustc_middle::{ + mir::{visit::Visitor, AggregateKind, Location, Operand, Place, Rvalue}, + ty::TyKind, +}; + +use crate::local_analysis::LocalAnalysis; + +pub(crate) type ApproximationHandler<'tcx, 'a> = + fn(&LocalAnalysis<'tcx, 'a>, &mut dyn Visitor<'tcx>, &[Operand<'tcx>], Place<'tcx>, Location); + +impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> { + /// Special case behavior for calls to functions used in desugaring async functions. + /// + /// Ensures that functions like `Pin::new_unchecked` are not modularly-approximated. + pub(crate) fn can_approximate_async_functions( + &self, + def_id: DefId, + ) -> Option> { + let lang_items = self.tcx().lang_items(); + if Some(def_id) == lang_items.new_unchecked_fn() { + Some(Self::approximate_new_unchecked) + } else if Some(def_id) == lang_items.into_future_fn() + // FIXME: better way to get retrieve this stdlib DefId? + || self.tcx().def_path_str(def_id) == "::into_future" + { + Some(Self::approximate_into_future) + } else { + None + } + } + + fn approximate_into_future( + &self, + vis: &mut dyn Visitor<'tcx>, + args: &[Operand<'tcx>], + destination: Place<'tcx>, + location: Location, + ) { + trace!("Handling into_future as assign for {destination:?}"); + let [op] = args else { + unreachable!(); + }; + vis.visit_assign(&destination, &Rvalue::Use(op.clone()), location); + } + + fn approximate_new_unchecked( + &self, + vis: &mut dyn Visitor<'tcx>, + args: &[Operand<'tcx>], + destination: Place<'tcx>, + location: Location, + ) { + let lang_items = self.tcx().lang_items(); + let [op] = args else { + unreachable!(); + }; + let mut operands = IndexVec::new(); + operands.push(op.clone()); + let TyKind::Adt(adt_id, generics) = destination.ty(&self.body, self.tcx()).ty.kind() else { + unreachable!() + }; + assert_eq!(adt_id.did(), lang_items.pin_type().unwrap()); + let aggregate_kind = + AggregateKind::Adt(adt_id.did(), VariantIdx::from_u32(0), generics, None, None); + let rvalue = Rvalue::Aggregate(Box::new(aggregate_kind), operands); + trace!("Handling new_unchecked as assign for {destination:?}"); + vis.visit_assign(&destination, &rvalue, location); + } +} diff --git a/crates/flowistry_pdg_construction/src/async_support.rs b/crates/flowistry_pdg_construction/src/async_support.rs index e7fa19d89e..6ab7da151d 100644 --- a/crates/flowistry_pdg_construction/src/async_support.rs +++ b/crates/flowistry_pdg_construction/src/async_support.rs @@ -1,21 +1,35 @@ -use std::rc::Rc; +use std::{fmt::Display, rc::Rc}; use either::Either; + use itertools::Itertools; use rustc_abi::{FieldIdx, VariantIdx}; use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_macros::{Decodable, Encodable}; use rustc_middle::{ mir::{ AggregateKind, BasicBlock, Body, Location, Operand, Place, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }, - ty::{GenericArgsRef, TyCtxt}, + ty::{GenericArgsRef, Instance, TyCtxt}, }; +use rustc_span::Span; -use crate::construct::{CallKind, PartialGraph}; +use crate::{ + construct::EmittableError, + local_analysis::{CallKind, LocalAnalysis}, + utils, Error, +}; -use super::construct::GraphConstructor; -use super::utils::{self, FnResolution}; +/// Describe in which way a function is `async`. +/// +/// Critically distinguishes between a normal `async fn` and an +/// `#[async_trait]`. +#[derive(Debug, Clone, Copy, Decodable, Encodable)] +pub enum AsyncType { + Fn, + Trait, +} /// Stores ids that are needed to construct projections around async functions. pub(crate) struct AsyncInfo { @@ -143,54 +157,127 @@ fn get_async_generator<'tcx>(body: &Body<'tcx>) -> (LocalDefId, GenericArgsRef<' (def_id.expect_local(), generic_args, location) } +/// Try to interpret this function as an async function. +/// +/// If this is an async function it returns the [`Instance`] of the generator, +/// the location where the generator is bound and the type of [`Asyncness`] +/// which in this case is guaranteed to satisfy [`Asyncness::is_async`]. pub fn determine_async<'tcx>( tcx: TyCtxt<'tcx>, def_id: LocalDefId, body: &Body<'tcx>, -) -> Option<(FnResolution<'tcx>, Location)> { - let (generator_def_id, args, loc) = if tcx.asyncness(def_id).is_async() { - get_async_generator(body) +) -> Option<(Instance<'tcx>, Location, AsyncType)> { + let ((generator_def_id, args, loc), asyncness) = if tcx.asyncness(def_id).is_async() { + (get_async_generator(body), AsyncType::Fn) } else { - try_as_async_trait_function(tcx, def_id.to_def_id(), body)? + ( + try_as_async_trait_function(tcx, def_id.to_def_id(), body)?, + AsyncType::Trait, + ) }; let param_env = tcx.param_env_reveal_all_normalized(def_id); let generator_fn = - utils::try_resolve_function(tcx, generator_def_id.to_def_id(), param_env, args); - Some((generator_fn, loc)) + utils::try_resolve_function(tcx, generator_def_id.to_def_id(), param_env, args)?; + Some((generator_fn, loc, asyncness)) } #[derive(Debug, Clone, PartialEq, Eq)] -pub enum AsyncDeterminationResult { +pub enum AsyncDeterminationResult<'tcx, T> { Resolved(T), - Unresolvable(String), + Unresolvable(Error<'tcx>), NotAsync, } -impl<'tcx> GraphConstructor<'tcx> { - pub(crate) fn try_handle_as_async(&self) -> Option> { - let (generator_fn, location) = determine_async(self.tcx, self.def_id, &self.body)?; +#[derive(Debug, Encodable, Decodable, Clone, Hash, Eq, PartialEq)] +pub enum OperandShapeViolation { + IsNotAPlace, + IsNotLocal, + HasNoAssignments, + WrongNumberOfAssignments(u16), +} - let calling_context = self.calling_context_for(generator_fn.def_id(), location); - let params = self.pdg_params_for_call(generator_fn); - Some( - GraphConstructor::new( - params, - Some(calling_context), - self.async_info.clone(), - &self.pdg_cache, - ) - .construct_partial(), - ) +impl Display for OperandShapeViolation { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use OperandShapeViolation::*; + if let WrongNumberOfAssignments(n) = self { + return write!(f, "wrong number of assignments, expected 1, got {n}"); + }; + let str = match self { + IsNotAPlace => "is not a place", + IsNotLocal => "is not local", + HasNoAssignments => "is never assigned", + WrongNumberOfAssignments(..) => unreachable!(), + }; + f.write_str(str) + } +} + +#[derive(Debug, Encodable, Decodable, Clone, Hash, Eq, PartialEq)] +pub enum AsyncResolutionErr { + WrongOperandShape { + span: Span, + reason: OperandShapeViolation, + }, + PinnedAssignmentIsNotACall { + span: Span, + }, + AssignmentToPinNewIsNotAStatement { + span: Span, + }, + AssignmentToAliasOfPinNewInputIsNotACall { + span: Span, + }, + AssignmentToIntoFutureInputIsNotACall { + span: Span, + }, + ChaseTargetIsNotAFunction { + span: Span, + }, +} + +impl<'tcx> EmittableError<'tcx> for AsyncResolutionErr { + fn span(&self, _tcx: TyCtxt<'tcx>) -> Option { + use AsyncResolutionErr::*; + match self { + WrongOperandShape { span, .. } + | PinnedAssignmentIsNotACall { span } + | AssignmentToAliasOfPinNewInputIsNotACall { span } + | AssignmentToIntoFutureInputIsNotACall { span } + | ChaseTargetIsNotAFunction { span } + | AssignmentToPinNewIsNotAStatement { span } => Some(*span), + } } - pub(crate) fn try_poll_call_kind<'a>( - &'a self, + fn msg(&self, _tcx: TyCtxt<'tcx>, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use AsyncResolutionErr::*; + if let WrongOperandShape { reason, .. } = self { + return write!(f, "operator has an unexpected shape: {reason}"); + } + f.write_str(match self { + PinnedAssignmentIsNotACall { .. } => "pinned assignment is not a call", + AssignmentToPinNewIsNotAStatement { .. } => "assignment to Pin::new is not a statement", + AssignmentToAliasOfPinNewInputIsNotACall { .. } => { + "assignment to Pin::new input is not a call" + } + AssignmentToIntoFutureInputIsNotACall { .. } => { + "assignment to into_future input is not a call" + } + ChaseTargetIsNotAFunction { .. } => "chase target is not a function", + WrongOperandShape { .. } => unreachable!(), + }) + } +} + +impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> { + pub(crate) fn try_poll_call_kind<'b>( + &'b self, def_id: DefId, - original_args: &'a [Operand<'tcx>], - ) -> AsyncDeterminationResult> { - let lang_items = self.tcx.lang_items(); + original_args: &'b [Operand<'tcx>], + span: Span, + ) -> AsyncDeterminationResult<'tcx, CallKind<'tcx>> { + let lang_items = self.tcx().lang_items(); if lang_items.future_poll_fn() == Some(def_id) { - match self.find_async_args(original_args) { + match self.find_async_args(original_args, span) { Ok((fun, loc, args)) => { AsyncDeterminationResult::Resolved(CallKind::AsyncPoll(fun, loc, args)) } @@ -202,29 +289,45 @@ impl<'tcx> GraphConstructor<'tcx> { } /// Given the arguments to a `Future::poll` call, walk back through the /// body to find the original future being polled, and get the arguments to the future. - fn find_async_args<'a>( - &'a self, - args: &'a [Operand<'tcx>], - ) -> Result<(FnResolution<'tcx>, Location, Place<'tcx>), String> { + fn find_async_args<'b>( + &'b self, + args: &'b [Operand<'tcx>], + call_span: Span, + ) -> Result<(Instance<'tcx>, Location, Place<'tcx>), Error<'tcx>> { + macro_rules! async_err { + ($msg:expr) => { + return Err(Error::AsyncResolutionErr($msg)) + }; + } macro_rules! let_assert { - ($p:pat = $e:expr, $($arg:tt)*) => { + ($p:pat = $e:expr, $msg:expr) => { let $p = $e else { - let msg = format!($($arg)*); - return Err(format!("Abandoning attempt to handle async because pattern {} could not be matched to {:?}: {}", stringify!($p), $e, msg)); + async_err!($msg); }; - } + }; } - let get_def_for_op = |op: &Operand<'tcx>| -> Result { - let_assert!(Some(place) = op.place(), "Arg is not a place"); + let get_def_for_op = |op: &Operand<'tcx>| -> Result { + let mk_err = |reason| AsyncResolutionErr::WrongOperandShape { + span: call_span, + reason, + }; + let_assert!( + Some(place) = op.place(), + mk_err(OperandShapeViolation::IsNotAPlace) + ); let_assert!( Some(local) = place.as_local(), - "Place {place:?} is not a local" + mk_err(OperandShapeViolation::IsNotLocal) ); let_assert!( Some(locs) = &self.body_assignments.get(&local), - "Local has no assignments" + mk_err(OperandShapeViolation::HasNoAssignments) ); - assert!(locs.len() == 1); + if locs.len() != 1 { + async_err!(mk_err(OperandShapeViolation::WrongNumberOfAssignments( + locs.len() as u16, + ))); + } Ok(locs[0]) }; @@ -236,12 +339,12 @@ impl<'tcx> GraphConstructor<'tcx> { }, .. }) = &self.body.stmt_at(get_def_for_op(&args[0])?), - "Pinned assignment is not a call" + AsyncResolutionErr::PinnedAssignmentIsNotACall { span: call_span } ); debug_assert!(new_pin_args.len() == 1); let future_aliases = self - .aliases(self.tcx.mk_place_deref(new_pin_args[0].place().unwrap())) + .aliases(self.tcx().mk_place_deref(new_pin_args[0].place().unwrap())) .collect_vec(); debug_assert!(future_aliases.len() == 1); let future = *future_aliases.first().unwrap(); @@ -251,7 +354,7 @@ impl<'tcx> GraphConstructor<'tcx> { kind: StatementKind::Assign(box (_, Rvalue::Use(future2))), .. }) = &self.body.stmt_at(get_def_for_op(&Operand::Move(future))?), - "Assignment to pin::new input is not a statement" + AsyncResolutionErr::AssignmentToPinNewIsNotAStatement { span: call_span } ); let_assert!( @@ -262,7 +365,7 @@ impl<'tcx> GraphConstructor<'tcx> { }, .. }) = &self.body.stmt_at(get_def_for_op(future2)?), - "Assignment to alias of pin::new input is not a call" + AsyncResolutionErr::AssignmentToAliasOfPinNewInputIsNotACall { span: call_span } ); let mut chase_target = Err(&into_future_args[0]); @@ -290,17 +393,23 @@ impl<'tcx> GraphConstructor<'tcx> { ), )) => Ok((*def_id, *generic_args, *lhs, async_fn_call_loc)), StatementKind::Assign(box (_, Rvalue::Use(target))) => { - let (op, generics) = self - .operand_to_def_id(target) - .ok_or_else(|| "Nope".to_string())?; + let Some((op, generics)) = self.operand_to_def_id(target) else { + async_err!(AsyncResolutionErr::ChaseTargetIsNotAFunction { + span: call_span + }) + }; Ok((op, generics, target.place().unwrap(), async_fn_call_loc)) } _ => { - panic!("Assignment to into_future input is not a call: {stmt:?}"); + async_err!(AsyncResolutionErr::AssignmentToIntoFutureInputIsNotACall { + span: call_span, + }); } }, _ => { - panic!("Assignment to into_future input is not a call: {stmt:?}"); + async_err!(AsyncResolutionErr::AssignmentToIntoFutureInputIsNotACall { + span: call_span, + }); } }; } @@ -308,11 +417,12 @@ impl<'tcx> GraphConstructor<'tcx> { let (op, generics, calling_convention, async_fn_call_loc) = chase_target.unwrap(); let resolution = utils::try_resolve_function( - self.tcx, + self.tcx(), op, - self.tcx.param_env_reveal_all_normalized(self.def_id), + self.tcx().param_env_reveal_all_normalized(self.def_id), generics, - ); + ) + .ok_or_else(|| Error::instance_resolution_failed(op, generics, call_span))?; Ok((resolution, async_fn_call_loc, calling_convention)) } diff --git a/crates/flowistry_pdg_construction/src/callback.rs b/crates/flowistry_pdg_construction/src/callback.rs index 3008f8f6f6..503eae8372 100644 --- a/crates/flowistry_pdg_construction/src/callback.rs +++ b/crates/flowistry_pdg_construction/src/callback.rs @@ -1,18 +1,18 @@ //! CAllbacks to influence graph construction and their supporting types. use flowistry_pdg::{rustc_portable::Location, CallString}; +use rustc_middle::ty::Instance; -use crate::FnResolution; +use crate::Error; pub trait CallChangeCallback<'tcx> { fn on_inline(&self, info: CallInfo<'tcx>) -> CallChanges; fn on_inline_miss( &self, - _resolution: FnResolution<'tcx>, + _resolution: Instance<'tcx>, _loc: Location, - _under_analysis: FnResolution<'tcx>, - _call_string: Option, + _under_analysis: Instance<'tcx>, _reason: InlineMissReason, ) { } @@ -35,8 +35,8 @@ impl<'tcx> CallChangeCallback<'tcx> for CallChangeCallbackFn<'tcx> { } #[derive(Debug)] -pub enum InlineMissReason { - Async(String), +pub enum InlineMissReason<'tcx> { + Async(Error<'tcx>), } impl Default for CallChanges { @@ -50,11 +50,11 @@ impl Default for CallChanges { /// Information about the function being called. pub struct CallInfo<'tcx> { /// The potentially-monomorphized resolution of the callee. - pub callee: FnResolution<'tcx>, + pub callee: Instance<'tcx>, /// If the callee is an async closure created by an `async fn`, this is the /// `async fn` item. - pub async_parent: Option>, + pub async_parent: Option>, /// The call-stack up to the current call site. pub call_string: CallString, diff --git a/crates/flowistry_pdg_construction/src/calling_convention.rs b/crates/flowistry_pdg_construction/src/calling_convention.rs index 76b1ce91a5..f86b4cbc97 100644 --- a/crates/flowistry_pdg_construction/src/calling_convention.rs +++ b/crates/flowistry_pdg_construction/src/calling_convention.rs @@ -3,11 +3,11 @@ use log::trace; use rustc_abi::FieldIdx; use rustc_middle::{ - mir::{Body, HasLocalDecls, Operand, Place, PlaceElem, RETURN_PLACE}, + mir::{tcx::PlaceTy, Body, HasLocalDecls, Operand, Place, PlaceElem, RETURN_PLACE}, ty::TyCtxt, }; -use crate::{async_support::AsyncInfo, construct::CallKind, utils}; +use crate::{async_support::AsyncInfo, local_analysis::CallKind, utils}; pub enum CallingConvention<'tcx, 'a> { Direct(&'a [Operand<'tcx>]), @@ -33,6 +33,7 @@ impl<'tcx, 'a> CallingConvention<'tcx, 'a> { } } + #[allow(clippy::too_many_arguments)] pub(crate) fn translate_to_parent( &self, child: Place<'tcx>, @@ -41,6 +42,7 @@ impl<'tcx, 'a> CallingConvention<'tcx, 'a> { parent_body: &Body<'tcx>, parent_def_id: DefId, destination: Place<'tcx>, + target_ty: Option>, ) -> Option> { trace!(" Translating child place: {child:?}"); let (parent_place, child_projection) = @@ -53,6 +55,7 @@ impl<'tcx, 'a> CallingConvention<'tcx, 'a> { tcx, parent_body, parent_def_id, + target_ty, )) } @@ -66,24 +69,26 @@ impl<'tcx, 'a> CallingConvention<'tcx, 'a> { ) -> Option<(Place<'tcx>, &[PlaceElem<'tcx>])> { let result = match self { // Async return must be handled special, because it gets wrapped in `Poll::Ready` - Self::Async { .. } if child.local == RETURN_PLACE => { - let in_poll = destination.project_deeper( - &[PlaceElem::Downcast(None, async_info.poll_ready_variant_idx)], - tcx, - ); - let field_idx = async_info.poll_ready_field_idx; - let child_inner_return_type = in_poll - .ty(parent_body.local_decls(), tcx) - .field_ty(tcx, field_idx); - ( - in_poll.project_deeper( - &[PlaceElem::Field(field_idx, child_inner_return_type)], + _ if child.local == RETURN_PLACE => match self { + Self::Async { .. } => { + let in_poll = destination.project_deeper( + &[PlaceElem::Downcast(None, async_info.poll_ready_variant_idx)], tcx, - ), - &child.projection[..], - ) - } - _ if child.local == RETURN_PLACE => (destination, &child.projection[..]), + ); + let field_idx = async_info.poll_ready_field_idx; + let child_inner_return_type = in_poll + .ty(parent_body.local_decls(), tcx) + .field_ty(tcx, field_idx); + ( + in_poll.project_deeper( + &[PlaceElem::Field(field_idx, child_inner_return_type)], + tcx, + ), + &child.projection[..], + ) + } + _ => (destination, &child.projection[..]), + }, // Map arguments to the argument array Self::Direct(args) => ( args[child.local.as_usize() - 1].place()?, @@ -107,7 +112,6 @@ impl<'tcx, 'a> CallingConvention<'tcx, 'a> { (closure_arg.place()?, &child.projection[..]) } else { let tuple_arg = tupled_arguments.place()?; - let _projection = child.projection.to_vec(); let field = FieldIdx::from_usize(child.local.as_usize() - 2); let field_ty = tuple_arg.ty(parent_body, tcx).field_ty(tcx, field); ( diff --git a/crates/flowistry_pdg_construction/src/construct.rs b/crates/flowistry_pdg_construction/src/construct.rs index 3c33128134..371aecd8ab 100644 --- a/crates/flowistry_pdg_construction/src/construct.rs +++ b/crates/flowistry_pdg_construction/src/construct.rs @@ -1,158 +1,385 @@ -use std::{borrow::Cow, collections::HashSet, iter, rc::Rc}; - -use df::{fmt::DebugWithContext, Analysis, AnalysisDomain, Results, ResultsVisitor}; +//! Constructing PDGs. +//! +//! The construction is split into two steps. A local analysis and a +//! cross-procedure PDG merging. +//! +//! 1. [`GraphConstructor`] is responsible for the local analysis. It performs a +//! procedure-local fixpoint analysis to determine a pre- and post effect +//! [`InstructionState`] at each instruction in the procedure. +//! 2. [`PartialGraph`] implements [`ResultsVisitor`] over the analysis result + +use std::{fmt::Display, rc::Rc}; + +use anyhow::anyhow; use either::Either; -use flowistry::mir::placeinfo::PlaceInfo; -use flowistry_pdg::{CallString, GlobalLocation, RichLocation}; -use itertools::Itertools; -use log::{debug, log_enabled, trace, Level}; + +use flowistry_pdg::{CallString, GlobalLocation}; + +use log::trace; use petgraph::graph::DiGraph; -use rustc_abi::VariantIdx; -use rustc_borrowck::consumers::{places_conflict, BodyWithBorrowckFacts, PlaceConflictBias}; use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hir::def_id::{DefId, LocalDefId}; use rustc_index::IndexVec; +use rustc_macros::{TyDecodable, TyEncodable}; use rustc_middle::{ mir::{ - visit::Visitor, AggregateKind, BasicBlock, Body, Location, Operand, Place, PlaceElem, - Rvalue, Statement, Terminator, TerminatorEdges, TerminatorKind, RETURN_PLACE, + visit::Visitor, AggregateKind, Location, Operand, Place, Rvalue, Terminator, TerminatorKind, }, - ty::{GenericArg, List, ParamEnv, TyCtxt, TyKind}, + ty::{GenericArgsRef, Instance, TyCtxt}, }; -use rustc_mir_dataflow::{self as df}; -use rustc_span::ErrorGuaranteed; +use rustc_mir_dataflow::{AnalysisDomain, Results, ResultsVisitor}; +use rustc_span::Span; use rustc_utils::cache::Cache; -use rustc_utils::{ - mir::{borrowck_facts, control_dependencies::ControlDependencies}, - BodyExt, PlaceExt, -}; -use super::async_support::*; -use super::calling_convention::*; -use super::graph::{DepEdge, DepGraph, DepNode}; -use super::utils::{self, FnResolution}; use crate::{ - graph::{SourceUse, TargetUse}, - mutation::Time, - utils::{is_non_default_trait_method, manufacture_substs_for}, - InlineMissReason, SkipCall, -}; -use crate::{ - mutation::{ModularMutationVisitor, Mutation}, - try_resolve_function, CallChangeCallback, CallChanges, CallInfo, + async_support::*, + graph::{ + push_call_string_root, DepEdge, DepGraph, DepNode, PartialGraph, SourceUse, TargetUse, + }, + local_analysis::{CallHandling, InstructionState, LocalAnalysis}, + mutation::{ModularMutationVisitor, Mutation, Time}, + utils::{manufacture_substs_for, try_resolve_function}, + CallChangeCallback, GraphLoader, }; -/// Top-level parameters to PDG construction. -#[derive(Clone)] -pub struct PdgParams<'tcx> { - tcx: TyCtxt<'tcx>, - root: FnResolution<'tcx>, - call_change_callback: Option + 'tcx>>, - dump_mir: bool, +/// A memoizing constructor of PDGs. +/// +/// Each `(LocalDefId, GenericArgs)` pair is guaranteed to be constructed only +/// once. +pub struct MemoPdgConstructor<'tcx> { + pub(crate) tcx: TyCtxt<'tcx>, + pub(crate) call_change_callback: Option + 'tcx>>, + pub(crate) dump_mir: bool, + pub(crate) async_info: Rc, + pub(crate) pdg_cache: PdgCache<'tcx>, + pub(crate) loader: Box + 'tcx>, } -impl<'tcx> PdgParams<'tcx> { - /// Must provide the [`TyCtxt`] and the [`LocalDefId`] of the function that is the root of the PDG. - pub fn new(tcx: TyCtxt<'tcx>, root: LocalDefId) -> Result { - let root = try_resolve_function( - tcx, - root.to_def_id(), - tcx.param_env_reveal_all_normalized(root), - manufacture_substs_for(tcx, root)?, - ); - Ok(PdgParams { +#[derive(Debug, TyEncodable, TyDecodable, Clone, Hash, Eq, PartialEq)] +pub enum Error<'tcx> { + InstanceResolutionFailed { + function: DefId, + generics: GenericArgsRef<'tcx>, + span: Span, + }, + Impossible, + FailedLoadingExternalFunction { + function: DefId, + span: Span, + }, + RustcReportedError, + CrateExistsButItemIsNotFound { + function: DefId, + }, + TooManyPredicatesForSynthesizingGenerics { + function: DefId, + number: u32, + }, + BoundVariablesInPredicates { + function: DefId, + }, + TraitRefWithBinder { + function: DefId, + }, + ConstantInGenerics { + function: DefId, + }, + OperandIsNotFunctionType { + op: String, + }, + AsyncResolutionErr(AsyncResolutionErr), + NormalizationError { + instance: Instance<'tcx>, + span: Span, + error: String, + }, +} + +pub trait EmittableError<'tcx> { + fn span(&self, _tcx: TyCtxt<'tcx>) -> Option { + None + } + fn msg(&self, tcx: TyCtxt<'tcx>, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result; + + fn emit(&self, tcx: TyCtxt<'tcx>) { + default_emit_error(self, tcx) + } +} + +pub trait UnwrapEmittable<'tcx> { + type Inner; + fn unwrap_emittable(self, tcx: TyCtxt<'tcx>) -> Self::Inner; +} + +impl<'tcx, T, E: EmittableError<'tcx>> UnwrapEmittable<'tcx> for Result { + type Inner = T; + fn unwrap_emittable(self, tcx: TyCtxt<'tcx>) -> Self::Inner { + match self { + Result::Ok(inner) => inner, + Result::Err(e) => { + default_emit_error(&e, tcx); + panic!("unwrap") + } + } + } +} + +pub fn default_emit_error<'tcx>(e: &(impl EmittableError<'tcx> + ?Sized), tcx: TyCtxt<'tcx>) { + struct FmtWithTcx<'tcx, A> { + tcx: TyCtxt<'tcx>, + inner: A, + } + impl<'tcx, A: EmittableError<'tcx> + ?Sized> Display for FmtWithTcx<'tcx, &'_ A> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.msg(self.tcx, f) + } + } + + let msg = format!("{}", FmtWithTcx { tcx, inner: e }); + if let Some(span) = e.span(tcx) { + tcx.sess.span_err(span, msg); + } else { + tcx.sess.err(msg); + } +} + +impl<'tcx> EmittableError<'tcx> for Error<'tcx> { + fn span(&self, tcx: TyCtxt<'tcx>) -> Option { + use Error::*; + match self { + AsyncResolutionErr(e) => e.span(tcx), + InstanceResolutionFailed { span, .. } + | FailedLoadingExternalFunction { span, .. } + | NormalizationError { span, .. } => Some(*span), + BoundVariablesInPredicates { function } + | TraitRefWithBinder { function } + | ConstantInGenerics { function } => Some(tcx.def_span(*function)), + _ => None, + } + } + + fn msg(&self, tcx: TyCtxt, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use Error::*; + match self { + InstanceResolutionFailed { + function, generics, .. + } => write!( + f, + "could not resolve instance for {} with generics {generics:?}", + tcx.def_path_debug_str(*function) + ), + Impossible => f.write_str("internal compiler error, this state should be impossible"), + FailedLoadingExternalFunction { function, .. } => write!( + f, + "failed loading external function {}", + tcx.def_path_debug_str(*function) + ), + RustcReportedError => f.write_str("see previously reported errors"), + CrateExistsButItemIsNotFound { function } => write!( + f, + "found a crate for item {}, but could not find a PDG for it", + tcx.def_path_debug_str(*function) + ), + TooManyPredicatesForSynthesizingGenerics { number, .. } => write!( + f, + "only one predicate can be synthesized to a `dyn`, found {number}" + ), + BoundVariablesInPredicates { .. } => { + f.write_str("bound variables in predicates are not supported") + } + TraitRefWithBinder { .. } => { + f.write_str("trait refs for `dyn` synthesis cannot have binders") + } + ConstantInGenerics { .. } => { + f.write_str("constants in generics for are not supported for analysis entrypoints") + } + OperandIsNotFunctionType { op } => { + write!(f, "operand {op} is not of function type") + } + AsyncResolutionErr(e) => e.msg(tcx, f), + NormalizationError { + instance, error, .. + } => write!( + f, + "failed to normalize with instance {instance:?} because {error}" + ), + } + } +} + +impl<'tcx> Error<'tcx> { + pub fn instance_resolution_failed( + function: DefId, + generics: GenericArgsRef<'tcx>, + span: Span, + ) -> Self { + Self::InstanceResolutionFailed { + function, + generics, + span, + } + } + + pub fn operand_is_not_function_type(op: &Operand) -> Self { + Self::OperandIsNotFunctionType { + op: format!("{op:?}"), + } + } +} + +impl<'tcx> MemoPdgConstructor<'tcx> { + /// Initialize the constructor, parameterized over an [`ArtifactLoader`] for + /// retrieving PDGs of functions from dependencies. + pub fn new(tcx: TyCtxt<'tcx>, loader: impl GraphLoader<'tcx> + 'tcx) -> Self { + Self { tcx, - root, call_change_callback: None, dump_mir: false, - }) + async_info: AsyncInfo::make(tcx).expect("Async functions are not defined"), + pdg_cache: Default::default(), + loader: Box::new(loader), + } } - pub fn with_dump_mir(mut self, dump_mir: bool) -> Self { + /// Dump the MIR of any function that is visited. + pub fn with_dump_mir(&mut self, dump_mir: bool) -> &mut Self { self.dump_mir = dump_mir; self } - /// Provide a callback for changing the behavior of how the PDG generator manages function calls. - /// - /// Currently, this callback can either indicate that a function call should be skipped (i.e., not recursed into), - /// or indicate that a set of fake effects should occur at the function call. See [`CallChanges`] for details. - /// - /// For example, in this code: - /// - /// ``` - /// fn incr(x: i32) -> i32 { x + 1 } - /// fn main() { - /// let a = 0; - /// let b = incr(a); - /// } - /// ``` - /// - /// When inspecting the call `incr(a)`, the callback will be called with `f({callee: incr, call_string: [main]})`. - /// You could apply a hard limit on call string length like this: - /// - /// ``` - /// # #![feature(rustc_private)] - /// # extern crate rustc_middle; - /// # use flowistry_pdg_construction::{PdgParams, SkipCall, CallChanges, CallChangeCallbackFn}; - /// # use rustc_middle::ty::TyCtxt; - /// # const THRESHOLD: usize = 5; - /// # fn f<'tcx>(tcx: TyCtxt<'tcx>, params: PdgParams<'tcx>) -> PdgParams<'tcx> { - /// params.with_call_change_callback(CallChangeCallbackFn::new(|info| { - /// let skip = if info.call_string.len() > THRESHOLD { - /// SkipCall::Skip - /// } else { - /// SkipCall::NoSkip - /// }; - /// CallChanges::default().with_skip(skip) - /// })) - /// # } - /// ``` - pub fn with_call_change_callback(self, f: impl CallChangeCallback<'tcx> + 'tcx) -> Self { - PdgParams { - call_change_callback: Some(Rc::new(f)), - ..self + /// Register a callback to determine how to deal with function calls seen. + /// Overwrites any previously registered callback with no warning. + pub fn with_call_change_callback( + &mut self, + callback: impl CallChangeCallback<'tcx> + 'tcx, + ) -> &mut Self { + self.call_change_callback.replace(Rc::new(callback)); + self + } + + /// Construct the intermediate PDG for this function. Instantiates any + /// generic arguments as `dyn `. + pub fn construct_root<'a>( + &'a self, + function: LocalDefId, + ) -> Result<&'a PartialGraph<'tcx>, Vec>> { + let generics = + manufacture_substs_for(self.tcx, function.to_def_id()).map_err(|i| vec![i])?; + let resolution = try_resolve_function( + self.tcx, + function.to_def_id(), + self.tcx.param_env_reveal_all_normalized(function), + generics, + ) + .ok_or_else(|| { + vec![Error::instance_resolution_failed( + function.to_def_id(), + generics, + self.tcx.def_span(function), + )] + })?; + self.construct_for(resolution) + .and_then(|f| f.ok_or(vec![Error::Impossible])) + } + + pub(crate) fn construct_for<'a>( + &'a self, + resolution: Instance<'tcx>, + ) -> Result>, Vec>> { + let def_id = resolution.def_id(); + let generics = resolution.args; + if let Some(local) = def_id.as_local() { + let r = self + .pdg_cache + .get_maybe_recursive((local, generics), |_| { + let g = LocalAnalysis::new(self, resolution) + .map_err(|e| vec![e])? + .construct_partial()?; + trace!( + "Computed new for {} {generics:?}", + self.tcx.def_path_str(local) + ); + g.check_invariants(); + Ok(g) + }) + .map(Result::as_ref) + .transpose() + .map_err(Clone::clone)?; + if let Some(g) = r { + trace!( + "Found pdg for {} with {:?}", + self.tcx.def_path_str(local), + g.generics + ) + }; + Ok(r) + } else { + self.loader.load(def_id) } } + + /// Has a PDG been constructed for this instance before? + pub fn is_in_cache(&self, resolution: Instance<'tcx>) -> bool { + if let Some(local) = resolution.def_id().as_local() { + self.pdg_cache.is_in_cache(&(local, resolution.args)) + } else { + matches!(self.loader.load(resolution.def_id()), Ok(Some(_))) + } + } + + /// Construct a final PDG for this function. Same as + /// [`Self::construct_root`] this instantiates all generics as `dyn`. + pub fn construct_graph( + &self, + function: LocalDefId, + ) -> Result, Vec>> { + let _args = manufacture_substs_for(self.tcx, function.to_def_id()) + .map_err(|_| anyhow!("rustc error")); + let g = self.construct_root(function)?.to_petgraph(); + Ok(g) + } } -#[derive(PartialEq, Eq, Default, Clone, Debug)] -pub struct InstructionState<'tcx> { - last_mutation: FxHashMap, FxHashSet>, +pub(crate) struct WithConstructionErrors<'tcx, A> { + pub(crate) inner: A, + pub errors: FxHashSet>, } -impl DebugWithContext for InstructionState<'_> {} +impl<'tcx, A> WithConstructionErrors<'tcx, A> { + pub fn new(inner: A) -> Self { + Self { + inner, + errors: Default::default(), + } + } -impl<'tcx> df::JoinSemiLattice for InstructionState<'tcx> { - fn join(&mut self, other: &Self) -> bool { - utils::hashmap_join( - &mut self.last_mutation, - &other.last_mutation, - utils::hashset_join, - ) + pub fn into_result(self) -> Result>> { + if self.errors.is_empty() { + Ok(self.inner) + } else { + Err(self.errors.into_iter().collect()) + } } } -#[derive(Default, Debug)] -pub struct PartialGraph<'tcx> { - nodes: FxHashSet>, - edges: FxHashSet<(DepNode<'tcx>, DepNode<'tcx>, DepEdge)>, -} +type DfResults<'mir, 'tcx> = Results<'tcx, DfAna<'mir, 'tcx>>; + +type DfAna<'mir, 'tcx> = WithConstructionErrors<'tcx, &'mir LocalAnalysis<'tcx, 'mir>>; -impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, Results<'tcx, DfAnalysis<'mir, 'tcx>>> - for PartialGraph<'tcx> +impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, DfResults<'mir, 'tcx>> + for WithConstructionErrors<'tcx, PartialGraph<'tcx>> { - type FlowState = as AnalysisDomain<'tcx>>::Domain; + type FlowState = as AnalysisDomain<'tcx>>::Domain; fn visit_statement_before_primary_effect( &mut self, - results: &Results<'tcx, DfAnalysis<'mir, 'tcx>>, + results: &DfResults<'mir, 'tcx>, state: &Self::FlowState, statement: &'mir rustc_middle::mir::Statement<'tcx>, location: Location, ) { - let mut vis = self.modular_mutation_visitor(results, state); + let mut vis = self.inner.modular_mutation_visitor(results, state); vis.visit_statement(statement, location) } @@ -175,119 +402,14 @@ impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, Results<'tcx, DfAnalysis<'mir, 'tcx> /// call site. fn visit_terminator_before_primary_effect( &mut self, - results: &Results<'tcx, DfAnalysis<'mir, 'tcx>>, + results: &DfResults<'mir, 'tcx>, state: &Self::FlowState, terminator: &'mir rustc_middle::mir::Terminator<'tcx>, location: Location, ) { - let mut handle_as_inline = || { - let TerminatorKind::Call { - func, - args, - destination, - .. - } = &terminator.kind - else { - return None; - }; - let constructor = results.analysis.0; - - let (child_constructor, calling_convention) = - match constructor.determine_call_handling(location, func, args)? { - CallHandling::Ready(one, two) => (one, two), - CallHandling::ApproxAsyncFn => { - // Register a synthetic assignment of `future = (arg0, arg1, ...)`. - let rvalue = Rvalue::Aggregate( - Box::new(AggregateKind::Tuple), - IndexVec::from_iter(args.iter().cloned()), - ); - self.modular_mutation_visitor(results, state).visit_assign( - destination, - &rvalue, - location, - ); - return Some(()); - } - CallHandling::ApproxAsyncSM(how) => { - how( - constructor, - &mut self.modular_mutation_visitor(results, state), - args, - *destination, - location, - ); - return Some(()); - } - }; - - let child_graph = child_constructor.construct_partial_cached(); - - let parentable_srcs = - child_graph.parentable_srcs(child_constructor.def_id, &child_constructor.body); - let parentable_dsts = - child_graph.parentable_dsts(child_constructor.def_id, &child_constructor.body); - - // For each source node CHILD that is parentable to PLACE, - // add an edge from PLACE -> CHILD. - trace!("PARENT -> CHILD EDGES:"); - for (child_src, _kind) in parentable_srcs { - if let Some(parent_place) = calling_convention.translate_to_parent( - child_src.place, - &constructor.async_info, - constructor.tcx, - &constructor.body, - constructor.def_id.to_def_id(), - *destination, - ) { - self.register_mutation( - results, - state, - Inputs::Unresolved { - places: vec![(parent_place, None)], - }, - Either::Right(child_src), - location, - TargetUse::Assign, - ); - } - } - - // For each destination node CHILD that is parentable to PLACE, - // add an edge from CHILD -> PLACE. - // - // PRECISION TODO: for a given child place, we only want to connect - // the *last* nodes in the child function to the parent, not *all* of them. - trace!("CHILD -> PARENT EDGES:"); - for (child_dst, kind) in parentable_dsts { - if let Some(parent_place) = calling_convention.translate_to_parent( - child_dst.place, - &constructor.async_info, - constructor.tcx, - &constructor.body, - constructor.def_id.to_def_id(), - *destination, - ) { - self.register_mutation( - results, - state, - Inputs::Resolved { - node: child_dst, - node_use: SourceUse::Operand, - }, - Either::Left(parent_place), - location, - kind.map_or(TargetUse::Return, TargetUse::MutArg), - ); - } - } - self.nodes.extend(&child_graph.nodes); - self.edges.extend(&child_graph.edges); - Some(()) - }; - if let TerminatorKind::SwitchInt { discr, .. } = &terminator.kind { if let Some(place) = discr.place() { - self.register_mutation( + self.inner.register_mutation( results, state, Inputs::Unresolved { @@ -301,41 +423,52 @@ impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, Results<'tcx, DfAnalysis<'mir, 'tcx> return; } - if handle_as_inline().is_none() { - trace!("Handling terminator {:?} as not inlined", terminator.kind); - let mut arg_vis = ModularMutationVisitor::new( - &results.analysis.0.place_info, - move |location, mutation| { - self.register_mutation( - results, - state, - Inputs::Unresolved { - places: mutation.inputs, - }, - Either::Left(mutation.mutated), - location, - mutation.mutation_reason, - ) - }, - ); - arg_vis.set_time(Time::Before); - arg_vis.visit_terminator(terminator, location); + match self + .inner + .handle_as_inline(results, state, terminator, location) + { + Ok(false) => (), + Ok(true) => return, + Err(e) => self.errors.extend(e), } + trace!("Handling terminator {:?} as not inlined", terminator.kind); + let mut arg_vis = ModularMutationVisitor::new( + &results.analysis.inner.place_info, + move |location, mutation| { + self.inner.register_mutation( + results, + state, + Inputs::Unresolved { + places: mutation.inputs, + }, + Either::Left(mutation.mutated), + location, + mutation.mutation_reason, + ) + }, + ); + arg_vis.set_time(Time::Before); + arg_vis.visit_terminator(terminator, location); } fn visit_terminator_after_primary_effect( &mut self, - results: &Results<'tcx, DfAnalysis<'mir, 'tcx>>, + results: &DfResults<'mir, 'tcx>, state: &Self::FlowState, terminator: &'mir rustc_middle::mir::Terminator<'tcx>, location: Location, ) { if let TerminatorKind::Call { func, args, .. } = &terminator.kind { - let constructor = results.analysis.0; + let constructor = results.analysis.inner; if matches!( - constructor.determine_call_handling(location, func, args), - Some(CallHandling::Ready(_, _)) + constructor.determine_call_handling( + location, + func, + args, + terminator.source_info.span + ), + Ok(Some(CallHandling::Ready { .. })) ) { return; } @@ -343,9 +476,9 @@ impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, Results<'tcx, DfAnalysis<'mir, 'tcx> trace!("Handling terminator {:?} as not inlined", terminator.kind); let mut arg_vis = ModularMutationVisitor::new( - &results.analysis.0.place_info, + &results.analysis.inner.place_info, move |location, mutation| { - self.register_mutation( + self.inner.register_mutation( results, state, Inputs::Unresolved { @@ -361,65 +494,165 @@ impl<'mir, 'tcx> ResultsVisitor<'mir, 'tcx, Results<'tcx, DfAnalysis<'mir, 'tcx> arg_vis.visit_terminator(terminator, location); } } -fn as_arg<'tcx>(node: &DepNode<'tcx>, def_id: LocalDefId, body: &Body<'tcx>) -> Option> { - if node.at.leaf().function != def_id { - return None; - } - if node.place.local == RETURN_PLACE { - Some(None) - } else if node.place.is_arg(body) { - Some(Some(node.place.local.as_u32() as u8 - 1)) - } else { - None - } -} impl<'tcx> PartialGraph<'tcx> { - fn modular_mutation_visitor<'a>( + fn modular_mutation_visitor<'a, 'mir>( &'a mut self, - results: &'a Results<'tcx, DfAnalysis<'_, 'tcx>>, + results: &'a DfResults<'mir, 'tcx>, state: &'a InstructionState<'tcx>, ) -> ModularMutationVisitor<'a, 'tcx, impl FnMut(Location, Mutation<'tcx>) + 'a> { - ModularMutationVisitor::new(&results.analysis.0.place_info, move |location, mutation| { - self.register_mutation( - results, - state, - Inputs::Unresolved { - places: mutation.inputs, - }, - Either::Left(mutation.mutated), - location, - mutation.mutation_reason, - ) - }) - } - fn parentable_srcs<'a>( - &'a self, - def_id: LocalDefId, - body: &'a Body<'tcx>, - ) -> impl Iterator, Option)> + 'a { - self.edges - .iter() - .map(|(src, _, _)| *src) - .filter_map(move |a| Some((a, as_arg(&a, def_id, body)?))) - .filter(|(node, _)| node.at.leaf().location.is_start()) + ModularMutationVisitor::new( + &results.analysis.inner.place_info, + move |location, mutation| { + self.register_mutation( + results, + state, + Inputs::Unresolved { + places: mutation.inputs, + }, + Either::Left(mutation.mutated), + location, + mutation.mutation_reason, + ) + }, + ) } - fn parentable_dsts<'a>( - &'a self, - def_id: LocalDefId, - body: &'a Body<'tcx>, - ) -> impl Iterator, Option)> + 'a { - self.edges - .iter() - .map(|(_, dst, _)| *dst) - .filter_map(move |a| Some((a, as_arg(&a, def_id, body)?))) - .filter(|node| node.0.at.leaf().location.is_end()) + /// returns whether we were able to successfully handle this as inline + fn handle_as_inline<'a>( + &mut self, + results: &DfResults<'a, 'tcx>, + state: &'a InstructionState<'tcx>, + terminator: &Terminator<'tcx>, + location: Location, + ) -> Result>> { + let TerminatorKind::Call { + func, + args, + destination, + .. + } = &terminator.kind + else { + return Ok(false); + }; + let constructor = results.analysis.inner; + let gloc = GlobalLocation { + location: location.into(), + function: constructor.def_id.to_def_id(), + }; + + let Some(handling) = constructor.determine_call_handling( + location, + func, + args, + terminator.source_info.span, + )? + else { + return Ok(false); + }; + + let (child_descriptor, calling_convention) = match handling { + CallHandling::Ready { + calling_convention, + descriptor, + } => (descriptor, calling_convention), + CallHandling::ApproxAsyncFn => { + // Register a synthetic assignment of `future = (arg0, arg1, ...)`. + let rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Tuple), + IndexVec::from_iter(args.iter().cloned()), + ); + self.modular_mutation_visitor(results, state).visit_assign( + destination, + &rvalue, + location, + ); + return Ok(true); + } + CallHandling::ApproxAsyncSM(how) => { + how( + constructor, + &mut self.modular_mutation_visitor(results, state), + args, + *destination, + location, + ); + return Ok(true); + } + }; + + let child_graph = push_call_string_root(child_descriptor, gloc); + + trace!("Child graph has generics {:?}", child_descriptor.generics); + + let is_root = |n: CallString| n.len() == 2; + + // For each source node CHILD that is parentable to PLACE, + // add an edge from PLACE -> CHILD. + trace!("PARENT -> CHILD EDGES:"); + for (child_src, _kind) in child_graph.parentable_srcs(is_root) { + if let Some(parent_place) = calling_convention.translate_to_parent( + child_src.place, + constructor.async_info(), + constructor.tcx(), + &constructor.body, + constructor.def_id.to_def_id(), + *destination, + Some(child_src.place.ty(child_descriptor, constructor.tcx())), + ) { + self.register_mutation( + results, + state, + Inputs::Unresolved { + places: vec![(parent_place, None)], + }, + Either::Right(child_src), + location, + TargetUse::Assign, + ); + } + } + + // For each destination node CHILD that is parentable to PLACE, + // add an edge from CHILD -> PLACE. + // + // PRECISION TODO: for a given child place, we only want to connect + // the *last* nodes in the child function to the parent, not *all* of them. + trace!("CHILD -> PARENT EDGES:"); + for (child_dst, kind) in child_graph.parentable_dsts(is_root) { + if let Some(parent_place) = calling_convention.translate_to_parent( + child_dst.place, + constructor.async_info(), + constructor.tcx(), + &constructor.body, + constructor.def_id.to_def_id(), + *destination, + Some(child_dst.place.ty(child_descriptor, constructor.tcx())), + ) { + self.register_mutation( + results, + state, + Inputs::Resolved { + node: child_dst, + node_use: SourceUse::Operand, + }, + Either::Left(parent_place), + location, + kind.map_or(TargetUse::Return, TargetUse::MutArg), + ); + } + } + self.nodes.extend(child_graph.nodes); + self.edges.extend(child_graph.edges); + self.monos.extend(child_graph.monos); + self.monos + .insert(CallString::single(gloc), child_descriptor.generics); + Ok(true) } - fn register_mutation( + fn register_mutation<'a>( &mut self, - results: &Results<'tcx, DfAnalysis<'_, 'tcx>>, + results: &DfResults<'a, 'tcx>, state: &InstructionState<'tcx>, inputs: Inputs<'tcx>, mutated: Either, DepNode<'tcx>>, @@ -427,7 +660,7 @@ impl<'tcx> PartialGraph<'tcx> { target_use: TargetUse, ) { trace!("Registering mutation to {mutated:?} with inputs {inputs:?} at {location:?}"); - let constructor = results.analysis.0; + let constructor = results.analysis.inner; let ctrl_inputs = constructor.find_control_inputs(location); trace!(" Found control inputs {ctrl_inputs:?}"); @@ -453,10 +686,8 @@ impl<'tcx> PartialGraph<'tcx> { let outputs = match mutated { Either::Right(node) => vec![node], - Either::Left(place) => results - .analysis - .0 - .find_outputs(state, place, location) + Either::Left(place) => constructor + .find_outputs(place, location) .into_iter() .map(|t| t.1) .collect(), @@ -490,33 +721,8 @@ impl<'tcx> PartialGraph<'tcx> { } } -pub(crate) struct CallingContext<'tcx> { - pub(crate) call_string: CallString, - pub(crate) param_env: ParamEnv<'tcx>, - pub(crate) call_stack: Vec, -} - -type PdgCache<'tcx> = Rc>>>; - -pub struct GraphConstructor<'tcx> { - pub(crate) tcx: TyCtxt<'tcx>, - pub(crate) params: PdgParams<'tcx>, - body_with_facts: &'tcx BodyWithBorrowckFacts<'tcx>, - pub(crate) body: Cow<'tcx, Body<'tcx>>, - pub(crate) def_id: LocalDefId, - place_info: PlaceInfo<'tcx>, - control_dependencies: ControlDependencies, - pub(crate) body_assignments: utils::BodyAssignments, - pub(crate) calling_context: Option>, - start_loc: FxHashSet, - pub(crate) async_info: Rc, - pub(crate) pdg_cache: PdgCache<'tcx>, -} - -fn other_as_arg<'tcx>(place: Place<'tcx>, body: &Body<'tcx>) -> Option { - (body.local_kind(place.local) == rustc_middle::mir::LocalKind::Arg) - .then(|| place.local.as_u32() as u8 - 1) -} +type PdgCache<'tcx> = + Rc), Result, Vec>>>>; #[derive(Debug)] enum Inputs<'tcx> { @@ -529,705 +735,9 @@ enum Inputs<'tcx> { }, } -impl<'tcx> GraphConstructor<'tcx> { - /// Creates a [`GraphConstructor`] at the root of the PDG. - pub fn root(params: PdgParams<'tcx>) -> Self { - let tcx = params.tcx; - GraphConstructor::new( - params, - None, - AsyncInfo::make(tcx).expect("async functions are not defined"), - &PdgCache::default(), - ) - } - - /// Creates [`GraphConstructor`] for a function resolved as `fn_resolution` in a given `calling_context`. - pub(crate) fn new( - params: PdgParams<'tcx>, - calling_context: Option>, - async_info: Rc, - pdg_cache: &PdgCache<'tcx>, - ) -> Self { - let tcx = params.tcx; - let def_id = params.root.def_id().expect_local(); - let body_with_facts = borrowck_facts::get_body_with_borrowck_facts(tcx, def_id); - let param_env = match &calling_context { - Some(cx) => cx.param_env, - None => ParamEnv::reveal_all(), - }; - let body = params - .root - .try_monomorphize(tcx, param_env, &body_with_facts.body); - - if params.dump_mir { - use std::io::Write; - let path = tcx.def_path_str(def_id) + ".mir"; - let mut f = std::fs::File::create(path.as_str()).unwrap(); - write!(f, "{}", body.to_string(tcx).unwrap()).unwrap(); - debug!("Dumped debug MIR {path}"); - } - - let place_info = PlaceInfo::build(tcx, def_id.to_def_id(), body_with_facts); - let control_dependencies = body.control_dependencies(); - - let mut start_loc = FxHashSet::default(); - start_loc.insert(RichLocation::Start); - - let body_assignments = utils::find_body_assignments(&body); - let pdg_cache = Rc::clone(pdg_cache); - - GraphConstructor { - tcx, - params, - body_with_facts, - body, - place_info, - control_dependencies, - start_loc, - def_id, - calling_context, - body_assignments, - async_info, - pdg_cache, - } - } - - /// Creates a [`GlobalLocation`] at the current function. - fn make_global_loc(&self, location: impl Into) -> GlobalLocation { - GlobalLocation { - function: self.def_id, - location: location.into(), - } - } - - pub(crate) fn calling_context_for( - &self, - call_stack_extension: DefId, - location: Location, - ) -> CallingContext<'tcx> { - CallingContext { - call_string: self.make_call_string(location), - param_env: self.tcx.param_env_reveal_all_normalized(self.def_id), - call_stack: match &self.calling_context { - Some(cx) => { - let mut cx = cx.call_stack.clone(); - cx.push(call_stack_extension); - cx - } - None => vec![], - }, - } - } - - pub(crate) fn pdg_params_for_call(&self, root: FnResolution<'tcx>) -> PdgParams<'tcx> { - PdgParams { - root, - ..self.params.clone() - } - } - - /// Creates a [`CallString`] with the current function at the root, - /// with the rest of the string provided by the [`CallingContext`]. - fn make_call_string(&self, location: impl Into) -> CallString { - let global_loc = self.make_global_loc(location); - match &self.calling_context { - Some(cx) => cx.call_string.push(global_loc), - None => CallString::single(global_loc), - } - } - - fn make_dep_node( - &self, - place: Place<'tcx>, - location: impl Into, - ) -> DepNode<'tcx> { - DepNode::new(place, self.make_call_string(location), self.tcx, &self.body) - } - - /// Returns all pairs of `(src, edge)`` such that the given `location` is control-dependent on `edge` - /// with input `src`. - fn find_control_inputs(&self, location: Location) -> Vec<(DepNode<'tcx>, DepEdge)> { - let mut blocks_seen = HashSet::::from_iter(Some(location.block)); - let mut block_queue = vec![location.block]; - let mut out = vec![]; - while let Some(block) = block_queue.pop() { - if let Some(ctrl_deps) = self.control_dependencies.dependent_on(block) { - for dep in ctrl_deps.iter() { - let ctrl_loc = self.body.terminator_loc(dep); - let Terminator { - kind: TerminatorKind::SwitchInt { discr, .. }, - .. - } = self.body.basic_blocks[dep].terminator() - else { - if blocks_seen.insert(dep) { - block_queue.push(dep); - } - continue; - }; - let Some(ctrl_place) = discr.place() else { - continue; - }; - let at = self.make_call_string(ctrl_loc); - let src = DepNode::new(ctrl_place, at, self.tcx, &self.body); - let edge = DepEdge::control(at, SourceUse::Operand, TargetUse::Assign); - out.push((src, edge)); - } - } - } - out - } - - /// Returns the aliases of `place`. See [`PlaceInfo::aliases`] for details. - pub(crate) fn aliases(&self, place: Place<'tcx>) -> impl Iterator> + '_ { - // MASSIVE HACK ALERT: - // The issue is that monomorphization erases regions, due to how it's implemented in rustc. - // However, Flowistry's alias analysis uses regions to figure out aliases. - // To workaround this incompatibility, when we receive a monomorphized place, we try to - // recompute its type in the context of the original region-containing body as far as possible. - // - // For example, say _2: (&'0 impl Foo,) in the original body and _2: (&(i32, i32),) in the monomorphized body. - // Say we ask for aliases (*(_2.0)).0. Then we will retype ((*_2.0).0).0 and receive back (*_2.0: &'0 impl Foo). - // We can ask for the aliases in the context of the original body, receiving e.g. {_1}. - // Then we reproject the aliases with the remaining projection, to create {_1.0}. - // - // This is a massive hack bc it's inefficient and I'm not certain that it's sound. - let place_retyped = utils::retype_place( - place, - self.tcx, - &self.body_with_facts.body, - self.def_id.to_def_id(), - ); - self.place_info.aliases(place_retyped).iter().map(|alias| { - let mut projection = alias.projection.to_vec(); - projection.extend(&place.projection[place_retyped.projection.len()..]); - Place::make(alias.local, &projection, self.tcx) - }) - } - - /// Returns all nodes `src` such that `src` is: - /// 1. Part of the value of `input` - /// 2. The most-recently modified location for `src` - fn find_data_inputs( - &self, - state: &InstructionState<'tcx>, - input: Place<'tcx>, - ) -> Vec> { - // Include all sources of indirection (each reference in the chain) as relevant places. - let provenance = input - .refs_in_projection(self.place_info.body, self.place_info.tcx) - .map(|(place_ref, _)| Place::from_ref(place_ref, self.tcx)); - let inputs = iter::once(input).chain(provenance); - - inputs - // **POINTER-SENSITIVITY:** - // If `input` involves indirection via dereferences, then resolve it to the direct places it could point to. - .flat_map(|place| self.aliases(place)) - .flat_map(|alias| { - // **FIELD-SENSITIVITY:** - // Find all places that have been mutated which conflict with `alias.` - let conflicts = state - .last_mutation - .iter() - .map(|(k, locs)| (*k, locs)) - .filter(move |(place, _)| { - if place.is_indirect() && place.is_arg(&self.body) { - // HACK: `places_conflict` seems to consider it a bug is `borrow_place` - // includes a dereference, which should only happen if `borrow_place` - // is an argument. So we special case that condition and just compare for local equality. - // - // TODO: this is not field-sensitive! - place.local == alias.local - } else { - let mut place = *place; - if let Some((PlaceElem::Deref, rest)) = place.projection.split_last() { - let mut new_place = place; - new_place.projection = self.tcx.mk_place_elems(rest); - if new_place.ty(self.body.as_ref(), self.tcx).ty.is_box() { - if new_place.is_indirect() { - // TODO might be unsound: We assume that if - // there are other indirections in here, - // there is an alias that does not have - // indirections in it. - return false; - } - place = new_place; - } - } - places_conflict( - self.tcx, - &self.body, - place, - alias, - PlaceConflictBias::Overlap, - ) - } - }); - - // Special case: if the `alias` is an un-mutated argument, then include it as a conflict - // coming from the special start location. - let alias_last_mut = if alias.is_arg(&self.body) { - Some((alias, &self.start_loc)) - } else { - None - }; - - // For each `conflict`` last mutated at the locations `last_mut`: - conflicts - .chain(alias_last_mut) - .flat_map(|(conflict, last_mut_locs)| { - // For each last mutated location: - last_mut_locs.iter().map(move |last_mut_loc| { - // Return @ as an input node. - let at = self.make_call_string(*last_mut_loc); - DepNode::new(conflict, at, self.tcx, &self.body) - }) - }) - }) - .collect() - } - - fn find_outputs( - &self, - _state: &InstructionState<'tcx>, - mutated: Place<'tcx>, - location: Location, - ) -> Vec<(Place<'tcx>, DepNode<'tcx>)> { - // **POINTER-SENSITIVITY:** - // If `mutated` involves indirection via dereferences, then resolve it to the direct places it could point to. - let aliases = self.aliases(mutated); - - // **FIELD-SENSITIVITY:** we do NOT deal with fields on *writes* (in this function), - // only on *reads* (in `add_input_to_op`). - - // For each mutated `dst`: - aliases - .map(|dst| { - // Create a destination node for (DST @ CURRENT_LOC). - ( - dst, - DepNode::new(dst, self.make_call_string(location), self.tcx, &self.body), - ) - }) - .collect() - } - - /// Updates the last-mutated location for `dst` to the given `location`. - fn apply_mutation( - &self, - state: &mut InstructionState<'tcx>, - location: Location, - mutated: Place<'tcx>, - ) { - self.find_outputs(state, mutated, location) - .into_iter() - .for_each(|(dst, _)| { - // Create a destination node for (DST @ CURRENT_LOC). - - // Clear all previous mutations. - let dst_mutations = state.last_mutation.entry(dst).or_default(); - dst_mutations.clear(); - - // Register that `dst` is mutated at the current location. - dst_mutations.insert(RichLocation::Location(location)); - }) - } - - /// Resolve a function [`Operand`] to a specific [`DefId`] and generic arguments if possible. - pub(crate) fn operand_to_def_id( - &self, - func: &Operand<'tcx>, - ) -> Option<(DefId, &'tcx List>)> { - let ty = match func { - Operand::Constant(func) => func.literal.ty(), - Operand::Copy(place) | Operand::Move(place) => { - place.ty(&self.body.local_decls, self.tcx).ty - } - }; - let ty = utils::ty_resolve(ty, self.tcx); - match ty.kind() { - TyKind::FnDef(def_id, generic_args) => Some((*def_id, generic_args)), - TyKind::Generator(def_id, generic_args, _) => Some((*def_id, generic_args)), - ty => { - trace!("Bailing from handle_call because func is literal with type: {ty:?}"); - None - } - } - } - - fn fmt_fn(&self, def_id: DefId) -> String { - self.tcx.def_path_str(def_id) - } - - /// Special case behavior for calls to functions used in desugaring async functions. - /// - /// Ensures that functions like `Pin::new_unchecked` are not modularly-approximated. - fn can_approximate_async_functions(&self, def_id: DefId) -> Option> { - let lang_items = self.tcx.lang_items(); - if Some(def_id) == lang_items.new_unchecked_fn() { - Some(Self::approximate_new_unchecked) - } else if Some(def_id) == lang_items.into_future_fn() - // FIXME: better way to get retrieve this stdlib DefId? - || self.tcx.def_path_str(def_id) == "::into_future" - { - Some(Self::approximate_into_future) - } else { - None - } - } - - fn approximate_into_future( - &self, - vis: &mut dyn Visitor<'tcx>, - args: &[Operand<'tcx>], - destination: Place<'tcx>, - location: Location, - ) { - trace!("Handling into_future as assign for {destination:?}"); - let [op] = args else { - unreachable!(); - }; - vis.visit_assign(&destination, &Rvalue::Use(op.clone()), location); - } - - fn approximate_new_unchecked( - &self, - vis: &mut dyn Visitor<'tcx>, - args: &[Operand<'tcx>], - destination: Place<'tcx>, - location: Location, - ) { - let lang_items = self.tcx.lang_items(); - let [op] = args else { - unreachable!(); - }; - let mut operands = IndexVec::new(); - operands.push(op.clone()); - let TyKind::Adt(adt_id, generics) = destination.ty(self.body.as_ref(), self.tcx).ty.kind() - else { - unreachable!() - }; - assert_eq!(adt_id.did(), lang_items.pin_type().unwrap()); - let aggregate_kind = - AggregateKind::Adt(adt_id.did(), VariantIdx::from_u32(0), generics, None, None); - let rvalue = Rvalue::Aggregate(Box::new(aggregate_kind), operands); - trace!("Handling new_unchecked as assign for {destination:?}"); - vis.visit_assign(&destination, &rvalue, location); - } - - fn determine_call_handling<'a>( - &self, - location: Location, - func: &Operand<'tcx>, - args: &'a [Operand<'tcx>], - ) -> Option> { - let tcx = self.tcx; - - let (called_def_id, generic_args) = self.operand_to_def_id(func)?; - trace!("Resolved call to function: {}", self.fmt_fn(called_def_id)); - - // Monomorphize the called function with the known generic_args. - let param_env = tcx.param_env_reveal_all_normalized(self.def_id); - let resolved_fn = - utils::try_resolve_function(self.tcx, called_def_id, param_env, generic_args); - let resolved_def_id = resolved_fn.def_id(); - if log_enabled!(Level::Trace) && called_def_id != resolved_def_id { - let (called, resolved) = (self.fmt_fn(called_def_id), self.fmt_fn(resolved_def_id)); - trace!(" `{called}` monomorphized to `{resolved}`",); - } - - if is_non_default_trait_method(tcx, resolved_def_id).is_some() { - trace!(" bailing because is unresolvable trait method"); - return None; - } - - // Don't inline recursive calls. - if let Some(cx) = &self.calling_context { - if cx.call_stack.contains(&resolved_def_id) { - trace!(" Bailing due to recursive call"); - return None; - } - } - - if let Some(handler) = self.can_approximate_async_functions(resolved_def_id) { - return Some(CallHandling::ApproxAsyncSM(handler)); - }; - - if !resolved_def_id.is_local() { - trace!( - " Bailing because func is non-local: `{}`", - tcx.def_path_str(resolved_def_id) - ); - return None; - }; - - let call_kind = match self.classify_call_kind(called_def_id, resolved_def_id, args) { - Ok(cc) => cc, - Err(async_err) => { - if let Some(cb) = self.params.call_change_callback.as_ref() { - cb.on_inline_miss( - resolved_fn, - location, - self.params.root, - self.calling_context.as_ref().map(|s| s.call_string), - InlineMissReason::Async(async_err), - ) - } - return None; - } - }; - - let calling_convention = CallingConvention::from_call_kind(&call_kind, args); - - trace!( - " Handling call! with kind {}", - match &call_kind { - CallKind::Direct => "direct", - CallKind::Indirect => "indirect", - CallKind::AsyncPoll { .. } => "async poll", - } - ); - - // Recursively generate the PDG for the child function. - let params = self.pdg_params_for_call(resolved_fn); - let calling_context = self.calling_context_for(resolved_def_id, location); - let call_string = calling_context.call_string; - - let cache_key = call_string.push(GlobalLocation { - function: resolved_fn.def_id().expect_local(), - location: RichLocation::Start, - }); - - let is_cached = self.pdg_cache.is_in_cache(&cache_key); - - let call_changes = self.params.call_change_callback.as_ref().map(|callback| { - let info = CallInfo { - callee: resolved_fn, - call_string, - is_cached, - async_parent: if let CallKind::AsyncPoll(resolution, _loc, _) = call_kind { - // Special case for async. We ask for skipping not on the closure, but - // on the "async" function that created it. This is needed for - // consistency in skipping. Normally, when "poll" is inlined, mutations - // introduced by the creator of the future are not recorded and instead - // handled here, on the closure. But if the closure is skipped we need - // those mutations to occur. To ensure this we always ask for the - // "CallChanges" on the creator so that both creator and closure have - // the same view of whether they are inlined or "Skip"ped. - Some(resolution) - } else { - None - }, - }; - callback.on_inline(info) - }); - - // Handle async functions at the time of polling, not when the future is created. - if tcx.asyncness(resolved_def_id).is_async() { - trace!(" Bailing because func is async"); - - // If a skip was requested then "poll" will not be inlined later so we - // bail with "None" here and perform the mutations. Otherwise we bail with - // "Some", knowing that handling "poll" later will handle the mutations. - return (!matches!( - &call_changes, - Some(CallChanges { - skip: SkipCall::Skip, - .. - }) - )) - .then_some(CallHandling::ApproxAsyncFn); - } - - if matches!( - call_changes, - Some(CallChanges { - skip: SkipCall::Skip, - .. - }) - ) { - trace!(" Bailing because user callback said to bail"); - return None; - } - - let child_constructor = GraphConstructor::new( - params, - Some(calling_context), - self.async_info.clone(), - &self.pdg_cache, - ); - Some(CallHandling::Ready(child_constructor, calling_convention)) - } - - /// Attempt to inline a call to a function, returning None if call is not inline-able. - fn handle_call( - &self, - state: &mut InstructionState<'tcx>, - location: Location, - func: &Operand<'tcx>, - args: &[Operand<'tcx>], - destination: Place<'tcx>, - ) -> Option<()> { - // Note: my comments here will use "child" to refer to the callee and - // "parent" to refer to the caller, since the words are most visually distinct. - - let preamble = self.determine_call_handling(location, func, args)?; - - let (child_constructor, calling_convention) = match preamble { - CallHandling::Ready(child_constructor, calling_convention) => { - (child_constructor, calling_convention) - } - CallHandling::ApproxAsyncFn => { - // Register a synthetic assignment of `future = (arg0, arg1, ...)`. - let rvalue = Rvalue::Aggregate( - Box::new(AggregateKind::Tuple), - IndexVec::from_iter(args.iter().cloned()), - ); - self.modular_mutation_visitor(state) - .visit_assign(&destination, &rvalue, location); - return Some(()); - } - CallHandling::ApproxAsyncSM(handler) => { - handler( - self, - &mut self.modular_mutation_visitor(state), - args, - destination, - location, - ); - return Some(()); - } - }; - - let child_graph = child_constructor.construct_partial_cached(); - - let parentable_dsts = - child_graph.parentable_dsts(child_constructor.def_id, &child_constructor.body); - let parent_body = &self.body; - let translate_to_parent = |child: Place<'tcx>| -> Option> { - calling_convention.translate_to_parent( - child, - &self.async_info, - self.tcx, - parent_body, - self.def_id.to_def_id(), - destination, - ) - }; - - // For each destination node CHILD that is parentable to PLACE, - // add an edge from CHILD -> PLACE. - // - // PRECISION TODO: for a given child place, we only want to connect - // the *last* nodes in the child function to the parent, not *all* of them. - trace!("CHILD -> PARENT EDGES:"); - for (child_dst, _) in parentable_dsts { - if let Some(parent_place) = translate_to_parent(child_dst.place) { - self.apply_mutation(state, location, parent_place); - } - } - trace!( - " Inlined {}", - self.fmt_fn(child_constructor.def_id.to_def_id()) - ); - - Some(()) - } - - fn modular_mutation_visitor<'a>( - &'a self, - state: &'a mut InstructionState<'tcx>, - ) -> ModularMutationVisitor<'a, 'tcx, impl FnMut(Location, Mutation<'tcx>) + 'a> { - ModularMutationVisitor::new( - &self.place_info, - move |location, mutation: Mutation<'tcx>| { - self.apply_mutation(state, location, mutation.mutated) - }, - ) - } - - fn handle_terminator( - &self, - terminator: &Terminator<'tcx>, - state: &mut InstructionState<'tcx>, - location: Location, - time: Time, - ) { - if let TerminatorKind::Call { - func, - args, - destination, - .. - } = &terminator.kind - { - if self - .handle_call(state, location, func, args, *destination) - .is_none() - { - trace!("Terminator {:?} failed the preamble", terminator.kind); - self.terminator_visitor(state, time) - .visit_terminator(terminator, location) - } - } else { - // Fallback: call the visitor - self.terminator_visitor(state, time) - .visit_terminator(terminator, location) - } - } - - fn construct_partial_cached(&self) -> Rc> { - let key = self.make_call_string(RichLocation::Start); - let pdg = self - .pdg_cache - .get(key, move |_| Rc::new(self.construct_partial())); - Rc::clone(pdg) - } - - pub(crate) fn construct_partial(&self) -> PartialGraph<'tcx> { - if let Some(g) = self.try_handle_as_async() { - return g; - } - - let mut analysis = DfAnalysis(self) - .into_engine(self.tcx, &self.body) - .iterate_to_fixpoint(); - - let mut final_state = PartialGraph::default(); - - analysis.visit_reachable_with(&self.body, &mut final_state); - - let all_returns = self.body.all_returns().map(|ret| ret.block).collect_vec(); - let has_return = !all_returns.is_empty(); - let mut analysis = analysis.into_results_cursor(&self.body); - if has_return { - for block in all_returns { - analysis.seek_to_block_end(block); - let return_state = analysis.get(); - for (place, locations) in &return_state.last_mutation { - let ret_kind = if place.local == RETURN_PLACE { - TargetUse::Return - } else if let Some(num) = other_as_arg(*place, &self.body) { - TargetUse::MutArg(num) - } else { - continue; - }; - for location in locations { - let src = self.make_dep_node(*place, *location); - let dst = self.make_dep_node(*place, RichLocation::End); - let edge = DepEdge::data( - self.make_call_string(self.body.terminator_loc(block)), - SourceUse::Operand, - ret_kind, - ); - final_state.edges.insert((src, dst, edge)); - } - } - } - } - - final_state - } - - fn domain_to_petgraph(self, domain: &PartialGraph<'tcx>) -> DepGraph<'tcx> { +impl<'tcx> PartialGraph<'tcx> { + pub fn to_petgraph(&self) -> DepGraph<'tcx> { + let domain = self; let mut graph: DiGraph, DepEdge> = DiGraph::new(); let mut nodes = FxHashMap::default(); macro_rules! add_node { @@ -1249,113 +759,13 @@ impl<'tcx> GraphConstructor<'tcx> { DepGraph::new(graph) } - pub fn construct(self) -> DepGraph<'tcx> { - let partial = self.construct_partial_cached(); - self.domain_to_petgraph(&partial) - } - - /// Determine the type of call-site. - /// - /// The error case is if we tried to resolve this as async and failed. We - /// know it *is* async but we couldn't determine the information needed to - /// analyze the function, therefore we will have to approximate it. - fn classify_call_kind<'a>( - &'a self, - def_id: DefId, - resolved_def_id: DefId, - original_args: &'a [Operand<'tcx>], - ) -> Result, String> { - match self.try_poll_call_kind(def_id, original_args) { - AsyncDeterminationResult::Resolved(r) => Ok(r), - AsyncDeterminationResult::NotAsync => Ok(self - .try_indirect_call_kind(resolved_def_id) - .unwrap_or(CallKind::Direct)), - AsyncDeterminationResult::Unresolvable(reason) => Err(reason), + fn check_invariants(&self) { + let root_function = self.nodes.iter().next().unwrap().at.root().function; + for n in &self.nodes { + assert_eq!(n.at.root().function, root_function); + } + for (_, _, e) in &self.edges { + assert_eq!(e.at.root().function, root_function); } - } - - fn try_indirect_call_kind(&self, def_id: DefId) -> Option> { - // let lang_items = self.tcx.lang_items(); - // let my_impl = self.tcx.impl_of_method(def_id)?; - // let my_trait = self.tcx.trait_id_of_impl(my_impl)?; - // (Some(my_trait) == lang_items.fn_trait() - // || Some(my_trait) == lang_items.fn_mut_trait() - // || Some(my_trait) == lang_items.fn_once_trait()) - // .then_some(CallKind::Indirect) - self.tcx.is_closure(def_id).then_some(CallKind::Indirect) - } - - fn terminator_visitor<'a>( - &'a self, - state: &'a mut InstructionState<'tcx>, - time: Time, - ) -> ModularMutationVisitor<'a, 'tcx, impl FnMut(Location, Mutation<'tcx>) + 'a> { - let mut vis = self.modular_mutation_visitor(state); - vis.set_time(time); - vis - } -} - -pub enum CallKind<'tcx> { - /// A standard function call like `f(x)`. - Direct, - /// A call to a function variable, like `fn foo(f: impl Fn()) { f() }` - Indirect, - /// A poll to an async function, like `f.await`. - AsyncPoll(FnResolution<'tcx>, Location, Place<'tcx>), -} - -type ApproximationHandler<'tcx> = - fn(&GraphConstructor<'tcx>, &mut dyn Visitor<'tcx>, &[Operand<'tcx>], Place<'tcx>, Location); - -enum CallHandling<'tcx, 'a> { - ApproxAsyncFn, - Ready(GraphConstructor<'tcx>, CallingConvention<'tcx, 'a>), - ApproxAsyncSM(ApproximationHandler<'tcx>), -} - -struct DfAnalysis<'a, 'tcx>(&'a GraphConstructor<'tcx>); - -impl<'tcx> df::AnalysisDomain<'tcx> for DfAnalysis<'_, 'tcx> { - type Domain = InstructionState<'tcx>; - - const NAME: &'static str = "GraphConstructor"; - - fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain { - InstructionState::default() - } - - fn initialize_start_block(&self, _body: &Body<'tcx>, _state: &mut Self::Domain) {} -} - -impl<'tcx> df::Analysis<'tcx> for DfAnalysis<'_, 'tcx> { - fn apply_statement_effect( - &mut self, - state: &mut Self::Domain, - statement: &Statement<'tcx>, - location: Location, - ) { - self.0 - .modular_mutation_visitor(state) - .visit_statement(statement, location) - } - - fn apply_terminator_effect<'mir>( - &mut self, - state: &mut Self::Domain, - terminator: &'mir Terminator<'tcx>, - location: Location, - ) -> TerminatorEdges<'mir, 'tcx> { - self.0 - .handle_terminator(terminator, state, location, Time::Unspecified); - terminator.edges() - } - - fn apply_call_return_effect( - &mut self, - _state: &mut Self::Domain, - _block: BasicBlock, - _return_places: rustc_middle::mir::CallReturnPlaces<'_, 'tcx>, - ) { } } diff --git a/crates/flowistry_pdg_construction/src/graph.rs b/crates/flowistry_pdg_construction/src/graph.rs index c64a0ded28..7275eaba95 100644 --- a/crates/flowistry_pdg_construction/src/graph.rs +++ b/crates/flowistry_pdg_construction/src/graph.rs @@ -1,23 +1,38 @@ //! The representation of the PDG. -use std::{fmt, hash::Hash, path::Path}; +use std::{ + fmt::{self, Display}, + hash::Hash, + path::Path, + rc::Rc, +}; -use flowistry_pdg::CallString; +use flowistry_pdg::{CallString, GlobalLocation}; use internment::Intern; use petgraph::{dot, graph::DiGraph}; + +use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hir::def_id::{DefId, DefIndex}; +use rustc_index::IndexVec; +use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable}; use rustc_middle::{ - mir::{Body, Place}, - ty::TyCtxt, + mir::{Body, HasLocalDecls, Local, LocalDecl, LocalDecls, Place}, + ty::{GenericArgsRef, TyCtxt}, }; +use rustc_serialize::{Decodable, Decoder, Encodable, Encoder}; +use rustc_span::Span; use rustc_utils::PlaceExt; -pub use flowistry_pdg::{SourceUse, TargetUse}; +pub use flowistry_pdg::{RichLocation, SourceUse, TargetUse}; +use serde::{Deserialize, Serialize}; + +use crate::{construct::Error, utils::Captures}; /// A node in the program dependency graph. /// /// Represents a place at a particular call-string. /// The place is in the body of the root of the call-string. -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, TyEncodable, TyDecodable)] pub struct DepNode<'tcx> { /// A place in memory in a particular body. pub place: Place<'tcx>, @@ -28,7 +43,10 @@ pub struct DepNode<'tcx> { /// Pretty representation of the place. /// This is cached as an interned string on [`DepNode`] because to compute it later, /// we would have to regenerate the entire monomorphized body for a given place. - pub(crate) place_pretty: Option>, + pub(crate) place_pretty: Option, + /// Does the PDG track subplaces of this place? + pub is_split: bool, + pub span: Span, } impl PartialEq for DepNode<'_> { @@ -40,8 +58,15 @@ impl PartialEq for DepNode<'_> { place, at, place_pretty: _, + span, + is_split, } = *self; - (place, at).eq(&(other.place, other.at)) + let eq = (place, at).eq(&(other.place, other.at)); + if eq { + debug_assert_eq!(span, other.span); + debug_assert_eq!(is_split, other.is_split); + } + eq } } @@ -56,6 +81,8 @@ impl Hash for DepNode<'_> { place, at, place_pretty: _, + span: _, + is_split: _, } = self; (place, at).hash(state) } @@ -66,11 +93,29 @@ impl<'tcx> DepNode<'tcx> { /// /// The `tcx` and `body` arguments are used to precompute a pretty string /// representation of the [`DepNode`]. - pub fn new(place: Place<'tcx>, at: CallString, tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> Self { + pub fn new( + place: Place<'tcx>, + at: CallString, + tcx: TyCtxt<'tcx>, + body: &Body<'tcx>, + is_split: bool, + ) -> Self { + let i = at.leaf(); + let span = match i.location { + RichLocation::Location(loc) => { + let expanded_span = body + .stmt_at(loc) + .either(|s| s.source_info.span, |t| t.source_info.span); + tcx.sess.source_map().stmt_span(expanded_span, body.span) + } + RichLocation::Start | RichLocation::End => tcx.def_span(i.function), + }; DepNode { place, at, - place_pretty: place.to_string(tcx, body).map(Intern::new), + place_pretty: place.to_string(tcx, body).map(InternedString::new), + span, + is_split, } } } @@ -78,7 +123,7 @@ impl<'tcx> DepNode<'tcx> { impl DepNode<'_> { /// Returns a pretty string representation of the place, if one exists. pub fn place_pretty(&self) -> Option<&str> { - self.place_pretty.map(|s| s.as_ref().as_str()) + self.place_pretty.as_ref().map(|s| s.as_str()) } } @@ -93,7 +138,7 @@ impl fmt::Display for DepNode<'_> { } /// A kind of edge in the program dependence graph. -#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug)] +#[derive(PartialEq, Eq, Hash, Clone, Copy, Debug, Decodable, Encodable)] pub enum DepEdgeKind { /// X is control-dependent on Y if the value of Y influences the execution /// of statements that affect the value of X. @@ -107,7 +152,7 @@ pub enum DepEdgeKind { /// An edge in the program dependence graph. /// /// Represents an operation that induces a dependency between places. -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Decodable, Encodable)] pub struct DepEdge { /// Either data or control. pub kind: DepEdgeKind, @@ -189,3 +234,269 @@ impl<'tcx> DepGraph<'tcx> { rustc_utils::mir::body::run_dot(path.as_ref(), graph_dot.into_bytes()) } } + +#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd, Debug, Serialize, Deserialize)] +pub struct InternedString(Intern); + +impl Display for InternedString { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl InternedString { + pub fn new(s: String) -> Self { + Self(Intern::new(s)) + } + + pub fn as_str(&self) -> &str { + &self.0 + } +} + +impl From for InternedString { + fn from(value: String) -> Self { + Self::new(value) + } +} + +impl From<&'_ str> for InternedString { + fn from(value: &'_ str) -> Self { + Self(Intern::from_ref(value)) + } +} + +impl std::ops::Deref for InternedString { + type Target = String; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl Encodable for InternedString { + fn encode(&self, e: &mut E) { + let s: &String = &self.0; + s.encode(e); + } +} + +impl Decodable for InternedString { + fn decode(d: &mut D) -> Self { + Self(Intern::new(String::decode(d))) + } +} + +/// A PDG that is fit for combining with other PDGs +#[derive(Debug, Clone, TyDecodable, TyEncodable)] +pub struct PartialGraph<'tcx> { + pub(crate) nodes: FxHashSet>, + pub(crate) edges: FxHashSet<(DepNode<'tcx>, DepNode<'tcx>, DepEdge)>, + pub(crate) monos: FxHashMap>, + pub(crate) generics: GenericArgsRef<'tcx>, + def_id: DefId, + arg_count: usize, + local_decls: IndexVec>, +} + +impl<'tcx> HasLocalDecls<'tcx> for PartialGraph<'tcx> { + fn local_decls(&self) -> &LocalDecls<'tcx> { + &self.local_decls + } +} + +impl<'tcx> PartialGraph<'tcx> { + pub fn mentioned_call_string<'a>( + &'a self, + ) -> impl Iterator + Captures<'tcx> + 'a { + self.nodes + .iter() + .map(|n| &n.at) + .chain(self.edges.iter().map(|e| &e.2.at)) + .copied() + } + + pub fn get_mono(&self, cs: CallString) -> Option> { + if let Some(caller) = cs.caller() { + self.monos.get(&caller).copied() + } else { + Some(self.generics) + } + } + + pub fn new( + generics: GenericArgsRef<'tcx>, + def_id: DefId, + arg_count: usize, + local_decls: &LocalDecls<'tcx>, + ) -> Self { + Self { + nodes: Default::default(), + edges: Default::default(), + monos: Default::default(), + generics, + def_id, + arg_count, + local_decls: local_decls.to_owned(), + } + } + + /// Returns the set of source places that the parent can access (write to) + /// + /// Parameterized by a `is_at_root` function which returns whether a given + /// call string refers to a location in the outermost function. This is + /// necessary, because consumers of [`PartialGraph`] manipulate the call + /// string and as such we cannot assume that `.len() == 1` necessarily refers + /// to a root location. (TODO we probably should maintain that invariant) + pub(crate) fn parentable_srcs<'a>( + &'a self, + is_at_root: impl Fn(CallString) -> bool, + ) -> FxHashSet<(DepNode<'tcx>, Option)> { + self.edges + .iter() + .map(|(src, _, _)| *src) + .filter(|n| is_at_root(n.at) && n.at.leaf().location.is_start()) + .filter_map(move |a| Some((a, as_arg(&a, self.def_id, self.arg_count)?))) + .collect() + } + + /// Returns the set of destination places that the parent can access (read + /// from) + /// + /// Parameterized by a `is_at_root` function which returns whether a given + /// call string refers to a location in the outermost function. This is + /// necessary, because consumers of [`PartialGraph`] manipulate the call + /// string and as such we cannot assume that `.len() == 1` necessarily refers + /// to a root location. (TODO we probably should maintain that invariant) + pub(crate) fn parentable_dsts<'a>( + &'a self, + is_at_root: impl Fn(CallString) -> bool, + ) -> FxHashSet<(DepNode<'tcx>, Option)> { + self.edges + .iter() + .map(|(_, dst, _)| *dst) + .filter(|n| is_at_root(n.at) && n.at.leaf().location.is_end()) + .filter_map(move |a| Some((a, as_arg(&a, self.def_id, self.arg_count)?))) + .collect() + } +} + +fn as_arg(node: &DepNode<'_>, def_id: DefId, arg_count: usize) -> Option> { + if node.at.leaf().function != def_id { + return None; + } + let local = node.place.local.as_usize(); + if node.place.local == rustc_middle::mir::RETURN_PLACE { + Some(None) + } else if local > 0 && (local - 1) < arg_count { + Some(Some(node.place.local.as_u32() as u8 - 1)) + } else { + None + } +} + +impl<'tcx> TransformCallString for PartialGraph<'tcx> { + fn transform_call_string(&self, f: impl Fn(CallString) -> CallString) -> Self { + let recurse_node = |n: &DepNode<'tcx>| n.transform_call_string(&f); + Self { + generics: self.generics, + nodes: self.nodes.iter().map(recurse_node).collect(), + edges: self + .edges + .iter() + .map(|(from, to, e)| { + ( + recurse_node(from), + recurse_node(to), + e.transform_call_string(&f), + ) + }) + .collect(), + monos: self + .monos + .iter() + .map(|(cs, args)| (f(*cs), *args)) + .collect(), + def_id: self.def_id, + arg_count: self.arg_count, + local_decls: self.local_decls.to_owned(), + } + } +} + +pub type GraphLoaderError<'tcx> = Vec>; + +/// Abstracts over how previously written [`Artifact`]s are retrieved, allowing +/// the user of this module to chose where to store them. +pub trait GraphLoader<'tcx> { + /// Try loading the graph for this function. + /// + /// This is intended to return `Err` in cases where an expectation is + /// violated. For instance if we request a function from a crate that + /// *should* have been analyzed or if `function` does not refer to a + /// function item. + /// + /// This should return `Ok(None)` in cases where the target is not expected + /// to have it's partial graph present. For instance if `function` refers to + /// an item in a crate that was not selected for analysis. + fn load(&self, function: DefId) -> Result>, GraphLoaderError<'tcx>>; +} + +/// Intermediate data that gets stored for each crate. +pub type Artifact<'tcx> = FxHashMap>; + +/// An [`ArtifactLoader`] that always returns `Ok(None)`. +pub struct NoLoader; + +impl<'tcx> GraphLoader<'tcx> for NoLoader { + fn load(&self, _: DefId) -> Result>, GraphLoaderError<'tcx>> { + Ok(None) + } +} + +impl<'tcx, T: GraphLoader<'tcx>> GraphLoader<'tcx> for Rc { + fn load(&self, function: DefId) -> Result>, GraphLoaderError<'tcx>> { + (**self).load(function) + } +} + +impl<'tcx, T: GraphLoader<'tcx>> GraphLoader<'tcx> for Box { + fn load(&self, function: DefId) -> Result>, GraphLoaderError<'tcx>> { + (**self).load(function) + } +} + +pub(crate) trait TransformCallString { + fn transform_call_string(&self, f: impl Fn(CallString) -> CallString) -> Self; +} + +impl TransformCallString for CallString { + fn transform_call_string(&self, f: impl Fn(CallString) -> CallString) -> Self { + f(*self) + } +} + +impl TransformCallString for DepNode<'_> { + fn transform_call_string(&self, f: impl Fn(CallString) -> CallString) -> Self { + Self { + at: f(self.at), + ..*self + } + } +} + +impl TransformCallString for DepEdge { + fn transform_call_string(&self, f: impl Fn(CallString) -> CallString) -> Self { + Self { + at: f(self.at), + ..*self + } + } +} + +pub(crate) fn push_call_string_root( + old: &T, + new_root: GlobalLocation, +) -> T { + old.transform_call_string(|c| c.push_front(new_root)) +} diff --git a/crates/flowistry_pdg_construction/src/lib.rs b/crates/flowistry_pdg_construction/src/lib.rs index 4efc0f8988..c306579d04 100644 --- a/crates/flowistry_pdg_construction/src/lib.rs +++ b/crates/flowistry_pdg_construction/src/lib.rs @@ -7,33 +7,42 @@ extern crate rustc_borrowck; extern crate rustc_hash; extern crate rustc_hir; extern crate rustc_index; +extern crate rustc_macros; extern crate rustc_middle; extern crate rustc_mir_dataflow; +extern crate rustc_serialize; extern crate rustc_span; extern crate rustc_target; extern crate rustc_type_ir; -pub use utils::FnResolution; - -use self::graph::DepGraph; -pub use async_support::{determine_async, is_async_trait_fn, match_async_trait_assign}; -use construct::GraphConstructor; +pub use async_support::{determine_async, is_async_trait_fn, AsyncType}; +pub use construct::Error; +pub use graph::{Artifact, DepGraph, GraphLoader, NoLoader, PartialGraph}; pub mod callback; +pub use crate::construct::{ + default_emit_error, EmittableError, MemoPdgConstructor, UnwrapEmittable, +}; pub use callback::{ CallChangeCallback, CallChangeCallbackFn, CallChanges, CallInfo, InlineMissReason, SkipCall, }; -pub use construct::PdgParams; -pub use utils::{is_non_default_trait_method, try_resolve_function}; +use rustc_middle::ty::{Instance, TyCtxt}; +mod approximation; mod async_support; mod calling_convention; mod construct; pub mod graph; +mod local_analysis; mod mutation; -mod utils; +pub mod utils; /// Computes a global program dependence graph (PDG) starting from the root function specified by `def_id`. -pub fn compute_pdg(params: PdgParams<'_>) -> DepGraph<'_> { - let constructor = GraphConstructor::root(params); - constructor.construct() +pub fn compute_pdg<'tcx>(tcx: TyCtxt<'tcx>, params: Instance<'tcx>) -> DepGraph<'tcx> { + let constructor = MemoPdgConstructor::new(tcx, NoLoader); + constructor + .construct_for(params) + .unwrap() + .ok_or(Error::Impossible) + .unwrap() + .to_petgraph() } diff --git a/crates/flowistry_pdg_construction/src/local_analysis.rs b/crates/flowistry_pdg_construction/src/local_analysis.rs new file mode 100644 index 0000000000..b44886d01e --- /dev/null +++ b/crates/flowistry_pdg_construction/src/local_analysis.rs @@ -0,0 +1,776 @@ +use std::{collections::HashSet, iter, rc::Rc}; + +use flowistry::mir::placeinfo::PlaceInfo; +use flowistry_pdg::{CallString, GlobalLocation, RichLocation}; +use itertools::Itertools; +use log::{debug, log_enabled, trace, Level}; + +use rustc_borrowck::consumers::{places_conflict, BodyWithBorrowckFacts, PlaceConflictBias}; +use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_index::IndexVec; +use rustc_middle::{ + mir::{ + visit::Visitor, AggregateKind, BasicBlock, Body, HasLocalDecls, Location, Operand, Place, + PlaceElem, Rvalue, Statement, Terminator, TerminatorEdges, TerminatorKind, RETURN_PLACE, + }, + ty::{GenericArg, GenericArgKind, GenericArgsRef, Instance, List, TyCtxt, TyKind}, +}; +use rustc_mir_dataflow::{self as df, fmt::DebugWithContext, Analysis}; + +use rustc_span::Span; +use rustc_utils::{ + mir::{borrowck_facts, control_dependencies::ControlDependencies}, + BodyExt, PlaceExt, +}; + +use crate::{ + approximation::ApproximationHandler, + async_support::*, + calling_convention::*, + construct::{Error, WithConstructionErrors}, + graph::{DepEdge, DepNode, PartialGraph, SourceUse, TargetUse}, + mutation::{ModularMutationVisitor, Mutation, Time}, + utils::{self, is_async, is_non_default_trait_method, try_monomorphize}, + CallChangeCallback, CallChanges, CallInfo, MemoPdgConstructor, SkipCall, +}; + +#[derive(PartialEq, Eq, Default, Clone, Debug)] +pub(crate) struct InstructionState<'tcx> { + last_mutation: FxHashMap, FxHashSet>, +} + +impl DebugWithContext for InstructionState<'_> {} + +impl<'tcx> df::JoinSemiLattice for InstructionState<'tcx> { + fn join(&mut self, other: &Self) -> bool { + utils::hashmap_join( + &mut self.last_mutation, + &other.last_mutation, + utils::hashset_join, + ) + } +} + +pub(crate) struct LocalAnalysis<'tcx, 'a> { + pub(crate) memo: &'a MemoPdgConstructor<'tcx>, + pub(super) root: Instance<'tcx>, + body_with_facts: &'tcx BodyWithBorrowckFacts<'tcx>, + pub(crate) body: Body<'tcx>, + pub(crate) def_id: LocalDefId, + pub(crate) place_info: PlaceInfo<'tcx>, + control_dependencies: ControlDependencies, + pub(crate) body_assignments: utils::BodyAssignments, + start_loc: FxHashSet, +} + +impl<'tcx, 'a> LocalAnalysis<'tcx, 'a> { + /// Creates [`GraphConstructor`] for a function resolved as `fn_resolution` in a given `calling_context`. + pub(crate) fn new( + memo: &'a MemoPdgConstructor<'tcx>, + root: Instance<'tcx>, + ) -> Result, Error<'tcx>> { + let tcx = memo.tcx; + let def_id = root.def_id().expect_local(); + let body_with_facts = borrowck_facts::get_body_with_borrowck_facts(tcx, def_id); + let param_env = tcx.param_env_reveal_all_normalized(def_id); + // let param_env = match &calling_context { + // Some(cx) => cx.param_env, + // None => ParamEnv::reveal_all(), + // }; + let body = try_monomorphize( + root, + tcx, + param_env, + &body_with_facts.body, + tcx.def_span(root.def_id()), + )?; + + if memo.dump_mir { + use std::io::Write; + let path = tcx.def_path_str(def_id) + ".mir"; + let mut f = std::fs::File::create(path.as_str()).unwrap(); + write!(f, "{}", body.to_string(tcx).unwrap()).unwrap(); + debug!("Dumped debug MIR {path}"); + } + + let place_info = PlaceInfo::build(tcx, def_id.to_def_id(), body_with_facts); + let control_dependencies = body.control_dependencies(); + + let mut start_loc = FxHashSet::default(); + start_loc.insert(RichLocation::Start); + + let body_assignments = utils::find_body_assignments(&body); + + Ok(LocalAnalysis { + memo, + root, + body_with_facts, + body, + place_info, + control_dependencies, + start_loc, + def_id, + body_assignments, + }) + } + + fn make_dep_node( + &self, + place: Place<'tcx>, + location: impl Into, + ) -> DepNode<'tcx> { + DepNode::new( + place, + self.make_call_string(location), + self.tcx(), + &self.body, + self.place_info.children(place).iter().any(|p| *p != place), + ) + } + + /// Returns all pairs of `(src, edge)`` such that the given `location` is control-dependent on `edge` + /// with input `src`. + pub(crate) fn find_control_inputs(&self, location: Location) -> Vec<(DepNode<'tcx>, DepEdge)> { + let mut blocks_seen = HashSet::::from_iter(Some(location.block)); + let mut block_queue = vec![location.block]; + let mut out = vec![]; + while let Some(block) = block_queue.pop() { + if let Some(ctrl_deps) = self.control_dependencies.dependent_on(block) { + for dep in ctrl_deps.iter() { + let ctrl_loc = self.body.terminator_loc(dep); + let Terminator { + kind: TerminatorKind::SwitchInt { discr, .. }, + .. + } = self.body.basic_blocks[dep].terminator() + else { + if blocks_seen.insert(dep) { + block_queue.push(dep); + } + continue; + }; + let Some(ctrl_place) = discr.place() else { + continue; + }; + let at = self.make_call_string(ctrl_loc); + let src = self.make_dep_node(ctrl_place, ctrl_loc); + let edge = DepEdge::control(at, SourceUse::Operand, TargetUse::Assign); + out.push((src, edge)); + } + } + } + out + } + + fn call_change_callback(&self) -> Option<&dyn CallChangeCallback<'tcx>> { + self.memo.call_change_callback.as_ref().map(Rc::as_ref) + } + + pub(crate) fn async_info(&self) -> &AsyncInfo { + &self.memo.async_info + } + + pub(crate) fn make_call_string(&self, location: impl Into) -> CallString { + CallString::single(GlobalLocation { + function: self.root.def_id(), + location: location.into(), + }) + } + + /// Returns the aliases of `place`. See [`PlaceInfo::aliases`] for details. + pub(crate) fn aliases(&'a self, place: Place<'tcx>) -> impl Iterator> + 'a { + // MASSIVE HACK ALERT: + // The issue is that monomorphization erases regions, due to how it's implemented in rustc. + // However, Flowistry's alias analysis uses regions to figure out aliases. + // To workaround this incompatibility, when we receive a monomorphized place, we try to + // recompute its type in the context of the original region-containing body as far as possible. + // + // For example, say _2: (&'0 impl Foo,) in the original body and _2: (&(i32, i32),) in the monomorphized body. + // Say we ask for aliases (*(_2.0)).0. Then we will retype ((*_2.0).0).0 and receive back (*_2.0: &'0 impl Foo). + // We can ask for the aliases in the context of the original body, receiving e.g. {_1}. + // Then we reproject the aliases with the remaining projection, to create {_1.0}. + // + // This is a massive hack bc it's inefficient and I'm not certain that it's sound. + let place_retyped = utils::retype_place( + place, + self.tcx(), + &self.body_with_facts.body, + self.def_id.to_def_id(), + None, + ); + self.place_info + .aliases(place_retyped) + .iter() + .map(move |alias| { + let mut projection = alias.projection.to_vec(); + projection.extend(&place.projection[place_retyped.projection.len()..]); + let p = Place::make(alias.local, &projection, self.tcx()); + // let t1 = place.ty(&self.body, self.tcx()); + // let t2 = p.ty(&self.body, self.tcx()); + // if !t1.equiv(&t2) { + // let p1_str = format!("{place:?}"); + // let p2_str = format!("{p:?}"); + // let l = p1_str.len().max(p2_str.len()); + // panic!("Retyping in {} failed to produce an equivalent type.\n Src {p1_str:l$} : {t1:?}\n Dst {p2_str:l$} : {t2:?}", self.tcx().def_path_str(self.def_id)) + // } + p + }) + } + + pub(crate) fn tcx(&self) -> TyCtxt<'tcx> { + self.memo.tcx + } + + /// Returns all nodes `src` such that `src` is: + /// 1. Part of the value of `input` + /// 2. The most-recently modified location for `src` + pub(crate) fn find_data_inputs( + &self, + state: &InstructionState<'tcx>, + input: Place<'tcx>, + ) -> Vec> { + // Include all sources of indirection (each reference in the chain) as relevant places. + let provenance = input + .refs_in_projection(&self.body, self.tcx()) + .map(|(place_ref, _)| Place::from_ref(place_ref, self.tcx())); + let inputs = iter::once(input).chain(provenance); + + inputs + // **POINTER-SENSITIVITY:** + // If `input` involves indirection via dereferences, then resolve it to the direct places it could point to. + .flat_map(|place| self.aliases(place)) + .flat_map(|alias| { + // **FIELD-SENSITIVITY:** + // Find all places that have been mutated which conflict with `alias.` + let conflicts = state + .last_mutation + .iter() + .map(|(k, locs)| (*k, locs)) + .filter(move |(place, _)| { + if place.is_indirect() && place.is_arg(&self.body) { + // HACK: `places_conflict` seems to consider it a bug is `borrow_place` + // includes a dereference, which should only happen if `borrow_place` + // is an argument. So we special case that condition and just compare for local equality. + // + // TODO: this is not field-sensitive! + place.local == alias.local + } else { + let mut place = *place; + if let Some((PlaceElem::Deref, rest)) = place.projection.split_last() { + let mut new_place = place; + new_place.projection = self.tcx().mk_place_elems(rest); + if new_place.ty(&self.body, self.tcx()).ty.is_box() { + if new_place.is_indirect() { + // TODO might be unsound: We assume that if + // there are other indirections in here, + // there is an alias that does not have + // indirections in it. + return false; + } + place = new_place; + } + } + places_conflict( + self.tcx(), + &self.body, + place, + alias, + PlaceConflictBias::Overlap, + ) + } + }); + + // Special case: if the `alias` is an un-mutated argument, then include it as a conflict + // coming from the special start location. + let alias_last_mut = if alias.is_arg(&self.body) { + Some((alias, &self.start_loc)) + } else { + None + }; + + // For each `conflict`` last mutated at the locations `last_mut`: + conflicts + .chain(alias_last_mut) + .flat_map(|(conflict, last_mut_locs)| { + // For each last mutated location: + last_mut_locs.iter().map(move |last_mut_loc| { + // Return @ as an input node. + self.make_dep_node(conflict, *last_mut_loc) + }) + }) + }) + .collect() + } + + pub(crate) fn find_outputs( + &self, + mutated: Place<'tcx>, + location: Location, + ) -> Vec<(Place<'tcx>, DepNode<'tcx>)> { + // **POINTER-SENSITIVITY:** + // If `mutated` involves indirection via dereferences, then resolve it to the direct places it could point to. + let aliases = self.aliases(mutated).collect_vec(); + + // **FIELD-SENSITIVITY:** we do NOT deal with fields on *writes* (in this function), + // only on *reads* (in `add_input_to_op`). + + // For each mutated `dst`: + aliases + .iter() + .map(|dst| { + // Create a destination node for (DST @ CURRENT_LOC). + (*dst, self.make_dep_node(*dst, location)) + }) + .collect() + } + + /// Updates the last-mutated location for `dst` to the given `location`. + fn apply_mutation( + &self, + state: &mut InstructionState<'tcx>, + location: Location, + mutated: Place<'tcx>, + ) { + self.find_outputs(mutated, location) + .into_iter() + .for_each(|(dst, _)| { + // Create a destination node for (DST @ CURRENT_LOC). + + // Clear all previous mutations. + let dst_mutations = state.last_mutation.entry(dst).or_default(); + dst_mutations.clear(); + + // Register that `dst` is mutated at the current location. + dst_mutations.insert(RichLocation::Location(location)); + }) + } + + /// Resolve a function [`Operand`] to a specific [`DefId`] and generic arguments if possible. + pub(crate) fn operand_to_def_id( + &self, + func: &Operand<'tcx>, + ) -> Option<(DefId, &'tcx List>)> { + let ty = func.ty(&self.body, self.tcx()); + utils::type_as_fn(self.tcx(), ty) + } + + fn fmt_fn(&self, def_id: DefId) -> String { + self.tcx().def_path_str(def_id) + } + + pub(crate) fn determine_call_handling<'b>( + &'b self, + location: Location, + func: &Operand<'tcx>, + args: &'b [Operand<'tcx>], + span: Span, + ) -> Result>, Vec>> { + let tcx = self.tcx(); + + trace!( + "Considering call at {location:?} in {:?}", + self.tcx().def_path_str(self.def_id) + ); + + let (called_def_id, generic_args) = self + .operand_to_def_id(func) + .ok_or_else(|| vec![Error::operand_is_not_function_type(func)])?; + trace!("Resolved call to function: {}", self.fmt_fn(called_def_id)); + + // Monomorphize the called function with the known generic_args. + let param_env = tcx.param_env_reveal_all_normalized(self.def_id); + let Some(resolved_fn) = + utils::try_resolve_function(self.tcx(), called_def_id, param_env, generic_args) + else { + if let Some(d) = generic_args.iter().find(|arg| matches!(arg.unpack(), GenericArgKind::Type(t) if matches!(t.kind(), TyKind::Dynamic(..)))) { + self.tcx().sess.span_warn(self.tcx().def_span(called_def_id), format!("could not resolve instance due to dynamic argument: {d:?}")); + return Ok(None); + } else { + return Err( + vec![Error::instance_resolution_failed( + called_def_id, + generic_args, + span + )]); + } + }; + let resolved_def_id = resolved_fn.def_id(); + if log_enabled!(Level::Trace) && called_def_id != resolved_def_id { + let (called, resolved) = (self.fmt_fn(called_def_id), self.fmt_fn(resolved_def_id)); + trace!(" `{called}` monomorphized to `{resolved}`",); + } + + if is_non_default_trait_method(tcx, resolved_def_id).is_some() { + trace!(" bailing because is unresolvable trait method"); + return Ok(None); + } + + if let Some(handler) = self.can_approximate_async_functions(resolved_def_id) { + return Ok(Some(CallHandling::ApproxAsyncSM(handler))); + }; + + let call_kind = match self.classify_call_kind(called_def_id, resolved_def_id, args, span) { + Ok(cc) => cc, + Err(async_err) => { + return Err(vec![async_err]); + } + }; + + let calling_convention = CallingConvention::from_call_kind(&call_kind, args); + + trace!( + " Handling call! with kind {}", + match &call_kind { + CallKind::Direct => "direct", + CallKind::Indirect => "indirect", + CallKind::AsyncPoll { .. } => "async poll", + } + ); + + // Recursively generate the PDG for the child function. + + let cache_key = resolved_fn; + + let is_cached = self.memo.is_in_cache(cache_key); + + let call_changes = self.call_change_callback().map(|callback| { + let info = CallInfo { + callee: resolved_fn, + call_string: self.make_call_string(location), + is_cached, + async_parent: if let CallKind::AsyncPoll(resolution, _loc, _) = call_kind { + // Special case for async. We ask for skipping not on the closure, but + // on the "async" function that created it. This is needed for + // consistency in skipping. Normally, when "poll" is inlined, mutations + // introduced by the creator of the future are not recorded and instead + // handled here, on the closure. But if the closure is skipped we need + // those mutations to occur. To ensure this we always ask for the + // "CallChanges" on the creator so that both creator and closure have + // the same view of whether they are inlined or "Skip"ped. + Some(resolution) + } else { + None + }, + }; + callback.on_inline(info) + }); + + // Handle async functions at the time of polling, not when the future is created. + if is_async(tcx, resolved_def_id) { + trace!(" Bailing because func is async"); + + // If a skip was requested then "poll" will not be inlined later so we + // bail with "None" here and perform the mutations. Otherwise we bail with + // "Some", knowing that handling "poll" later will handle the mutations. + return Ok((!matches!( + &call_changes, + Some(CallChanges { + skip: SkipCall::Skip, + .. + }) + )) + .then_some(CallHandling::ApproxAsyncFn)); + } + + if matches!( + call_changes, + Some(CallChanges { + skip: SkipCall::Skip, + .. + }) + ) { + trace!(" Bailing because user callback said to bail"); + return Ok(None); + } + let Some(descriptor) = self.memo.construct_for(cache_key)? else { + trace!(" Bailing because cache lookup {cache_key} failed"); + return Ok(None); + }; + Ok(Some(CallHandling::Ready { + descriptor, + calling_convention, + })) + } + + /// Attempt to inline a call to a function. + /// + /// The return indicates whether we were successfully able to perform the inlining. + fn handle_call( + &self, + state: &mut InstructionState<'tcx>, + location: Location, + func: &Operand<'tcx>, + args: &[Operand<'tcx>], + destination: Place<'tcx>, + span: Span, + ) -> Result>> { + // Note: my comments here will use "child" to refer to the callee and + // "parent" to refer to the caller, since the words are most visually distinct. + + let Some(preamble) = self.determine_call_handling(location, func, args, span)? else { + return Ok(false); + }; + + trace!("Call handling is {}", preamble.as_ref()); + + let (child_constructor, calling_convention) = match preamble { + CallHandling::Ready { + descriptor, + calling_convention, + } => (descriptor, calling_convention), + CallHandling::ApproxAsyncFn => { + // Register a synthetic assignment of `future = (arg0, arg1, ...)`. + let rvalue = Rvalue::Aggregate( + Box::new(AggregateKind::Tuple), + IndexVec::from_iter(args.iter().cloned()), + ); + self.modular_mutation_visitor(state) + .visit_assign(&destination, &rvalue, location); + return Ok(true); + } + CallHandling::ApproxAsyncSM(handler) => { + handler( + self, + &mut self.modular_mutation_visitor(state), + args, + destination, + location, + ); + return Ok(true); + } + }; + + let parentable_dsts = child_constructor.parentable_dsts(|n| n.len() == 1); + let parent_body = &self.body; + + // For each destination node CHILD that is parentable to PLACE, + // add an edge from CHILD -> PLACE. + // + // PRECISION TODO: for a given child place, we only want to connect + // the *last* nodes in the child function to the parent, not *all* of them. + trace!("CHILD -> PARENT EDGES:"); + for (child_dst, _) in parentable_dsts { + if let Some(parent_place) = calling_convention.translate_to_parent( + child_dst.place, + self.async_info(), + self.tcx(), + parent_body, + self.def_id.to_def_id(), + destination, + Some(child_dst.place.ty(child_constructor, self.tcx())), + ) { + self.apply_mutation(state, location, parent_place); + } + } + + Ok(true) + } + + fn modular_mutation_visitor<'b: 'a>( + &'b self, + state: &'a mut InstructionState<'tcx>, + ) -> ModularMutationVisitor<'b, 'tcx, impl FnMut(Location, Mutation<'tcx>) + 'b> { + ModularMutationVisitor::new( + &self.place_info, + move |location, mutation: Mutation<'tcx>| { + self.apply_mutation(state, location, mutation.mutated) + }, + ) + } + + pub(super) fn generic_args(&self) -> GenericArgsRef<'tcx> { + self.root.args + } + + pub(crate) fn construct_partial(&'a self) -> Result, Vec>> { + let mut analysis = WithConstructionErrors::new(self) + .into_engine(self.tcx(), &self.body) + .iterate_to_fixpoint(); + + if !analysis.analysis.errors.is_empty() { + return Err(analysis.analysis.errors.into_iter().collect()); + } + + let mut final_state = WithConstructionErrors::new(PartialGraph::new( + self.generic_args(), + self.def_id.to_def_id(), + self.body.arg_count, + self.body.local_decls(), + )); + + analysis.visit_reachable_with(&self.body, &mut final_state); + + let mut final_state = final_state.into_result()?; + + let all_returns = self.body.all_returns().map(|ret| ret.block).collect_vec(); + let mut analysis = analysis.into_results_cursor(&self.body); + for block in all_returns { + analysis.seek_to_block_end(block); + let return_state = analysis.get(); + for (place, locations) in &return_state.last_mutation { + let ret_kind = if place.local == RETURN_PLACE { + TargetUse::Return + } else if let Some(num) = other_as_arg(*place, &self.body) { + TargetUse::MutArg(num) + } else { + continue; + }; + for location in locations { + let src = self.make_dep_node(*place, *location); + let dst = self.make_dep_node(*place, RichLocation::End); + let edge = DepEdge::data( + self.make_call_string(self.body.terminator_loc(block)), + SourceUse::Operand, + ret_kind, + ); + final_state.edges.insert((src, dst, edge)); + } + } + } + + Ok(final_state) + } + + /// Determine the type of call-site. + /// + /// The error case is if we tried to resolve this as async and failed. We + /// know it *is* async but we couldn't determine the information needed to + /// analyze the function, therefore we will have to approximate it. + fn classify_call_kind<'b>( + &'b self, + def_id: DefId, + resolved_def_id: DefId, + original_args: &'b [Operand<'tcx>], + span: Span, + ) -> Result, Error<'tcx>> { + match self.try_poll_call_kind(def_id, original_args, span) { + AsyncDeterminationResult::Resolved(r) => Ok(r), + AsyncDeterminationResult::NotAsync => Ok(self + .try_indirect_call_kind(resolved_def_id) + .unwrap_or(CallKind::Direct)), + AsyncDeterminationResult::Unresolvable(reason) => Err(reason), + } + } + + fn try_indirect_call_kind(&self, def_id: DefId) -> Option> { + self.tcx().is_closure(def_id).then_some(CallKind::Indirect) + } + + fn terminator_visitor<'b: 'a>( + &'b self, + state: &'b mut InstructionState<'tcx>, + time: Time, + ) -> ModularMutationVisitor<'b, 'tcx, impl FnMut(Location, Mutation<'tcx>) + 'b> { + let mut vis = self.modular_mutation_visitor(state); + vis.set_time(time); + vis + } +} + +impl<'tcx, 'a> WithConstructionErrors<'tcx, &'_ LocalAnalysis<'tcx, 'a>> { + fn handle_terminator( + &mut self, + terminator: &Terminator<'tcx>, + state: &mut InstructionState<'tcx>, + location: Location, + time: Time, + ) { + if let TerminatorKind::Call { + func, + args, + destination, + .. + } = &terminator.kind + { + match self.inner.handle_call( + state, + location, + func, + args, + *destination, + terminator.source_info.span, + ) { + Err(e) => { + self.errors.extend(e); + } + Ok(false) => { + trace!("Terminator {:?} failed the preamble", terminator.kind); + } + Ok(true) => return, + } + } + // Fallback: call the visitor + self.inner + .terminator_visitor(state, time) + .visit_terminator(terminator, location) + } +} + +impl<'tcx, 'a> df::AnalysisDomain<'tcx> + for WithConstructionErrors<'tcx, &'a LocalAnalysis<'tcx, 'a>> +{ + type Domain = InstructionState<'tcx>; + + const NAME: &'static str = "LocalPdgConstruction"; + + fn bottom_value(&self, _body: &Body<'tcx>) -> Self::Domain { + InstructionState::default() + } + + fn initialize_start_block(&self, _body: &Body<'tcx>, _state: &mut Self::Domain) {} +} + +impl<'a, 'tcx> df::Analysis<'tcx> for WithConstructionErrors<'tcx, &'a LocalAnalysis<'tcx, 'a>> { + fn apply_statement_effect( + &mut self, + state: &mut Self::Domain, + statement: &Statement<'tcx>, + location: Location, + ) { + self.inner + .modular_mutation_visitor(state) + .visit_statement(statement, location) + } + + fn apply_terminator_effect<'mir>( + &mut self, + state: &mut Self::Domain, + terminator: &'mir Terminator<'tcx>, + location: Location, + ) -> TerminatorEdges<'mir, 'tcx> { + self.handle_terminator(terminator, state, location, Time::Unspecified); + terminator.edges() + } + + fn apply_call_return_effect( + &mut self, + _state: &mut Self::Domain, + _block: BasicBlock, + _return_places: rustc_middle::mir::CallReturnPlaces<'_, 'tcx>, + ) { + } +} + +pub enum CallKind<'tcx> { + /// A standard function call like `f(x)`. + Direct, + /// A call to a function variable, like `fn foo(f: impl Fn()) { f() }` + Indirect, + /// A poll to an async function, like `f.await`. + AsyncPoll(Instance<'tcx>, Location, Place<'tcx>), +} + +#[derive(strum::AsRefStr)] +pub(crate) enum CallHandling<'tcx, 'a> { + ApproxAsyncFn, + Ready { + calling_convention: CallingConvention<'tcx, 'a>, + descriptor: &'a PartialGraph<'tcx>, + }, + ApproxAsyncSM(ApproximationHandler<'tcx, 'a>), +} + +fn other_as_arg<'tcx>(place: Place<'tcx>, body: &Body<'tcx>) -> Option { + (body.local_kind(place.local) == rustc_middle::mir::LocalKind::Arg) + .then(|| place.local.as_u32() as u8 - 1) +} diff --git a/crates/flowistry_pdg_construction/src/utils.rs b/crates/flowistry_pdg_construction/src/utils.rs index 492fb7f2ad..97631b8a9d 100644 --- a/crates/flowistry_pdg_construction/src/utils.rs +++ b/crates/flowistry_pdg_construction/src/utils.rs @@ -1,107 +1,49 @@ -use std::{borrow::Cow, collections::hash_map::Entry, hash::Hash}; +use std::{collections::hash_map::Entry, fmt::Debug, hash::Hash}; use either::Either; -use flowistry_pdg::rustc_portable::LocalDefId; + use itertools::Itertools; -use log::{debug, trace}; +use log::trace; use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hir::def_id::DefId; + use rustc_middle::{ mir::{ tcx::PlaceTy, Body, HasLocalDecls, Local, Location, Place, ProjectionElem, Statement, StatementKind, Terminator, TerminatorKind, }, ty::{ - self, EarlyBinder, GenericArg, GenericArgsRef, Instance, List, ParamEnv, Ty, TyCtxt, TyKind, + self, Binder, BoundVariableKind, EarlyBinder, GenericArg, GenericArgKind, GenericArgsRef, + Instance, List, ParamEnv, Region, Ty, TyCtxt, TyKind, }, }; -use rustc_span::ErrorGuaranteed; + +use rustc_span::Span; use rustc_type_ir::{fold::TypeFoldable, AliasKind}; use rustc_utils::{BodyExt, PlaceExt}; -#[derive(Clone, Copy, Eq, PartialEq, Hash, Debug)] -pub enum FnResolution<'tcx> { - Final(ty::Instance<'tcx>), - Partial(DefId), -} - -impl<'tcx> PartialOrd for FnResolution<'tcx> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} +use crate::construct::Error; -impl<'tcx> Ord for FnResolution<'tcx> { - fn cmp(&self, other: &Self) -> std::cmp::Ordering { - use FnResolution::*; - match (self, other) { - (Final(_), Partial(_)) => std::cmp::Ordering::Greater, - (Partial(_), Final(_)) => std::cmp::Ordering::Less, - (Partial(slf), Partial(otr)) => slf.cmp(otr), - (Final(slf), Final(otr)) => match slf.def.cmp(&otr.def) { - std::cmp::Ordering::Equal => slf.args.cmp(otr.args), - result => result, - }, - } - } -} +pub trait Captures<'a> {} +impl<'a, T: ?Sized> Captures<'a> for T {} -impl<'tcx> FnResolution<'tcx> { - pub fn def_id(self) -> DefId { - match self { - FnResolution::Final(f) => f.def_id(), - FnResolution::Partial(p) => p, - } - } -} - -impl<'tcx> std::fmt::Display for FnResolution<'tcx> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - FnResolution::Final(sub) => std::fmt::Debug::fmt(sub, f), - FnResolution::Partial(p) => std::fmt::Debug::fmt(p, f), - } - } -} - -/// Try and normalize the provided generics. -/// -/// The purpose of this function is to test whether resolving these generics -/// will return an error. We need this because [`ty::Instance::resolve`] fails -/// with a hard error when this normalization fails (even though it returns -/// [`Result`]). However legitimate situations can arise in the code where this -/// normalization fails for which we want to report warnings but carry on with -/// the analysis which a hard error doesn't allow us to do. -fn test_generics_normalization<'tcx>( - tcx: TyCtxt<'tcx>, - param_env: ParamEnv<'tcx>, - args: &'tcx ty::List>, -) -> Result<(), ty::normalize_erasing_regions::NormalizationError<'tcx>> { - tcx.try_normalize_erasing_regions(param_env, args) - .map(|_| ()) +/// An async check that does not crash if called on closures. +pub fn is_async(tcx: TyCtxt<'_>, def_id: DefId) -> bool { + !tcx.is_closure(def_id) && tcx.asyncness(def_id).is_async() } +/// Resolve the `def_id` item to an instance. pub fn try_resolve_function<'tcx>( tcx: TyCtxt<'tcx>, def_id: DefId, param_env: ParamEnv<'tcx>, args: GenericArgsRef<'tcx>, -) -> FnResolution<'tcx> { +) -> Option> { let param_env = param_env.with_reveal_all_normalized(tcx); - let make_opt = || { - if let Err(e) = test_generics_normalization(tcx, param_env, args) { - debug!("Normalization failed: {e:?}"); - return None; - } - Instance::resolve(tcx, param_env, def_id, args).unwrap() - }; - - match make_opt() { - Some(inst) => FnResolution::Final(inst), - None => FnResolution::Partial(def_id), - } + Instance::resolve(tcx, param_env, def_id, args).unwrap() } +/// Returns the default implementation of this method if it is a trait method. pub fn is_non_default_trait_method(tcx: TyCtxt, function: DefId) -> Option { let assoc_item = tcx.opt_associated_item(function)?; if assoc_item.container != ty::AssocItemContainer::TraitContainer @@ -112,32 +54,50 @@ pub fn is_non_default_trait_method(tcx: TyCtxt, function: DefId) -> Option FnResolution<'tcx> { - pub fn try_monomorphize<'a, T>( - self, - tcx: TyCtxt<'tcx>, - param_env: ParamEnv<'tcx>, - t: &'a T, - ) -> Cow<'a, T> - where - T: TypeFoldable> + Clone, - { - match self { - FnResolution::Partial(_) => Cow::Borrowed(t), - FnResolution::Final(inst) => Cow::Owned(inst.subst_mir_and_normalize_erasing_regions( - tcx, - param_env, - EarlyBinder::bind(tcx.erase_regions(t.clone())), - )), +/// The "canonical" way we monomorphize +pub fn try_monomorphize<'tcx, 'a, T>( + inst: Instance<'tcx>, + tcx: TyCtxt<'tcx>, + param_env: ParamEnv<'tcx>, + t: &'a T, + span: Span, +) -> Result> +where + T: TypeFoldable> + Clone + Debug, +{ + inst.try_subst_mir_and_normalize_erasing_regions( + tcx, + param_env, + EarlyBinder::bind(tcx.erase_regions(t.clone())), + ) + .map_err(|e| Error::NormalizationError { + instance: inst, + span, + error: format!("{e:?}"), + }) +} + +/// Attempt to interpret this type as a statically determinable function and its +/// generic arguments. +pub fn type_as_fn<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Option<(DefId, GenericArgsRef<'tcx>)> { + let ty = ty_resolve(ty, tcx); + match ty.kind() { + TyKind::FnDef(def_id, generic_args) => Some((*def_id, generic_args)), + TyKind::Generator(def_id, generic_args, _) => Some((*def_id, generic_args)), + ty => { + trace!("Bailing from handle_call because func is literal with type: {ty:?}"); + None } } } -pub fn retype_place<'tcx>( +/// If `target_ty` is supplied checks that the final type is the same as `target_ty`. +pub(crate) fn retype_place<'tcx>( orig: Place<'tcx>, tcx: TyCtxt<'tcx>, body: &Body<'tcx>, def_id: DefId, + target_ty: Option>, ) -> Place<'tcx> { trace!("Retyping {orig:?} in context of {def_id:?}"); @@ -149,23 +109,24 @@ pub fn retype_place<'tcx>( ty.ty.kind(), TyKind::Alias(..) | TyKind::Param(..) | TyKind::Bound(..) | TyKind::Placeholder(..) ) { + trace!("Breaking on param-like type {:?}", ty.ty); break; } - // Don't continue if we reach a private field - if let ProjectionElem::Field(field, _) = elem { - if let Some(adt_def) = ty.ty.ty_adt_def() { - let field = adt_def - .all_fields() - .nth(field.as_usize()) - .unwrap_or_else(|| { - panic!("ADT for {:?} does not have field {field:?}", ty.ty); - }); - if !field.vis.is_accessible_from(def_id, tcx) { - break; - } - } - } + // // Don't continue if we reach a private field + // if let ProjectionElem::Field(field, _) = elem { + // if let Some(adt_def) = ty.ty.ty_adt_def() { + // let field = adt_def + // .all_fields() + // .nth(field.as_usize()) + // .unwrap_or_else(|| { + // panic!("ADT for {:?} does not have field {field:?}", ty.ty); + // }); + // if !field.vis.is_accessible_from(def_id, tcx) { + // break; + // } + // } + // } trace!( " Projecting {:?}.{new_projection:?} : {:?} with {elem:?}", @@ -195,13 +156,179 @@ pub fn retype_place<'tcx>( }; new_projection.push(elem); } - let p = Place::make(orig.local, &new_projection, tcx); + + if let Some(target_ty) = target_ty { + if !ty.equiv(&target_ty) { + let p1_str = format!("{orig:?}"); + let p2_str = format!("{p:?}"); + let l = p1_str.len().max(p2_str.len()); + panic!("Retyping in {} failed to produce an equivalent type.\n Src {p1_str:l$} : {target_ty:?}\n Dst {p2_str:l$} : {ty:?}", tcx.def_path_str(def_id)) + } + } + trace!(" Final translation: {p:?}"); p } -pub fn hashset_join( +pub trait SimpleTyEquiv { + fn equiv(&self, other: &Self) -> bool; +} + +impl<'tcx> SimpleTyEquiv for Ty<'tcx> { + fn equiv(&self, other: &Self) -> bool { + self.kind().equiv(other.kind()) + } +} + +impl SimpleTyEquiv for [T] { + fn equiv(&self, other: &Self) -> bool { + self.iter().zip(other.iter()).all(|(a, b)| a.equiv(b)) + } +} + +impl SimpleTyEquiv for ty::List { + fn equiv(&self, other: &Self) -> bool { + self.as_slice().equiv(other.as_slice()) + } +} + +impl<'tcx> SimpleTyEquiv for GenericArg<'tcx> { + fn equiv(&self, other: &Self) -> bool { + match (&self.unpack(), &other.unpack()) { + (GenericArgKind::Const(a), GenericArgKind::Const(b)) => a == b, + (GenericArgKind::Lifetime(a), GenericArgKind::Lifetime(b)) => a.equiv(b), + (GenericArgKind::Type(a), GenericArgKind::Type(b)) => a.equiv(b), + _ => false, + } + } +} + +impl<'tcx> SimpleTyEquiv for Region<'tcx> { + fn equiv(&self, _other: &Self) -> bool { + true + } +} + +impl<'tcx, T: SimpleTyEquiv> SimpleTyEquiv for ty::Binder<'tcx, T> { + fn equiv(&self, other: &Self) -> bool { + self.bound_vars().equiv(other.bound_vars()) + && self + .as_ref() + .skip_binder() + .equiv(other.as_ref().skip_binder()) + } +} + +impl SimpleTyEquiv for BoundVariableKind { + fn equiv(&self, other: &Self) -> bool { + self == other + } +} + +impl<'tcx> SimpleTyEquiv for ty::TypeAndMut<'tcx> { + fn equiv(&self, other: &Self) -> bool { + self.mutbl == other.mutbl && self.ty.equiv(&other.ty) + } +} + +impl<'tcx> SimpleTyEquiv for ty::FnSig<'tcx> { + fn equiv(&self, other: &Self) -> bool { + let Self { + inputs_and_output, + c_variadic, + unsafety, + abi, + } = *self; + inputs_and_output.equiv(other.inputs_and_output) + && c_variadic == other.c_variadic + && unsafety == other.unsafety + && abi == other.abi + } +} + +impl SimpleTyEquiv for &T { + fn equiv(&self, other: &Self) -> bool { + (*self).equiv(*other) + } +} + +impl<'tcx> SimpleTyEquiv for ty::AliasTy<'tcx> { + fn equiv(&self, other: &Self) -> bool { + self.def_id == other.def_id && self.args.equiv(other.args) + } +} + +impl<'tcx> SimpleTyEquiv for ty::ExistentialPredicate<'tcx> { + fn equiv(&self, other: &Self) -> bool { + self == other + } +} + +fn is_wildcard(t: &TyKind<'_>) -> bool { + matches!( + t, + TyKind::Param(..) | TyKind::Alias(..) | TyKind::Bound(..) | TyKind::Placeholder(..) + ) || matches!(t, + TyKind::Dynamic(pred, _, _) if matches!( + pred.first().copied().and_then(Binder::no_bound_vars), + Some(ty::ExistentialPredicate::Trait(tref)) + if tref.def_id == ty::tls::with(|tcx| tcx + .get_diagnostic_item(rustc_span::sym::Any) + .expect("The `Any` item is not defined.")) + ) + ) +} + +impl<'tcx> SimpleTyEquiv for TyKind<'tcx> { + fn equiv(&self, other: &Self) -> bool { + use rustc_type_ir::TyKind::*; + match (self, other) { + _ if is_wildcard(self) || is_wildcard(other) => true, + (Int(a_i), Int(b_i)) => a_i == b_i, + (Uint(a_u), Uint(b_u)) => a_u == b_u, + (Float(a_f), Float(b_f)) => a_f == b_f, + (Adt(a_d, a_s), Adt(b_d, b_s)) => a_d == b_d && a_s.equiv(b_s), + (Foreign(a_d), Foreign(b_d)) => a_d == b_d, + (Array(a_t, a_c), Array(b_t, b_c)) => a_t.equiv(b_t) && a_c == b_c, + (Slice(a_t), Slice(b_t)) => a_t.equiv(b_t), + (RawPtr(a_t), RawPtr(b_t)) => a_t.equiv(b_t), + (Ref(a_r, a_t, a_m), Ref(b_r, b_t, b_m)) => { + a_r.equiv(b_r) && a_t.equiv(b_t) && a_m == b_m + } + (FnDef(a_d, a_s), FnDef(b_d, b_s)) => a_d == b_d && a_s.equiv(b_s), + (FnPtr(a_s), FnPtr(b_s)) => a_s.equiv(b_s), + (Dynamic(a_p, a_r, a_repr), Dynamic(b_p, b_r, b_repr)) => { + a_p.equiv(b_p) && a_r.equiv(b_r) && a_repr == b_repr + } + (Closure(a_d, a_s), Closure(b_d, b_s)) => a_d == b_d && a_s.equiv(b_s), + (Generator(a_d, a_s, a_m), Generator(b_d, b_s, b_m)) => { + a_d == b_d && a_s.equiv(b_s) && a_m == b_m + } + (GeneratorWitness(a_g), GeneratorWitness(b_g)) => a_g.equiv(b_g), + (GeneratorWitnessMIR(a_d, a_s), GeneratorWitnessMIR(b_d, b_s)) => { + a_d == b_d && a_s.equiv(b_s) + } + (Tuple(a_t), Tuple(b_t)) => a_t.equiv(b_t), + (Alias(a_i, a_p), Alias(b_i, b_p)) => a_i == b_i && a_p.equiv(b_p), + (Param(a_p), Param(b_p)) => a_p == b_p, + (Bound(a_d, a_b), Bound(b_d, b_b)) => a_d == b_d && a_b == b_b, + (Placeholder(a_p), Placeholder(b_p)) => a_p == b_p, + (Infer(_a_t), Infer(_b_t)) => unreachable!(), + (Error(a_e), Error(b_e)) => a_e == b_e, + (Bool, Bool) | (Char, Char) | (Str, Str) | (Never, Never) => true, + _ => false, + } + } +} + +impl<'tcx> SimpleTyEquiv for PlaceTy<'tcx> { + fn equiv(&self, other: &Self) -> bool { + self.variant_index == other.variant_index && self.ty.equiv(&other.ty) + } +} + +pub(crate) fn hashset_join( hs1: &mut FxHashSet, hs2: &FxHashSet, ) -> bool { @@ -210,7 +337,7 @@ pub fn hashset_join( hs1.len() != orig_len } -pub fn hashmap_join( +pub(crate) fn hashmap_join( hm1: &mut FxHashMap, hm2: &FxHashMap, join: impl Fn(&mut V, &V) -> bool, @@ -230,9 +357,9 @@ pub fn hashmap_join( changed } -pub type BodyAssignments = FxHashMap>; +pub(crate) type BodyAssignments = FxHashMap>; -pub fn find_body_assignments(body: &Body<'_>) -> BodyAssignments { +pub(crate) fn find_body_assignments(body: &Body<'_>) -> BodyAssignments { body.all_locations() .filter_map(|location| match body.stmt_at(location) { Either::Left(Statement { @@ -250,6 +377,7 @@ pub fn find_body_assignments(body: &Body<'_>) -> BodyAssignments { .collect() } +/// Resolve through type aliases pub fn ty_resolve<'tcx>(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Ty<'tcx> { match ty.kind() { TyKind::Alias(AliasKind::Opaque, alias_ty) => tcx.type_of(alias_ty.def_id).skip_binder(), @@ -257,19 +385,32 @@ pub fn ty_resolve<'tcx>(ty: Ty<'tcx>, tcx: TyCtxt<'tcx>) -> Ty<'tcx> { } } +/// This function creates dynamic types that satisfy the constraints on the +/// given function. It returns a list of generic arguments that are suitable for +/// calling `Instance::resolve` for this function, guaranteeing that the resolve +/// call does not fail. +/// +/// This is achieved by constructing `dyn` types which assume the constraints of +/// the `where` clause for this function (and any parents). pub fn manufacture_substs_for( tcx: TyCtxt<'_>, - function: LocalDefId, -) -> Result<&List>, ErrorGuaranteed> { + function: DefId, +) -> Result<&List>, Error> { use rustc_middle::ty::{ - Binder, BoundRegionKind, DynKind, ExistentialPredicate, ExistentialProjection, - ExistentialTraitRef, GenericParamDefKind, ImplPolarity, ParamTy, Region, TraitPredicate, + BoundRegionKind, DynKind, ExistentialPredicate, ExistentialProjection, ExistentialTraitRef, + GenericParamDefKind, ImplPolarity, ParamTy, TraitPredicate, }; + trace!("Manufacturing for {function:?}"); + let generics = tcx.generics_of(function); + trace!("Found generics {generics:?}"); let predicates = tcx.predicates_of(function).instantiate_identity(tcx); + trace!("Found predicates {predicates:?}"); + let lang_items = tcx.lang_items(); let types = (0..generics.count()).map(|gidx| { let param = generics.param_at(gidx, tcx); + trace!("Trying param {param:?}"); if let Some(default_val) = param.default_value(tcx) { return Ok(default_val.instantiate_identity()); } @@ -278,58 +419,94 @@ pub fn manufacture_substs_for( GenericParamDefKind::Lifetime => { return Ok(GenericArg::from(Region::new_free( tcx, - function.to_def_id(), + function, BoundRegionKind::BrAnon(None), ))) } GenericParamDefKind::Const { .. } => { - return Err(tcx.sess.span_err( - tcx.def_span(param.def_id), - "Cannot use constants as generic parameters in controllers", - )) + return Err(Error::ConstantInGenerics { function }); } GenericParamDefKind::Type { .. } => (), }; let param_as_ty = ParamTy::for_def(param); - let constraints = predicates.predicates.iter().filter_map(|clause| { - let pred = if let Some(trait_ref) = clause.as_trait_clause() { - if trait_ref.polarity() != ImplPolarity::Positive { - return None; - }; - let Some(TraitPredicate { trait_ref, .. }) = trait_ref.no_bound_vars() else { - return Some(Err(tcx.sess.span_err( - tcx.def_span(param.def_id), - format!("Trait ref had binder {trait_ref:?}"), - ))); - }; - if !matches!(trait_ref.self_ty().kind(), TyKind::Param(p) if *p == param_as_ty) { - return None; - }; - Some(ExistentialPredicate::Trait( - ExistentialTraitRef::erase_self_ty(tcx, trait_ref), - )) - } else if let Some(pred) = clause.as_projection_clause() { - let pred = pred.no_bound_vars()?; - if !matches!(pred.self_ty().kind(), TyKind::Param(p) if *p == param_as_ty) { + let constraints = predicates.predicates.iter().enumerate().rev().filter_map( + |(_pidx, clause)| { + trace!(" Trying clause {clause:?}"); + let pred = if let Some(trait_ref) = clause.as_trait_clause() { + trace!(" is trait clause"); + if trait_ref.polarity() != ImplPolarity::Positive { + trace!(" Bailing because it is negative"); + return None; + }; + let Some(TraitPredicate { trait_ref, .. }) = trait_ref.no_bound_vars() else { + return Some(Err(Error::TraitRefWithBinder { function })); + }; + if !matches!(trait_ref.self_ty().kind(), TyKind::Param(p) if *p == param_as_ty) + { + trace!(" Bailing because self type is not param type"); + return None; + }; + if Some(trait_ref.def_id) == lang_items.sized_trait() + || tcx.trait_is_auto(trait_ref.def_id) + { + trace!(" bailing because trait is auto trait"); + return None; + } + ExistentialPredicate::Trait(ExistentialTraitRef::erase_self_ty(tcx, trait_ref)) + } else if let Some(pred) = clause.as_projection_clause() { + trace!(" is projection clause"); + let Some(pred) = pred.no_bound_vars() else { + return Some(Err(Error::BoundVariablesInPredicates { function })); + }; + if !matches!(pred.self_ty().kind(), TyKind::Param(p) if *p == param_as_ty) { + trace!(" Bailing because self type is not param type"); + return None; + }; + ExistentialPredicate::Projection(ExistentialProjection::erase_self_ty( + tcx, pred, + )) + } else { + trace!(" is other clause: ignoring"); return None; }; - Some(ExistentialPredicate::Projection( - ExistentialProjection::erase_self_ty(tcx, pred), - )) - } else { - None - }?; - - Some(Ok(Binder::dummy(pred))) - }); + + trace!(" Created predicate {pred:?}"); + + Some(Ok(Binder::dummy(pred))) + }, + ); + let mut predicates = constraints.collect::, _>>()?; + trace!(" collected predicates {predicates:?}"); + match predicates.len() { + 0 => predicates.push(Binder::dummy(ExistentialPredicate::Trait( + ExistentialTraitRef { + def_id: tcx + .get_diagnostic_item(rustc_span::sym::Any) + .expect("The `Any` item is not defined."), + args: List::empty(), + }, + ))), + 1 => (), + _ => { + return Err(Error::TooManyPredicatesForSynthesizingGenerics { + function, + number: predicates.len() as u32, + }) + } + }; + let poly_predicate = tcx.mk_poly_existential_predicates_from_iter(predicates.into_iter()); + trace!(" poly predicate {poly_predicate:?}"); let ty = Ty::new_dynamic( tcx, - tcx.mk_poly_existential_predicates_from_iter(constraints)?, - Region::new_free(tcx, function.to_def_id(), BoundRegionKind::BrAnon(None)), + poly_predicate, + Region::new_free(tcx, function, BoundRegionKind::BrAnon(None)), DynKind::Dyn, ); + trace!(" Created a dyn {ty:?}"); Ok(GenericArg::from(ty)) }); - tcx.mk_args_from_iter(types) + let args = tcx.mk_args_from_iter(types)?; + trace!("Created args {args:?}"); + Ok(args) } diff --git a/crates/flowistry_pdg_construction/tests/pdg.rs b/crates/flowistry_pdg_construction/tests/pdg.rs index f9b0bf6dd1..13e32a4b1a 100644 --- a/crates/flowistry_pdg_construction/tests/pdg.rs +++ b/crates/flowistry_pdg_construction/tests/pdg.rs @@ -3,13 +3,14 @@ extern crate either; extern crate rustc_hir; extern crate rustc_middle; +extern crate rustc_span; use std::collections::HashSet; use either::Either; use flowistry_pdg_construction::{ graph::{DepEdge, DepGraph}, - CallChangeCallbackFn, CallChanges, PdgParams, SkipCall, + CallChangeCallbackFn, CallChanges, MemoPdgConstructor, NoLoader, SkipCall, }; use itertools::Itertools; use rustc_hir::def_id::LocalDefId; @@ -17,6 +18,7 @@ use rustc_middle::{ mir::{Terminator, TerminatorKind}, ty::TyCtxt, }; +use rustc_span::Symbol; use rustc_utils::{ mir::borrowck_facts, source_map::find_bodies::find_bodies, test_utils::CompileResult, }; @@ -34,14 +36,15 @@ fn get_main(tcx: TyCtxt<'_>) -> LocalDefId { fn pdg( input: impl Into, - configure: impl for<'tcx> FnOnce(TyCtxt<'tcx>, PdgParams<'tcx>) -> PdgParams<'tcx> + Send, + configure: impl for<'tcx> FnOnce(TyCtxt<'tcx>, &mut MemoPdgConstructor<'tcx>) + Send, tests: impl for<'tcx> FnOnce(TyCtxt<'tcx>, DepGraph<'tcx>) + Send, ) { let _ = env_logger::try_init(); rustc_utils::test_utils::CompileBuilder::new(input).compile(move |CompileResult { tcx }| { let def_id = get_main(tcx); - let params = configure(tcx, PdgParams::new(tcx, def_id).unwrap()); - let pdg = flowistry_pdg_construction::compute_pdg(params); + let mut memo = MemoPdgConstructor::new(tcx, NoLoader); + configure(tcx, &mut memo); + let pdg = memo.construct_graph(def_id).unwrap(); tests(tcx, pdg) }) } @@ -92,8 +95,10 @@ fn connects<'tcx>( .edge_indices() .filter_map(|edge| { let DepEdge { at, .. } = g.graph[edge]; - let body_with_facts = - borrowck_facts::get_body_with_borrowck_facts(tcx, at.leaf().function); + let body_with_facts = borrowck_facts::get_body_with_borrowck_facts( + tcx, + at.leaf().function.expect_local(), + ); let Either::Right(Terminator { kind: TerminatorKind::Call { func, .. }, .. @@ -168,7 +173,7 @@ macro_rules! pdg_constraint { macro_rules! pdg_test { ($name:ident, { $($i:item)* }, $($cs:tt),*) => { - pdg_test!($name, { $($i)* }, |_, params| params, $($cs),*); + pdg_test!($name, { $($i)* }, |_, _| (), $($cs),*); }; ($name:ident, { $($i:item)* }, $e:expr, $($cs:tt),*) => { #[test] @@ -612,7 +617,7 @@ pdg_test! { |_, params| { params.with_call_change_callback(CallChangeCallbackFn::new( move |_| { CallChanges::default().with_skip(SkipCall::Skip) - })) + })); }, (recipients -/> sender) } @@ -638,19 +643,23 @@ pdg_test! { nested_layer_one(&mut w, z); } }, - |tcx, params| params.with_call_change_callback(CallChangeCallbackFn::new(move |info| { - let name = tcx.opt_item_name(info.callee.def_id()); - let skip = if !matches!(name.as_ref().map(|sym| sym.as_str()), Some("no_inline")) - && info.call_string.len() < 2 - { - SkipCall::NoSkip - } else { - SkipCall::Skip - }; - CallChanges::default().with_skip(skip) - })), - (y -> x), - (z -> w) + |tcx, params| { + params.with_call_change_callback(CallChangeCallbackFn::new(move |info| { + let name = tcx.opt_item_name(info.callee.def_id()); + let skip = if !matches!(name.as_ref().map(|sym| sym.as_str()), Some("no_inline")) + && info.call_string.len() < 2 + { + SkipCall::NoSkip + } else { + SkipCall::Skip + }; + CallChanges::default().with_skip(skip) + })); + }, + (y -> x) + // TODO the way that graphs are constructed currently doesn't allow limiting + // call string depth + // (z -> w) } pdg_test! { @@ -776,3 +785,60 @@ pdg_test! { } }, } + +pdg_test! { + spawn_and_loop_await, + { + use std::future::Future; + use std::task::{Poll, Context}; + use std::pin::Pin; + + struct JoinHandle(Box>); + + impl Future for JoinHandle { + type Output = T; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.map_unchecked_mut(|p| p.0.as_mut()).poll(cx) + } + } + + pub fn spawn(future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + JoinHandle(Box::new(future)) + } + + pub async fn main() { + let mut tasks = vec![]; + for i in [0,1] { + let task: JoinHandle<_> = spawn(async move { + println!("{i}"); + Ok::<_, String>(0) + }); + tasks.push(task); + } + + for h in tasks { + if let Err(e) = h.await { + panic!("{e}") + } + } + } + }, + |tcx, params| { + params.with_call_change_callback(CallChangeCallbackFn::new(move |info| { + let name = tcx.opt_item_name(info.callee.def_id()); + let name2 = tcx.opt_parent(info.callee.def_id()).and_then(|c| tcx.opt_item_name(c)); + let is_spawn = |name: Option<&Symbol>| name.map_or(false, |n| n.as_str().contains("spawn")); + let mut changes = CallChanges::default(); + if is_spawn(name.as_ref()) || is_spawn(name2.as_ref()) + { + changes = changes.with_skip(SkipCall::Skip); + }; + changes + })); + }, + (i -> h) +} diff --git a/crates/paralegal-flow/Cargo.toml b/crates/paralegal-flow/Cargo.toml index 255f6079f9..c057ae416b 100644 --- a/crates/paralegal-flow/Cargo.toml +++ b/crates/paralegal-flow/Cargo.toml @@ -43,7 +43,7 @@ enum-map = "2.7" serial_test = "2.0.0" itertools = "0.12" anyhow = "1.0.72" -thiserror = "1" +thiserror = { workspace = true } serde_bare = "0.5.0" toml = "0.7" diff --git a/crates/paralegal-flow/build.rs b/crates/paralegal-flow/build.rs index 7fb9ee8ce6..d4b1f305bb 100644 --- a/crates/paralegal-flow/build.rs +++ b/crates/paralegal-flow/build.rs @@ -1,3 +1,5 @@ +#![feature(string_remove_matches)] + use std::path::PathBuf; use std::process::Command; extern crate chrono; @@ -71,8 +73,7 @@ fn main() { .arg("--version") .output() .unwrap(); - println!( - "cargo:rustc-env=RUSTC_VERSION=\"{}\"", - String::from_utf8(rustc_version.stdout).unwrap() - ); + let mut version_str = String::from_utf8(rustc_version.stdout).unwrap(); + version_str.remove_matches('\n'); + println!("cargo:rustc-env=RUSTC_VERSION={}", version_str,); } diff --git a/crates/paralegal-flow/src/ana/encoder.rs b/crates/paralegal-flow/src/ana/encoder.rs new file mode 100644 index 0000000000..18e65a4c00 --- /dev/null +++ b/crates/paralegal-flow/src/ana/encoder.rs @@ -0,0 +1,212 @@ +//! Readers and writers for the intermediate artifacts we store per crate. +//! +//! Most of this code is adapted/copied from `EncodeContext` and `DecodeContext` in +//! `rustc_metadata`. +//! +//! Note that the methods pertaining to allocations of `AllocId`'s are +//! unimplemented and will cause a crash if you try to stick an `AllocId` into +//! the Paralegal artifact. + +use std::path::Path; + +use rustc_hash::FxHashMap; +use rustc_hir::def_id::{DefId, DefIndex}; +use rustc_middle::ty::{self, Ty, TyCtxt}; +use rustc_serialize::{ + opaque::{FileEncoder, MemDecoder}, + Decodable, Decoder, Encodable, Encoder, +}; +use rustc_type_ir::{TyDecoder, TyEncoder}; + +macro_rules! encoder_methods { + ($($name:ident($ty:ty);)*) => { + $(fn $name(&mut self, value: $ty) { + self.file_encoder.$name(value) + })* + } +} + +pub struct ParalegalEncoder<'tcx> { + tcx: TyCtxt<'tcx>, + file_encoder: FileEncoder, + type_shorthands: FxHashMap, usize>, + predicate_shorthands: FxHashMap, usize>, +} + +impl<'tcx> ParalegalEncoder<'tcx> { + pub fn new(path: impl AsRef, tcx: TyCtxt<'tcx>) -> Self { + Self { + tcx, + file_encoder: FileEncoder::new(path).unwrap(), + type_shorthands: Default::default(), + predicate_shorthands: Default::default(), + } + } + + pub fn finish(self) { + self.file_encoder.finish().unwrap(); + } +} + +const CLEAR_CROSS_CRATE: bool = false; + +impl<'tcx> Encoder for ParalegalEncoder<'tcx> { + encoder_methods! { + emit_usize(usize); + emit_u128(u128); + emit_u64(u64); + emit_u32(u32); + emit_u16(u16); + emit_u8(u8); + + emit_isize(isize); + emit_i128(i128); + emit_i64(i64); + emit_i32(i32); + emit_i16(i16); + + emit_raw_bytes(&[u8]); + } +} + +impl<'tcx> TyEncoder for ParalegalEncoder<'tcx> { + type I = TyCtxt<'tcx>; + const CLEAR_CROSS_CRATE: bool = CLEAR_CROSS_CRATE; + + fn position(&self) -> usize { + self.file_encoder.position() + } + + fn type_shorthands( + &mut self, + ) -> &mut FxHashMap<::Ty, usize> { + &mut self.type_shorthands + } + + fn predicate_shorthands( + &mut self, + ) -> &mut FxHashMap<::PredicateKind, usize> { + &mut self.predicate_shorthands + } + + fn encode_alloc_id(&mut self, _alloc_id: &::AllocId) { + unimplemented!() + } +} + +impl<'tcx> Encodable> for DefId { + fn encode(&self, s: &mut ParalegalEncoder<'tcx>) { + s.tcx.def_path_hash(*self).encode(s) + } +} + +pub struct ParalegalDecoder<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + mem_decoder: MemDecoder<'a>, + shorthand_map: FxHashMap>, +} + +impl<'tcx, 'a> ParalegalDecoder<'tcx, 'a> { + pub fn new(tcx: TyCtxt<'tcx>, buf: &'a [u8]) -> Self { + Self { + tcx, + mem_decoder: MemDecoder::new(buf, 0), + shorthand_map: Default::default(), + } + } +} + +impl<'tcx, 'a> TyDecoder for ParalegalDecoder<'tcx, 'a> { + const CLEAR_CROSS_CRATE: bool = CLEAR_CROSS_CRATE; + + type I = TyCtxt<'tcx>; + + fn interner(&self) -> Self::I { + self.tcx + } + + fn cached_ty_for_shorthand( + &mut self, + shorthand: usize, + or_insert_with: F, + ) -> ::Ty + where + F: FnOnce(&mut Self) -> ::Ty, + { + if let Some(ty) = self.shorthand_map.get(&shorthand) { + return *ty; + } + let ty = or_insert_with(self); + self.shorthand_map.insert(shorthand, ty); + ty + } + + fn decode_alloc_id(&mut self) -> ::AllocId { + unimplemented!() + } + + fn with_position(&mut self, pos: usize, f: F) -> R + where + F: FnOnce(&mut Self) -> R, + { + let new_opaque = MemDecoder::new(self.mem_decoder.data(), pos); + let old_opaque = std::mem::replace(&mut self.mem_decoder, new_opaque); + let r = f(self); + self.mem_decoder = old_opaque; + r + } +} + +macro_rules! decoder_methods { + ($($name:ident($ty:ty);)*) => { + $(fn $name(&mut self) -> $ty { + self.mem_decoder.$name() + })* + } +} + +impl<'tcx, 'a> Decoder for ParalegalDecoder<'tcx, 'a> { + decoder_methods! { + read_usize(usize); + read_u128(u128); + read_u64(u64); + read_u32(u32); + read_u16(u16); + read_u8(u8); + read_isize(isize); + read_i128(i128); + read_i64(i64); + read_i32(i32); + read_i16(i16); + } + fn position(&self) -> usize { + self.mem_decoder.position() + } + fn peek_byte(&self) -> u8 { + self.mem_decoder.peek_byte() + } + fn read_raw_bytes(&mut self, len: usize) -> &[u8] { + self.mem_decoder.read_raw_bytes(len) + } +} + +impl<'tcx, 'a> Decodable> for DefId { + fn decode(d: &mut ParalegalDecoder<'tcx, 'a>) -> Self { + d.tcx + .def_path_hash_to_def_id(Decodable::decode(d), &mut || { + panic!("Could not translate hash") + }) + } +} + +impl<'tcx> Encodable> for DefIndex { + fn encode(&self, s: &mut ParalegalEncoder<'tcx>) { + self.as_u32().encode(s) + } +} + +impl<'tcx, 'a> Decodable> for DefIndex { + fn decode(d: &mut ParalegalDecoder<'tcx, 'a>) -> Self { + Self::from_u32(u32::decode(d)) + } +} diff --git a/crates/paralegal-flow/src/ana/graph_converter.rs b/crates/paralegal-flow/src/ana/graph_converter.rs index c062dbef50..60952deb9d 100644 --- a/crates/paralegal-flow/src/ana/graph_converter.rs +++ b/crates/paralegal-flow/src/ana/graph_converter.rs @@ -1,36 +1,30 @@ use crate::{ - ana::inline_judge::InlineJudge, - ann::MarkerAnnotation, - desc::*, - discover::FnToAnalyze, - rust::{hir::def, *}, - stats::TimedStat, - utils::*, - DefId, HashMap, HashSet, MarkerCtx, + ana::inline_judge::InlineJudge, ann::MarkerAnnotation, desc::*, discover::FnToAnalyze, + utils::*, DefId, HashMap, HashSet, MarkerCtx, }; -use flowistry::mir::placeinfo::PlaceInfo; use flowistry_pdg::SourceUse; use paralegal_spdg::{Node, SPDGStats}; -use rustc_utils::cache::Cache; - -use std::{ - cell::RefCell, - fmt::Display, - rc::Rc, - time::{Duration, Instant}, +use rustc_hir::{def, def_id::LocalDefId}; +use rustc_middle::{ + mir::{self, Location}, + ty::{self, Instance, ParamEnv, TyCtxt}, }; -use self::call_string_resolver::CallStringResolver; +use std::{cell::RefCell, fmt::Display, rc::Rc}; -use super::{default_index, path_for_item, src_loc_for_span, SPDGGenerator}; +use super::{ + default_index, + metadata::{AsyncStatus, BodyInfo, Error}, + path_for_item, src_loc_for_span, RustcInstructionKind, SPDGGenerator, +}; use anyhow::{anyhow, Result}; use either::Either; use flowistry_pdg_construction::{ - determine_async, graph::{DepEdge, DepEdgeKind, DepGraph, DepNode}, - is_async_trait_fn, match_async_trait_assign, CallChangeCallback, CallChanges, CallInfo, - InlineMissReason, PdgParams, + utils::try_monomorphize, + CallChangeCallback, CallChanges, CallInfo, EmittableError, InlineMissReason, SkipCall::Skip, + UnwrapEmittable, }; use petgraph::{ visit::{IntoNodeReferences, NodeIndexable, NodeRef}, @@ -64,27 +58,20 @@ pub struct GraphConverter<'tcx, 'a, C> { /// The converted graph we are creating spdg: SPDGImpl, marker_assignments: HashMap>, - call_string_resolver: call_string_resolver::CallStringResolver<'tcx>, - stats: SPDGStats, - place_info_cache: PlaceInfoCache<'tcx>, } -pub type PlaceInfoCache<'tcx> = Rc>>; - impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { /// Initialize a new converter by creating an initial PDG using flowistry. pub fn new_with_flowistry( generator: &'a SPDGGenerator<'tcx>, known_def_ids: &'a mut C, target: &'a FnToAnalyze, - place_info_cache: PlaceInfoCache<'tcx>, ) -> Result { - let local_def_id = target.def_id.expect_local(); - let start = Instant::now(); - let (dep_graph, stats) = Self::create_flowistry_graph(generator, local_def_id)?; - generator - .stats - .record_timed(TimedStat::Flowistry, start.elapsed()); + let local_def_id = target.def_id; + let dep_graph = Self::create_flowistry_graph(generator, local_def_id).map_err(|e| { + e.emit(generator.tcx); + anyhow!("construction error") + })?; if generator.opts.dbg().dump_flowistry_pdg() { dep_graph.generate_graphviz(format!( @@ -103,9 +90,6 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { types: Default::default(), spdg: Default::default(), marker_assignments: Default::default(), - call_string_resolver: CallStringResolver::new(generator.tcx, local_def_id), - stats, - place_info_cache, }) } @@ -119,7 +103,10 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { /// Is the top-level function (entrypoint) an `async fn` fn entrypoint_is_async(&self) -> bool { - entrypoint_is_async(self.tcx(), self.local_def_id) + self.generator + .metadata_loader + .get_asyncness(self.local_def_id.to_def_id()) + .is_async() } /// Insert this node into the converted graph, return it's auto-assigned id @@ -151,30 +138,25 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { } } - fn place_info(&self, def_id: LocalDefId) -> &PlaceInfo<'tcx> { - self.place_info_cache.get(def_id, |_| { - PlaceInfo::build( - self.tcx(), - def_id.to_def_id(), - self.tcx().body_for_def_id(def_id).unwrap(), - ) - }) - } - /// Find direct annotations on this node and register them in the marker map. fn node_annotations(&mut self, old_node: Node, weight: &DepNode<'tcx>) { let leaf_loc = weight.at.leaf(); let node = self.new_node_for(old_node); - let body = &self.tcx().body_for_def_id(leaf_loc.function).unwrap().body; - let graph = self.dep_graph.clone(); + let body = self + .generator + .metadata_loader + .get_body_info(leaf_loc.function) + .unwrap(); + let monos = self.generator.metadata_loader.get_mono(weight.at).unwrap(); + match leaf_loc.location { RichLocation::Start if matches!(body.local_kind(weight.place.local), mir::LocalKind::Arg) => { - let function_id = leaf_loc.function.to_def_id(); + let function_id = leaf_loc.function; let arg_num = weight.place.local.as_u32() - 1; self.known_def_ids.extend(Some(function_id)); @@ -183,27 +165,24 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { }); } RichLocation::End if weight.place.local == mir::RETURN_PLACE => { - let function_id = leaf_loc.function.to_def_id(); + let function_id = leaf_loc.function; self.known_def_ids.extend(Some(function_id)); self.register_annotations_for_function(node, function_id, |ann| { ann.refinement.on_return() }); } RichLocation::Location(loc) => { - let stmt_at_loc = body.stmt_at(loc); - if let crate::Either::Right( - term @ mir::Terminator { - kind: mir::TerminatorKind::Call { destination, .. }, - .. - }, - ) = stmt_at_loc - { - let res = self.call_string_resolver.resolve(weight.at); - let (fun, ..) = res - .try_monomorphize(self.tcx(), self.tcx().param_env(res.def_id()), term) - .as_instance_and_args(self.tcx()) - .unwrap(); - self.known_def_ids.extend(Some(fun.def_id())); + let instruction = body.instruction_at(loc); + if let RustcInstructionKind::FunctionCall(f) = instruction.kind { + let (def_id, args) = flowistry_pdg_construction::utils::type_as_fn( + self.tcx(), + f.instantiate(self.tcx(), monos), + ) + .unwrap(); + let f = Instance::resolve(self.tcx(), ParamEnv::reveal_all(), def_id, args) + .unwrap() + .map_or(def_id, |i| i.def_id()); + self.known_def_ids.extend(Some(f)); // Question: Could a function with no input produce an // output that has aliases? E.g. could some place, where the @@ -211,38 +190,23 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { // this function call be affected/modified by this call? If // so, that location would also need to have this marker // attached - let needs_return_markers = weight.place.local == destination.local - || graph - .graph - .edges_directed(old_node, Direction::Incoming) - .any(|e| { - if weight.at != e.weight().at { - // Incoming edges are either from our operation or from control flow - let at = e.weight().at; - debug_assert!( - at.leaf().function == leaf_loc.function - && if let RichLocation::Location(loc) = - at.leaf().location - { - matches!( - body.stmt_at(loc), - Either::Right(mir::Terminator { - kind: mir::TerminatorKind::SwitchInt { .. }, - .. - }) - ) - } else { - false - } - ); - false - } else { - e.weight().target_use.is_return() - } - }); + // + // Also yikes. This should have better detection of whether + // a place is (part of) a function return + let mut in_edges = graph + .graph + .edges_directed(old_node, Direction::Incoming) + .filter(|e| e.weight().kind == DepEdgeKind::Data); + let needs_return_markers = in_edges.clone().next().is_none() + || in_edges.any(|e| { + let at = e.weight().at; + #[cfg(debug_assertions)] + assert_edge_location_invariant(self.tcx(), at, body, weight.at); + weight.at == at && e.weight().target_use.is_return() + }); if needs_return_markers { - self.register_annotations_for_function(node, fun.def_id(), |ann| { + self.register_annotations_for_function(node, f, |ann| { ann.refinement.on_return() }); } @@ -251,7 +215,7 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { let SourceUse::Argument(arg) = e.weight().source_use else { continue; }; - self.register_annotations_for_function(node, fun.def_id(), |ann| { + self.register_annotations_for_function(node, f, |ann| { ann.refinement.on_argument().contains(arg as u32).unwrap() }); } @@ -266,26 +230,21 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { &self, at: CallString, place: mir::PlaceRef<'tcx>, - ) -> Option> { + span: rustc_span::Span, + ) -> Result>, Error<'tcx>> { let tcx = self.tcx(); - let locations = at.iter_from_root().collect::>(); - let (last, mut rest) = locations.split_last().unwrap(); - - if self.entrypoint_is_async() { - let (first, tail) = rest.split_first().unwrap(); - // The body of a top-level `async` function binds a closure to the - // return place `_0`. Here we expect are looking at the statement - // that does this binding. - assert!(expect_stmt_at(self.tcx(), *first).is_left()); - rest = tail; - } + let body = self + .generator + .metadata_loader + .get_body_info(at.leaf().function) + .unwrap(); // So actually we're going to check the base place only, because // Flowistry sometimes tracks subplaces instead but we want the marker // from the base place. - let place = if self.entrypoint_is_async() && place.local.as_u32() == 1 && rest.len() == 1 { + let place = if self.entrypoint_is_async() && place.local.as_u32() == 1 && at.len() == 2 { if place.projection.is_empty() { - return None; + return Ok(None); } // in the case of targeting the top-level async closure (e.g. async args) // we'll keep the first projection. @@ -297,12 +256,25 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { place.local.into() }; - let resolution = self.call_string_resolver.resolve(at); - - // Thread through each caller to recover generic arguments - let body = tcx.body_for_def_id(last.function).unwrap(); - let raw_ty = place.ty(&body.body, tcx); - Some(*resolution.try_monomorphize(tcx, ty::ParamEnv::reveal_all(), &raw_ty)) + let raw_ty = place.ty(body, tcx); + let function = at.leaf().function; + // println!( + // "Resolving {raw_ty:?} for place {place:?} with generics {generics:?} in {function:?}", + // ); + let generics = self.generator.metadata_loader.get_mono(at).unwrap(); + trace!("Determining type for place {place:?} at {at} with raw type {raw_ty:?} and generics {generics:?}"); + let instance = Instance::resolve( + tcx, + tcx.param_env_reveal_all_normalized(function), + function, + generics, + ) + .unwrap() + .unwrap(); + let resolution = try_monomorphize(instance, tcx, ty::ParamEnv::reveal_all(), &raw_ty, span) + .map_err(|e| Error::ConstructionErrors(vec![e]))?; + //println!("Resolved to {resolution:?}"); + Ok(Some(resolution)) } /// Fetch annotations item identified by this `id`. @@ -318,20 +290,21 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { function: DefId, mut filter: impl FnMut(&MarkerAnnotation) -> bool, ) { + trace!("Checking annotations for {node:?} on function {function:?}"); let parent = get_parent(self.tcx(), function); let marker_ctx = self.marker_ctx().clone(); - self.register_markers( - node, - marker_ctx - .combined_markers(function) - .chain( - parent - .into_iter() - .flat_map(|parent| marker_ctx.combined_markers(parent)), - ) - .filter(|ann| filter(ann)) - .map(|ann| ann.marker), - ); + let markers = marker_ctx + .combined_markers(function) + .chain( + parent + .into_iter() + .flat_map(|parent| marker_ctx.combined_markers(parent)), + ) + .filter(|ann| filter(ann)) + .map(|ann| Identifier::new_intern(ann.marker.as_str())) + .collect::>(); + trace!("Found markers {markers:?}"); + self.register_markers(node, markers); self.known_def_ids.extend(parent); } @@ -339,14 +312,21 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { fn handle_node_types(&mut self, old_node: Node, weight: &DepNode<'tcx>) { let i = self.new_node_for(old_node); - let Some(place_ty) = self.determine_place_type(weight.at, weight.place.as_ref()) else { + let Some(place_ty) = self + .determine_place_type(weight.at, weight.place.as_ref(), weight.span) + .unwrap_emittable(self.tcx()) + else { return; }; - let place_info = self.place_info(weight.at.leaf().function); - let deep = !place_info.children(weight.place).is_empty(); + // Restore after fixing https://github.com/brownsys/paralegal/issues/138 + //let deep = !weight.is_split; + let deep = true; let mut node_types = self.type_is_marked(place_ty, deep).collect::>(); for (p, _) in weight.place.iter_projections() { - if let Some(place_ty) = self.determine_place_type(weight.at, p) { + if let Some(place_ty) = self + .determine_place_type(weight.at, p, weight.span) + .unwrap_emittable(self.tcx()) + { node_types.extend(self.type_is_marked(place_ty, false)); } } @@ -364,77 +344,33 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { } /// Create an initial flowistry graph for the function identified by - /// `local_def_id`. + /// `def_id`. fn create_flowistry_graph( generator: &SPDGGenerator<'tcx>, - local_def_id: LocalDefId, - ) -> Result<(DepGraph<'tcx>, SPDGStats)> { - let tcx = generator.tcx; - let opts = generator.opts; - let stat_wrap = Rc::new(RefCell::new(( - SPDGStats { - unique_functions: 0, - unique_locs: 0, - analyzed_functions: 0, - analyzed_locs: 0, - inlinings_performed: 0, - construction_time: Duration::ZERO, - conversion_time: Duration::ZERO, - }, - Default::default(), - ))); - // TODO: I don't like that I have to do that here. Clean this up - let target = determine_async(tcx, local_def_id, &tcx.body_for_def_id(local_def_id)?.body) - .map_or(local_def_id, |res| res.0.def_id().expect_local()); - // Make sure we count outselves - record_inlining(&stat_wrap, tcx, target, false); - let stat_wrap_copy = stat_wrap.clone(); - let judge = generator.inline_judge.clone(); - let params = PdgParams::new(tcx, local_def_id) - .map_err(|_| anyhow!("unable to contruct PDG for {local_def_id:?}"))? - .with_call_change_callback(MyCallback { - judge, - stat_wrap, - tcx, - }) - .with_dump_mir(generator.opts.dbg().dump_mir()); - - if opts.dbg().dump_mir() { - let mut file = std::fs::File::create(format!( - "{}.mir", - tcx.def_path_str(local_def_id.to_def_id()) - ))?; - mir::pretty::write_mir_fn( - tcx, - &tcx.body_for_def_id_default_policy(local_def_id) - .ok_or_else(|| anyhow!("Body not found"))? - .body, - &mut |_, _| Ok(()), - &mut file, - )? - } - let flowistry_time = Instant::now(); - let pdg = flowistry_pdg_construction::compute_pdg(params); - let (mut stats, _) = Rc::into_inner(stat_wrap_copy).unwrap().into_inner(); - stats.construction_time = flowistry_time.elapsed(); - - Ok((pdg, stats)) + def_id: LocalDefId, + ) -> Result, Error<'tcx>> { + // We only demand a local def id to ensure that this is always called in + // the same crate. + let def_id = def_id.to_def_id(); + Ok(match generator.metadata_loader.get_asyncness(def_id) { + AsyncStatus::NotAsync => generator.metadata_loader.get_partial_graph(def_id), + AsyncStatus::Async { + generator_id, + asyncness: _, + } => generator.metadata_loader.get_partial_graph(generator_id), + }? + .to_petgraph()) } /// Consume the generator and compile the [`SPDG`]. pub fn make_spdg(mut self) -> SPDG { - let start = Instant::now(); self.make_spdg_impl(); let arguments = self.determine_arguments(); let return_ = self.determine_return(); - self.generator - .stats - .record_timed(TimedStat::Conversion, start.elapsed()); - self.stats.conversion_time = start.elapsed(); SPDG { path: path_for_item(self.local_def_id.to_def_id(), self.tcx()), graph: self.spdg, - id: self.local_def_id, + id: self.local_def_id.to_def_id(), name: Identifier::new(self.target.name()), arguments, markers: self @@ -448,7 +384,7 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { .into_iter() .map(|(k, v)| (k, Types(v.into()))) .collect(), - statistics: self.stats, + statistics: Default::default(), } } @@ -460,16 +396,12 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { let tcx = self.tcx(); for (i, weight) in input.node_references() { - let at = weight.at.leaf(); - let body = &tcx.body_for_def_id(at.function).unwrap().body; - - let node_span = body.local_decls[weight.place.local].source_info.span; self.register_node( i, NodeInfo { at: weight.at, description: format!("{:?}", weight.place), - span: src_loc_for_span(node_span, tcx), + span: src_loc_for_span(weight.span, tcx), }, ); self.node_annotations(i, weight); @@ -568,10 +500,45 @@ impl<'a, 'tcx, C: Extend> GraphConverter<'tcx, 'a, C> { } } -struct MyCallback<'tcx> { - judge: InlineJudge<'tcx>, - stat_wrap: StatStracker, +#[cfg(debug_assertions)] +fn assert_edge_location_invariant<'tcx>( tcx: TyCtxt<'tcx>, + at: CallString, + body: &BodyInfo<'tcx>, + location: CallString, +) { + // Normal case. The edge is introduced where the operation happens + if location == at { + return; + } + // Control flow case. The edge is introduced at the `switchInt` + if let RichLocation::Location(loc) = at.leaf().location { + if at.leaf().function == location.leaf().function + && matches!( + body.instruction_at(loc).kind, + RustcInstructionKind::SwitchInt + ) + { + return; + } + } + let mut msg = tcx.sess.struct_span_fatal( + body.span_of(at.leaf().location), + format!( + "This operation is performed in a different location: {}", + at + ), + ); + msg.span_note( + body.span_of(location.leaf().location), + format!("Expected to originate here: {}", at), + ); + msg.emit() +} + +pub(super) struct MyCallback<'tcx> { + pub(super) judge: InlineJudge<'tcx>, + pub(super) tcx: TyCtxt<'tcx>, } impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { @@ -591,23 +558,15 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { if skip { changes = changes.with_skip(Skip); - } else { - record_inlining( - &self.stat_wrap, - self.tcx, - info.callee.def_id().expect_local(), - info.is_cached, - ) } changes } fn on_inline_miss( &self, - resolution: FnResolution<'tcx>, + resolution: Instance<'tcx>, loc: Location, - parent: FnResolution<'tcx>, - call_string: Option, + parent: Instance<'tcx>, reason: InlineMissReason, ) { let body = self @@ -622,9 +581,8 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { self.tcx.sess.span_err( span, format!( - "Could not inline this function call in {:?}, at {} because {reason:?}. {}", + "Could not inline this function call in {:?}, because {reason:?}. {}", parent.def_id(), - call_string.map_or("root".to_owned(), |c| c.to_string()), Print(|f| if markers_reachable.is_empty() { f.write_str("No markers are reachable") } else { @@ -637,8 +595,10 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> { } } +#[allow(dead_code)] type StatStracker = Rc)>>; +#[allow(dead_code)] fn record_inlining(tracker: &StatStracker, tcx: TyCtxt<'_>, def_id: LocalDefId, is_in_cache: bool) { let mut borrow = tracker.borrow_mut(); let (stats, loc_set) = &mut *borrow; @@ -663,15 +623,6 @@ fn record_inlining(tracker: &StatStracker, tcx: TyCtxt<'_>, def_id: LocalDefId, } } -/// Find the statement at this location or fail. -fn expect_stmt_at(tcx: TyCtxt, loc: GlobalLocation) -> Either<&mir::Statement, &mir::Terminator> { - let body = &tcx.body_for_def_id(loc.function).unwrap().body; - let RichLocation::Location(loc) = loc.location else { - unreachable!(); - }; - body.stmt_at(loc) -} - /// If `did` is a method of an `impl` of a trait, then return the `DefId` that /// refers to the method on the trait definition. fn get_parent(tcx: TyCtxt, did: DefId) -> Option { @@ -689,97 +640,3 @@ fn get_parent(tcx: TyCtxt, did: DefId) -> Option { .def_id; Some(id) } - -fn entrypoint_is_async(tcx: TyCtxt, local_def_id: LocalDefId) -> bool { - tcx.asyncness(local_def_id).is_async() - || is_async_trait_fn( - tcx, - local_def_id.to_def_id(), - &tcx.body_for_def_id(local_def_id).unwrap().body, - ) -} - -mod call_string_resolver { - //! Resolution of [`CallString`]s to [`FnResolution`]s. - //! - //! This is a separate mod so that we can use encapsulation to preserve the - //! internal invariants of the resolver. - - use flowistry_pdg::{rustc_portable::LocalDefId, CallString}; - use flowistry_pdg_construction::{try_resolve_function, FnResolution}; - use rustc_utils::cache::Cache; - - use crate::{Either, TyCtxt}; - - use super::{map_either, match_async_trait_assign, AsFnAndArgs}; - - /// Cached resolution of [`CallString`]s to [`FnResolution`]s. - /// - /// Only valid for a single controller. Each controller should initialize a - /// new resolver. - pub struct CallStringResolver<'tcx> { - cache: Cache>, - tcx: TyCtxt<'tcx>, - entrypoint_is_async: bool, - } - - impl<'tcx> CallStringResolver<'tcx> { - /// Tries to resolve to the monomophized function in which this call - /// site exists. That is to say that `return.def_id() == - /// cs.leaf().function`. - /// - /// Unlike `Self::resolve_internal` this can be called on any valid - /// [`CallString`]. - pub fn resolve(&self, cs: CallString) -> FnResolution<'tcx> { - let (this, opt_prior_loc) = cs.pop(); - if let Some(prior_loc) = opt_prior_loc { - if prior_loc.len() == 1 && self.entrypoint_is_async { - FnResolution::Partial(this.function.to_def_id()) - } else { - self.resolve_internal(prior_loc) - } - } else { - FnResolution::Partial(this.function.to_def_id()) - } - } - - pub fn new(tcx: TyCtxt<'tcx>, entrypoint: LocalDefId) -> Self { - Self { - cache: Default::default(), - tcx, - entrypoint_is_async: super::entrypoint_is_async(tcx, entrypoint), - } - } - - /// This resolves the monomorphized function *being called at* this call - /// site. - /// - /// This function is internal because it panics if `cs.leaf().location` - /// is not either a function call or a statement where an async closure - /// is created and assigned. - fn resolve_internal(&self, cs: CallString) -> FnResolution<'tcx> { - *self.cache.get(cs, |_| { - let this = cs.leaf(); - let prior = self.resolve(cs); - - let tcx = self.tcx; - - let base_stmt = super::expect_stmt_at(tcx, this); - let param_env = tcx.param_env_reveal_all_normalized(prior.def_id()); - let normalized = map_either( - base_stmt, - |stmt| prior.try_monomorphize(tcx, param_env, stmt), - |term| prior.try_monomorphize(tcx, param_env, term), - ); - let res = match normalized { - Either::Right(term) => term.as_ref().as_instance_and_args(tcx).unwrap().0, - Either::Left(stmt) => { - let (def_id, generics) = match_async_trait_assign(stmt.as_ref()).unwrap(); - try_resolve_function(tcx, def_id, param_env, generics) - } - }; - res - }) - } - } -} diff --git a/crates/paralegal-flow/src/ana/inline_judge.rs b/crates/paralegal-flow/src/ana/inline_judge.rs index 564914f01e..0d4414dfef 100644 --- a/crates/paralegal-flow/src/ana/inline_judge.rs +++ b/crates/paralegal-flow/src/ana/inline_judge.rs @@ -34,11 +34,7 @@ impl<'tcx> InlineJudge<'tcx> { let marker_target = info.async_parent.unwrap_or(info.callee); let marker_target_def_id = marker_target.def_id(); match self.analysis_control.inlining_depth() { - _ if self.marker_ctx.is_marked(marker_target_def_id) - || !marker_target_def_id.is_local() => - { - false - } + _ if self.marker_ctx.is_marked(marker_target_def_id) => false, InliningDepth::Adaptive => self .marker_ctx .has_transitive_reachable_markers(marker_target), diff --git a/crates/paralegal-flow/src/ana/metadata.rs b/crates/paralegal-flow/src/ana/metadata.rs new file mode 100644 index 0000000000..85df051779 --- /dev/null +++ b/crates/paralegal-flow/src/ana/metadata.rs @@ -0,0 +1,517 @@ +//! Per-crate intermediate data we store. +//! +//! [`Metadata`] is what gets stored, whereas a [`MetadataLoader`] is +//! responsible for reading/writing this data. + +use crate::{ + ann::{db::MarkerDatabase, Annotation}, + consts::INTERMEDIATE_ARTIFACT_EXT, + desc::*, + discover::{CollectingVisitor, FnToAnalyze}, + Args, DefId, HashMap, MarkerCtx, +}; + +use std::path::Path; +use std::{fs::File, io::Read, rc::Rc}; + +use construct::determine_async; +use flowistry_pdg_construction::{ + self as construct, default_emit_error, graph::InternedString, AsyncType, EmittableError, + GraphLoader, MemoPdgConstructor, PartialGraph, +}; + +use rustc_hash::FxHashMap; +use rustc_hir::def_id::{CrateNum, DefIndex, LocalDefId, LOCAL_CRATE}; +use rustc_index::IndexVec; +use rustc_macros::{Decodable, Encodable, TyDecodable, TyEncodable}; +use rustc_middle::{ + mir::{ + BasicBlock, BasicBlockData, HasLocalDecls, Local, LocalDecl, LocalDecls, LocalKind, + Location, Statement, Terminator, TerminatorKind, + }, + ty::{EarlyBinder, GenericArgsRef, Ty, TyCtxt}, +}; +use rustc_serialize::{Decodable, Encodable}; + +use rustc_utils::{cache::Cache, mir::borrowck_facts}; + +use super::{ + encoder::{ParalegalDecoder, ParalegalEncoder}, + graph_converter::MyCallback, + inline_judge::InlineJudge, +}; + +/// Manager responsible for reading and writing [`Metadata`] artifacts. +pub struct MetadataLoader<'tcx> { + tcx: TyCtxt<'tcx>, + cache: Cache>>, +} + +/// The types of errors that can arise from interacting with the [`MetadataLoader`]. +#[derive(Debug)] +pub enum Error<'tcx> { + PdgForItemMissing(DefId), + MetadataForCrateMissing(CrateNum), + NoGenericsKnownForCallSite(CallString), + NoSuchItemInCate(DefId), + ConstructionErrors(Vec>), +} + +impl<'tcx> EmittableError<'tcx> for Error<'tcx> { + fn msg(&self, tcx: TyCtxt<'tcx>, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use Error::*; + match self { + PdgForItemMissing(def) => { + write!(f, "found no pdg for item {}", tcx.def_path_str(*def)) + } + MetadataForCrateMissing(krate) => { + write!(f, "no metadata found for crate {}", tcx.crate_name(*krate)) + } + NoGenericsKnownForCallSite(cs) => { + write!(f, "no generics known for call site {cs}") + } + NoSuchItemInCate(it) => write!( + f, + "no such item {} found in crate {}", + tcx.def_path_debug_str(*it), + tcx.crate_name(it.krate) + ), + ConstructionErrors(_e) => f.write_str("construction errors"), + } + } + + fn emit(&self, tcx: TyCtxt<'tcx>) { + if let Error::ConstructionErrors(e) = self { + for e in e { + e.emit(tcx); + } + return; + } + default_emit_error(self, tcx) + } +} + +use Error::*; + +impl<'tcx> GraphLoader<'tcx> for MetadataLoader<'tcx> { + fn load( + &self, + function: DefId, + ) -> Result>, Vec>> { + let Ok(meta) = self.get_metadata(function.krate) else { + return Ok(None); + }; + let res = meta + .pdgs + .get(&function.index) + .ok_or_else(|| vec![construct::Error::CrateExistsButItemIsNotFound { function }])? + .as_ref() + .map_err(Clone::clone)?; + + Ok(Some(&res.graph)) + } +} + +impl<'tcx> MetadataLoader<'tcx> { + /// Traverse the items of this crate, create PDGs and collect other relevant + /// information about them. Write the metadata to disk, but also register + /// them with the loader itself for downstream analyses. + /// + /// Returns which functions should be emitted for policy enforcement (e.g. + /// analysis targets) and a context of discovered markers suitable for query + /// during that analysis. + pub fn collect_and_emit_metadata( + self: Rc, + args: &'static Args, + path: Option>, + ) -> (Vec, MarkerCtx<'tcx>) { + let tcx = self.tcx; + let mut collector = CollectingVisitor::new(tcx, args, self.clone()); + collector.run(); + let emit_targets = collector.emit_target_collector; + let marker_ctx: MarkerCtx = collector.marker_ctx.into(); + let mut constructor = MemoPdgConstructor::new(tcx, self.clone()); + constructor + .with_call_change_callback(MyCallback { + tcx, + judge: InlineJudge::new(marker_ctx.clone(), tcx, args.anactrl()), + }) + .with_dump_mir(args.dbg().dump_mir()); + let pdgs = emit_targets + .into_iter() + .map(|t| { + println!("Constructing for {:?}", tcx.def_path_str(t)); + let graph = constructor.construct_root(t).map(|graph| { + let body = borrowck_facts::get_body_with_borrowck_facts(tcx, t); + // MONOMORPHIZATION: normally we need to monomorphize the + // body, but here we don't because generics can't change + // whether a function has async structure. + let async_status = determine_async(tcx, t, &body.body) + .map(|(inst, _loc, asyncness)| AsyncStatus::Async { + generator_id: inst.def_id().index, + asyncness, + }) + .unwrap_or(AsyncStatus::NotAsync); + PdgInfo { + graph: graph.clone(), + async_status, + } + }); + (t.local_def_index, graph) + }) + .collect::>(); + let meta = Metadata::from_pdgs(tcx, pdgs, marker_ctx.db()); + if let Some(path) = path { + let path = path.as_ref(); + debug!("Writing metadata to {}", path.display()); + meta.write(path, tcx); + } + self.cache.get(LOCAL_CRATE, |_| Some(meta)); + (collector.functions_to_analyze, marker_ctx) + } + + pub fn get_annotations(&self, key: DefId) -> &[Annotation] { + (|| { + Some( + self.get_metadata(key.krate) + .ok()? + .local_annotations + .get(&key.index)? + .as_slice(), + ) + })() + .unwrap_or(&[]) + } + + pub fn all_annotations<'a>(&'a self) -> impl Iterator { + let b = self.cache.borrow(); + + // Safety: While we're keeping references to the borrow above, we only + // keep references to values behind `Pin>` which are guaranteed + // not to move. So even if the borrow is modified, these references are + // still valid. + // + // In terms of race conditions: this is a cache which never overwrites values. + let metadatas = unsafe { + std::mem::transmute::< + Vec<(CrateNum, &_)>, + Vec<(CrateNum, &'a HashMap>)>, + >( + b.iter() + .filter_map(|(k, v)| Some((*k, &(**(v.as_ref()?)).as_ref()?.local_annotations))) + .collect::>(), + ) + }; + metadatas.into_iter().flat_map(|(krate, m)| { + m.iter() + .flat_map(move |(&index, v)| v.iter().map(move |v| (DefId { krate, index }, v))) + }) + } +} + +#[derive(Clone, Debug, TyEncodable, TyDecodable)] +pub struct PdgInfo<'tcx> { + pub graph: PartialGraph<'tcx>, + pub async_status: AsyncStatus, +} + +#[derive(Clone, Copy, Debug, Encodable, Decodable)] +pub enum AsyncStatus { + NotAsync, + Async { + generator_id: Def, + asyncness: AsyncType, + }, +} + +impl AsyncStatus { + pub fn is_async(&self) -> bool { + matches!(self, Self::Async { .. }) + } + + fn map_index(&self, f: impl FnOnce(&Def) -> D) -> AsyncStatus { + match self { + Self::NotAsync => AsyncStatus::NotAsync, + Self::Async { + generator_id, + asyncness, + } => AsyncStatus::Async { + generator_id: f(generator_id), + asyncness: *asyncness, + }, + } + } +} + +pub type PdgMap<'tcx> = FxHashMap, Vec>>>; + +/// Intermediate artifacts stored on disc for every crate. +/// +/// Contains PDGs and reduced information about the source code that is needed +/// downstream. +#[derive(Clone, Debug, TyEncodable, TyDecodable)] +pub struct Metadata<'tcx> { + pub pdgs: PdgMap<'tcx>, + pub bodies: FxHashMap>, + pub local_annotations: HashMap>, + pub reachable_markers: HashMap<(DefIndex, GenericArgsRef<'tcx>), Box<[InternedString]>>, +} + +impl<'tcx> Metadata<'tcx> { + fn write(&self, path: impl AsRef, tcx: TyCtxt<'tcx>) { + let mut encoder = ParalegalEncoder::new(path, tcx); + self.encode(&mut encoder); + encoder.finish() + } +} + +impl<'tcx> Metadata<'tcx> { + /// Given a set of PDGs created, query additional information we need to + /// record from rustc and return a serializable metadata artifact. + pub fn from_pdgs( + tcx: TyCtxt<'tcx>, + pdgs: PdgMap<'tcx>, + markers: &MarkerDatabase<'tcx>, + ) -> Self { + let mut bodies: FxHashMap = Default::default(); + for call_string in pdgs + .values() + .filter_map(|e| e.as_ref().ok()) + .flat_map(|subgraph| subgraph.graph.mentioned_call_string()) + { + for location in call_string.iter() { + if let Some(local) = location.function.as_local() { + bodies.entry(local.local_def_index).or_insert_with(|| { + let info = BodyInfo::from_body(tcx, local); + trace!("Created info for body {local:?}\n{info:?}"); + info + }); + } + } + } + let cache_borrow = markers.reachable_markers.borrow(); + Self { + pdgs, + bodies, + local_annotations: markers + .local_annotations + .iter() + .map(|(k, v)| (k.local_def_index, v.clone())) + .collect(), + reachable_markers: (*cache_borrow) + .iter() + .filter_map(|(inst, v)| { + let id = inst.def_id(); + let args = inst.args; + Some(( + (id.as_local()?.local_def_index, args), + (**(v.as_ref()?)).clone(), + )) + }) + .collect(), + } + } +} + +impl<'tcx> MetadataLoader<'tcx> { + pub fn new(tcx: TyCtxt<'tcx>) -> Rc { + Rc::new(Self { + tcx, + cache: Default::default(), + }) + } + + pub fn get_metadata(&self, key: CrateNum) -> Result<&Metadata<'tcx>, Error<'tcx>> { + let meta = self + .cache + .get(key, |_| { + let paths = self.tcx.crate_extern_paths(key); + for path in paths { + let path = path.with_extension(INTERMEDIATE_ARTIFACT_EXT); + let Ok(mut file) = File::open(path) else { + continue; + }; + let mut buf = Vec::new(); + file.read_to_end(&mut buf).unwrap(); + let mut decoder = ParalegalDecoder::new(self.tcx, buf.as_slice()); + let meta = Metadata::decode(&mut decoder); + return Some(meta); + } + None + }) + .as_ref() + .ok_or(MetadataForCrateMissing(key))?; + Ok(meta) + } + + pub fn get_partial_graph(&self, key: DefId) -> Result<&PartialGraph<'tcx>, Error<'tcx>> { + let meta = self.get_metadata(key.krate)?; + let result = meta.pdgs.get(&key.index).ok_or(PdgForItemMissing(key))?; + result + .as_ref() + .map_err(|e| Error::ConstructionErrors(e.clone())) + .map(|e| &e.graph) + } + + pub fn get_body_info(&self, key: DefId) -> Result<&BodyInfo<'tcx>, Error<'tcx>> { + let meta = self.get_metadata(key.krate)?; + let res = meta.bodies.get(&key.index).ok_or(NoSuchItemInCate(key)); + res + } + + pub fn get_mono(&self, cs: CallString) -> Result, Error<'tcx>> { + let key = cs.root().function; + self.get_partial_graph(key)? + .get_mono(cs) + .ok_or(NoGenericsKnownForCallSite(cs)) + } + + pub fn get_asyncness(&self, key: DefId) -> AsyncStatus { + (|| { + Some( + self.get_metadata(key.krate) + .ok()? + .pdgs + .get(&key.index)? + .as_ref() + .ok()? + .async_status + .map_index(|i| DefId { + krate: key.krate, + index: *i, + }), + ) + })() + .unwrap_or(AsyncStatus::NotAsync) + } +} + +/// Effectively a reduced MIR [`Body`](rustc_middle::mir::Body). +#[derive(Clone, Debug, TyEncodable, TyDecodable)] +pub struct BodyInfo<'tcx> { + pub arg_count: usize, + pub decls: IndexVec>, + pub instructions: IndexVec>>, + pub def_span: rustc_span::Span, +} + +impl<'tcx> BodyInfo<'tcx> { + pub fn from_body(tcx: TyCtxt<'tcx>, function_id: LocalDefId) -> Self { + let body_with_facts = borrowck_facts::get_body_with_borrowck_facts(tcx, function_id); + let body = &body_with_facts.body; + Self { + arg_count: body.arg_count, + decls: body.local_decls().to_owned(), + instructions: body + .basic_blocks + .iter() + .map(|bb| RustcInstructionInfo::from_basic_block(tcx, body, bb)) + .collect(), + def_span: tcx.def_span(function_id), + } + } +} + +#[derive(Clone, Copy, Debug, TyEncodable, TyDecodable)] +pub struct RustcInstructionInfo<'tcx> { + /// Classification of the instruction + pub kind: RustcInstructionKind<'tcx>, + /// The source code span + pub span: rustc_span::Span, + /// Textual rendering of the MIR + pub description: InternedString, +} + +impl<'tcx> RustcInstructionInfo<'tcx> { + pub fn from_statement(stmt: &Statement) -> Self { + Self { + kind: RustcInstructionKind::Statement, + span: stmt.source_info.span, + description: format!("{:?}", stmt.kind).into(), + } + } + + pub fn from_terminator( + tcx: TyCtxt<'tcx>, + local_decls: &impl HasLocalDecls<'tcx>, + term: &Terminator<'tcx>, + ) -> Self { + Self { + kind: match &term.kind { + TerminatorKind::Call { + func, + args: _, + destination: _, + target: _, + unwind: _, + call_source: _, + fn_span: _, + } => { + let op_ty = tcx.erase_regions(func.ty(local_decls, tcx)); + RustcInstructionKind::FunctionCall(EarlyBinder::bind(op_ty)) + } + TerminatorKind::SwitchInt { .. } => RustcInstructionKind::SwitchInt, + _ => RustcInstructionKind::Terminator, + }, + span: term.source_info.span, + description: format!("{:?}", term.kind).into(), + } + } + + pub fn from_basic_block( + tcx: TyCtxt<'tcx>, + local_decls: &impl HasLocalDecls<'tcx>, + bb: &BasicBlockData<'tcx>, + ) -> Vec { + let t = bb.terminator(); + bb.statements + .iter() + .map(Self::from_statement) + .chain([Self::from_terminator(tcx, local_decls, t)]) + .collect() + } +} + +/// The type of instructions we may encounter +#[derive(Debug, Clone, Copy, Eq, Ord, PartialOrd, PartialEq, TyEncodable, TyDecodable)] +pub enum RustcInstructionKind<'tcx> { + /// Some type of statement + Statement, + /// A function call. The type is guaranteed to be of function type + FunctionCall(EarlyBinder>), + /// A basic block terminator + Terminator, + /// The switch int terminator + SwitchInt, +} + +impl<'tcx> BodyInfo<'tcx> { + pub fn local_kind(&self, local: Local) -> LocalKind { + let local = local.as_usize(); + assert!(local < self.decls.len()); + if local == 0 { + LocalKind::ReturnPointer + } else if local < self.arg_count + 1 { + LocalKind::Arg + } else { + LocalKind::Temp + } + } + + pub fn instruction_at(&self, location: Location) -> RustcInstructionInfo<'tcx> { + self.instructions[location.block][location.statement_index] + } + + pub fn span_of(&self, loc: RichLocation) -> rustc_span::Span { + match loc { + RichLocation::Location(loc) => self.instruction_at(loc).span, + _ => self.def_span, + } + } +} + +impl<'tcx> HasLocalDecls<'tcx> for BodyInfo<'tcx> { + fn local_decls(&self) -> &LocalDecls<'tcx> { + &self.decls + } +} diff --git a/crates/paralegal-flow/src/ana/mod.rs b/crates/paralegal-flow/src/ana/mod.rs index 9b87d47829..65006accaf 100644 --- a/crates/paralegal-flow/src/ana/mod.rs +++ b/crates/paralegal-flow/src/ana/mod.rs @@ -4,40 +4,43 @@ //! [`CollectingVisitor`](crate::discover::CollectingVisitor) and then calling //! [`analyze`](SPDGGenerator::analyze). +use std::{rc::Rc, time::Duration}; + use crate::{ ann::{Annotation, MarkerAnnotation}, desc::*, discover::FnToAnalyze, - rust::{hir::def, *}, - stats::{Stats, TimedStat}, utils::*, DefId, HashMap, HashSet, LogLevelConfig, MarkerCtx, Symbol, }; -use std::time::Instant; - use anyhow::Result; -use either::Either; + use itertools::Itertools; use petgraph::visit::GraphBase; + +use rustc_hir::def; +use rustc_middle::ty::{Instance, ParamEnv, TyCtxt}; use rustc_span::{FileNameDisplayPreference, Span as RustSpan}; +mod encoder; mod graph_converter; mod inline_judge; +mod metadata; use graph_converter::GraphConverter; +use metadata::RustcInstructionKind; -use self::{graph_converter::PlaceInfoCache, inline_judge::InlineJudge}; +pub use metadata::MetadataLoader; /// Read-only database of information the analysis needs. /// /// [`Self::analyze`] serves as the main entrypoint to SPDG generation. pub struct SPDGGenerator<'tcx> { - pub inline_judge: InlineJudge<'tcx>, pub opts: &'static crate::Args, pub tcx: TyCtxt<'tcx>, - stats: Stats, - place_info_cache: PlaceInfoCache<'tcx>, + marker_ctx: MarkerCtx<'tcx>, + metadata_loader: Rc>, } impl<'tcx> SPDGGenerator<'tcx> { @@ -45,19 +48,18 @@ impl<'tcx> SPDGGenerator<'tcx> { marker_ctx: MarkerCtx<'tcx>, opts: &'static crate::Args, tcx: TyCtxt<'tcx>, - stats: Stats, + metadata_loader: Rc>, ) -> Self { Self { - inline_judge: InlineJudge::new(marker_ctx, tcx, opts.anactrl()), + marker_ctx, opts, tcx, - stats, - place_info_cache: Default::default(), + metadata_loader, } } pub fn marker_ctx(&self) -> &MarkerCtx<'tcx> { - self.inline_judge.marker_ctx() + &self.marker_ctx } /// Perform the analysis for one `#[paralegal_flow::analyze]` annotated function and @@ -71,17 +73,12 @@ impl<'tcx> SPDGGenerator<'tcx> { known_def_ids: &mut impl Extend, ) -> Result<(Endpoint, SPDG)> { info!("Handling target {}", self.tcx.def_path_str(target.def_id)); - let local_def_id = target.def_id.expect_local(); + let local_def_id = target.def_id; - let converter = GraphConverter::new_with_flowistry( - self, - known_def_ids, - target, - self.place_info_cache.clone(), - )?; + let converter = GraphConverter::new_with_flowistry(self, known_def_ids, target)?; let spdg = converter.make_spdg(); - Ok((local_def_id, spdg)) + Ok((local_def_id.to_def_id(), spdg)) } /// Main analysis driver. Essentially just calls [`Self::handle_target`] @@ -117,13 +114,7 @@ impl<'tcx> SPDGGenerator<'tcx> { }) }) .collect::>>() - .map(|controllers| { - let start = Instant::now(); - let desc = self.make_program_description(controllers, known_def_ids, &targets); - self.stats - .record_timed(TimedStat::Conversion, start.elapsed()); - desc - }) + .map(|controllers| self.make_program_description(controllers, known_def_ids, &targets)) } /// Given the PDGs and a record of all [`DefId`]s we've seen, compile @@ -137,33 +128,7 @@ impl<'tcx> SPDGGenerator<'tcx> { ) -> ProgramDescription { let tcx = self.tcx; - let instruction_info = self.collect_instruction_info(&controllers); - - let inlined_functions = instruction_info - .keys() - .filter_map(|l| l.function.to_def_id().as_local()) - .collect::>(); - let analyzed_spans = inlined_functions - .iter() - .copied() - // Because we now take the functions seen from the marker context - // this includes functions where the body is not present (e.g. `dyn`) - // so if we fail to retrieve the body in that case it is allowed. - // - // Prefereably in future we would filter what we get from the marker - // context better. - .filter_map(|f| { - let body = match tcx.body_for_def_id(f) { - Ok(b) => Some(b), - Err(BodyResolutionError::IsTraitAssocFn(_)) => None, - Err(e) => panic!("{e:?}"), - }?; - let span = body_span(&body.body); - Some((f, src_loc_for_span(span, tcx))) - }) - .collect::>(); - - known_def_ids.extend(inlined_functions.iter().map(|f| f.to_def_id())); + let instruction_info = self.collect_instruction_info(&controllers, &mut known_def_ids); let type_info = self.collect_type_info(); known_def_ids.extend(type_info.keys()); @@ -172,52 +137,19 @@ impl<'tcx> SPDGGenerator<'tcx> { .map(|id| (*id, def_info_for_item(*id, self.marker_ctx(), tcx))) .collect(); - let dedup_locs = analyzed_spans.values().map(Span::line_len).sum(); - let dedup_functions = analyzed_spans.len() as u32; - - let (seen_locs, seen_functions) = if self.opts.anactrl().inlining_depth().is_adaptive() { - let mut total_functions = inlined_functions; - let mctx = self.marker_ctx(); - total_functions.extend( - mctx.functions_seen() - .into_iter() - .map(|f| f.def_id()) - .filter(|f| !mctx.is_marked(f)) - .filter_map(|f| f.as_local()), - ); - let mut seen_functions = 0; - let locs = total_functions - .into_iter() - .filter_map(|f| Some(body_span(&tcx.body_for_def_id(f).ok()?.body))) - .map(|span| { - seen_functions += 1; - let (_, start_line, _, end_line, _) = - tcx.sess.source_map().span_to_location_info(span); - end_line - start_line + 1 - }) - .sum::() as u32; - (locs, seen_functions) - } else { - (dedup_locs, dedup_functions) - }; - type_info_sanity_check(&controllers, &type_info); ProgramDescription { type_info, instruction_info, controllers, def_info, - marker_annotation_count: self - .marker_ctx() - .all_annotations() - .filter_map(|m| m.1.either(Annotation::as_marker, Some)) - .count() as u32, - rustc_time: self.stats.get_timed(TimedStat::Rustc), - dedup_locs, - dedup_functions, - seen_functions, - seen_locs, - analyzed_spans, + rustc_time: Duration::ZERO, + marker_annotation_count: 0, + dedup_locs: 0, + dedup_functions: 0, + seen_locs: 0, + seen_functions: 0, + analyzed_spans: Default::default(), } } @@ -226,59 +158,71 @@ impl<'tcx> SPDGGenerator<'tcx> { fn collect_instruction_info( &self, controllers: &HashMap, - ) -> HashMap { + known_def_ids: &mut impl Extend, + ) -> HashMap { let all_instructions = controllers .values() .flat_map(|v| { v.graph .node_weights() - .flat_map(|n| n.at.iter()) - .chain(v.graph.edge_weights().flat_map(|e| e.at.iter())) + .map(|n| &n.at) + .chain(v.graph.edge_weights().map(|e| &e.at)) }) .collect::>(); all_instructions .into_iter() - .map(|i| { - let body = &self.tcx.body_for_def_id(i.function).unwrap().body; - - let (kind, description) = match i.location { - RichLocation::End => (InstructionKind::Return, "start".to_owned()), - RichLocation::Start => (InstructionKind::Start, "end".to_owned()), - RichLocation::Location(loc) => match body.stmt_at(loc) { - crate::Either::Right(term) => { - let kind = if let Ok((id, ..)) = term.as_fn_and_args(self.tcx) { - InstructionKind::FunctionCall(FunctionCallInfo { - id, - is_inlined: id.is_local(), - }) - } else { - InstructionKind::Terminator - }; - (kind, format!("{:?}", term.kind)) - } - crate::Either::Left(stmt) => { - (InstructionKind::Statement, format!("{:?}", stmt.kind)) - } - }, - }; - let rust_span = match i.location { + .map(|&n| { + let monos = self.metadata_loader.get_mono(n).unwrap(); + let body = self + .metadata_loader + .get_body_info(n.leaf().function) + .unwrap(); + let (kind, description, span) = match n.leaf().location { + RichLocation::End => { + (InstructionKind::Return, "start".to_owned(), body.def_span) + } + RichLocation::Start => { + (InstructionKind::Start, "end".to_owned(), body.def_span) + } RichLocation::Location(loc) => { - let expanded_span = match body.stmt_at(loc) { - crate::Either::Right(term) => term.source_info.span, - crate::Either::Left(stmt) => stmt.source_info.span, - }; - self.tcx - .sess - .source_map() - .stmt_span(expanded_span, body.span) + let instruction = body.instruction_at(loc); + ( + match instruction.kind { + RustcInstructionKind::SwitchInt => InstructionKind::SwitchInt, + RustcInstructionKind::FunctionCall(c) => { + InstructionKind::FunctionCall({ + let (id, generics) = + flowistry_pdg_construction::utils::type_as_fn( + self.tcx, + c.instantiate(self.tcx, monos), + ) + .unwrap(); + let instance_id = Instance::resolve( + self.tcx, + ParamEnv::reveal_all(), + id, + generics, + ) + .unwrap() + .map(|i| i.def_id()) + .unwrap_or(id); + known_def_ids.extend(Some(instance_id)); + FunctionCallInfo { id: instance_id } + }) + } + RustcInstructionKind::Statement => InstructionKind::Statement, + RustcInstructionKind::Terminator => InstructionKind::Terminator, + }, + (*instruction.description).clone(), + instruction.span, + ) } - RichLocation::Start | RichLocation::End => self.tcx.def_span(i.function), }; ( - i, + n, InstructionInfo { kind, - span: src_loc_for_span(rust_span, self.tcx), + span: src_loc_for_span(span, self.tcx), description: Identifier::new_intern(&description), }, ) @@ -296,16 +240,12 @@ impl<'tcx> SPDGGenerator<'tcx> { .fold_with( |id, _| (format!("{id:?}"), vec![], vec![]), |mut desc, _, ann| { - match ann { - Either::Right(MarkerAnnotation { refinement, marker }) - | Either::Left(Annotation::Marker(MarkerAnnotation { - refinement, - marker, - })) => { + match ann.as_ref() { + Annotation::Marker(MarkerAnnotation { refinement, marker }) => { assert!(refinement.on_self()); desc.2.push(*marker) } - Either::Left(Annotation::OType(id)) => desc.1.push(*id), + Annotation::OType(id) => desc.1.push(*id), _ => panic!("Unexpected type of annotation {ann:?}"), } desc @@ -318,7 +258,10 @@ impl<'tcx> SPDGGenerator<'tcx> { TypeDescription { rendering, otypes: otypes.into(), - markers, + markers: markers + .into_iter() + .map(|i| Identifier::new_intern(i.as_str())) + .collect(), }, ) }) @@ -394,7 +337,7 @@ fn path_for_item(id: DefId, tcx: TyCtxt) -> Box<[Identifier]> { let def_path = tcx.def_path(id); std::iter::once(Identifier::new(tcx.crate_name(def_path.krate))) .chain(def_path.data.iter().filter_map(|segment| { - use hir::definitions::DefPathDataName::*; + use rustc_hir::definitions::DefPathDataName::*; match segment.data.name() { Named(sym) => Some(Identifier::new(sym)), Anon { .. } => None, @@ -415,7 +358,7 @@ fn def_info_for_item(id: DefId, markers: &MarkerCtx, tcx: TyCtxt) -> DefInfo { .combined_markers(id) .cloned() .map(|ann| paralegal_spdg::MarkerAnnotation { - marker: ann.marker, + marker: Identifier::new_intern(ann.marker.as_str()), on_return: ann.refinement.on_return(), on_argument: ann.refinement.on_argument(), }) diff --git a/crates/paralegal-flow/src/ann/db.rs b/crates/paralegal-flow/src/ann/db.rs index 1b37e4e73c..6d0d5e6b9f 100644 --- a/crates/paralegal-flow/src/ann/db.rs +++ b/crates/paralegal-flow/src/ann/db.rs @@ -11,23 +11,30 @@ //! All interactions happen through the central database object: [`MarkerCtx`]. use crate::{ + ana::MetadataLoader, ann::{Annotation, MarkerAnnotation}, args::{Args, MarkerControl}, - ast::Attribute, consts, - hir::def::DefKind, - mir, ty, utils::{ - resolve::expect_resolve_string_to_def_id, AsFnAndArgs, FnResolution, FnResolutionExt, - IntoDefId, IntoHirId, MetaItemMatch, TyCtxtExt, TyExt, + resolve::expect_resolve_string_to_def_id, AsFnAndArgs, InstanceExt, IntoDefId, IntoHirId, + MetaItemMatch, TyCtxtExt, TyExt, }, DefId, Either, HashMap, HashSet, LocalDefId, TyCtxt, }; -use flowistry_pdg_construction::determine_async; -use paralegal_spdg::Identifier; +use flowistry_pdg_construction::{ + determine_async, + graph::InternedString, + utils::{is_async, try_monomorphize}, +}; +use rustc_ast::Attribute; +use rustc_hir::def::DefKind; +use rustc_middle::{ + mir, + ty::{self, Instance}, +}; use rustc_utils::cache::Cache; -use std::rc::Rc; +use std::{borrow::Cow, rc::Rc}; type ExternalMarkers = HashMap>; @@ -57,7 +64,7 @@ impl<'tcx> MarkerCtx<'tcx> { } #[inline] - fn db(&self) -> &MarkerDatabase<'tcx> { + pub fn db(&self) -> &MarkerDatabase<'tcx> { &self.0 } @@ -65,18 +72,23 @@ impl<'tcx> MarkerCtx<'tcx> { /// are present an empty slice is returned. /// /// Query is cached. - pub fn local_annotations(&self, def_id: LocalDefId) -> &[Annotation] { - self.db() - .local_annotations - .get(&self.defid_rewrite(def_id.to_def_id()).expect_local()) - .map_or(&[], |o| o.as_slice()) + fn attribute_annotations(&self, key: DefId) -> &[Annotation] { + let key = self.defid_rewrite(key); + if let Some(local) = key.as_local() { + self.db() + .local_annotations + .get(&local) + .map_or(&[], Vec::as_slice) + } else { + self.0.loader.get_annotations(key) + } } /// Retrieves any external markers on this item. If there are not such /// markers an empty slice is returned. /// /// THe external marker database is populated at construction. - pub fn external_markers(&self, did: D) -> &[MarkerAnnotation] { + fn external_markers(&self, did: D) -> &[MarkerAnnotation] { self.db() .external_annotations .get(&self.defid_rewrite(did.into_def_id(self.tcx()))) @@ -87,11 +99,9 @@ impl<'tcx> MarkerCtx<'tcx> { /// /// Queries are cached/precomputed so calling this repeatedly is cheap. pub fn combined_markers(&self, def_id: DefId) -> impl Iterator { - def_id - .as_local() - .map(|ldid| self.local_annotations(ldid)) - .into_iter() - .flat_map(|anns| anns.iter().flat_map(Annotation::as_marker)) + self.attribute_annotations(def_id) + .iter() + .filter_map(Annotation::as_marker) .chain(self.external_markers(def_id).iter()) } @@ -102,7 +112,7 @@ impl<'tcx> MarkerCtx<'tcx> { if matches!(def_kind, DefKind::Generator) { if let Some(parent) = self.tcx().opt_parent(def_id) { if matches!(self.tcx().def_kind(parent), DefKind::AssocFn | DefKind::Fn) - && self.tcx().asyncness(parent).is_async() + && is_async(self.tcx(), parent) { return parent; } @@ -112,13 +122,13 @@ impl<'tcx> MarkerCtx<'tcx> { } /// Are there any external markers on this item? - pub fn is_externally_marked(&self, did: D) -> bool { + fn is_externally_marked(&self, did: D) -> bool { !self.external_markers(did).is_empty() } /// Are there any local markers on this item? - pub fn is_locally_marked(&self, def_id: LocalDefId) -> bool { - self.local_annotations(def_id) + fn is_attribute_marked(&self, def_id: DefId) -> bool { + self.attribute_annotations(def_id) .iter() .any(Annotation::is_marker) } @@ -128,44 +138,17 @@ impl<'tcx> MarkerCtx<'tcx> { /// This is in contrast to [`Self::marker_is_reachable`] which also reports /// if markers are reachable from the body of this function (if it is one). pub fn is_marked(&self, did: D) -> bool { - matches!(did.into_def_id(self.tcx()).as_local(), Some(ldid) if self.is_locally_marked(ldid)) - || self.is_externally_marked(did) - } + let did = did.into_def_id(self.tcx()); - /// Return a complete set of local annotations that were discovered. - /// - /// Crucially this is a "readout" from the marker cache, which means only - /// items reachable from the `paralegal_flow::analyze` will end up in this collection. - pub fn local_annotations_found(&self) -> Vec<(LocalDefId, &[Annotation])> { - self.db() - .local_annotations - .iter() - .map(|(k, v)| (*k, (v.as_slice()))) - .collect() - } - - /// Direct access to the loaded database of external markers. - #[inline] - pub fn external_annotations(&self) -> &ExternalMarkers { - &self.db().external_annotations - } - - /// Are there markers reachable from this (function)? - /// - /// Returns true if the item itself carries a marker *or* if one of the - /// functions called in its body are marked. - /// - /// XXX Does not take into account reachable type markers - pub fn marker_is_reachable(&self, res: FnResolution<'tcx>) -> bool { - self.is_marked(res.def_id()) || self.has_transitive_reachable_markers(res) + self.is_attribute_marked(did) || self.is_externally_marked(did) } /// Queries the transitive marker cache. - pub fn has_transitive_reachable_markers(&self, res: FnResolution<'tcx>) -> bool { + pub fn has_transitive_reachable_markers(&self, res: Instance<'tcx>) -> bool { !self.get_reachable_markers(res).is_empty() } - pub fn get_reachable_markers(&self, res: FnResolution<'tcx>) -> &[Identifier] { + pub fn get_reachable_markers(&self, res: Instance<'tcx>) -> &[InternedString] { self.db() .reachable_markers .get_maybe_recursive(res, |_| self.compute_reachable_markers(res)) @@ -174,8 +157,8 @@ impl<'tcx> MarkerCtx<'tcx> { fn get_reachable_and_self_markers( &self, - res: FnResolution<'tcx>, - ) -> impl Iterator + '_ { + res: Instance<'tcx>, + ) -> impl Iterator + '_ { if res.def_id().is_local() { let mut direct_markers = self .combined_markers(res.def_id()) @@ -199,7 +182,8 @@ impl<'tcx> MarkerCtx<'tcx> { /// If the transitive marker cache did not contain the answer, this is what /// computes it. - fn compute_reachable_markers(&self, res: FnResolution<'tcx>) -> Box<[Identifier]> { + fn compute_reachable_markers(&self, res: Instance<'tcx>) -> Box<[InternedString]> { + let tcx = self.tcx(); trace!("Computing reachable markers for {res:?}"); let Some(local) = res.def_id().as_local() else { trace!(" Is not local"); @@ -209,16 +193,21 @@ impl<'tcx> MarkerCtx<'tcx> { trace!(" Is marked"); return Box::new([]); } - let Some(body) = self.tcx().body_for_def_id_default_policy(local) else { + let Some(body) = tcx.body_for_def_id_default_policy(local) else { trace!(" Cannot find body"); return Box::new([]); }; - let mono_body = res.try_monomorphize( - self.tcx(), - self.tcx().param_env_reveal_all_normalized(local), + let Ok(mono_body) = try_monomorphize( + res, + tcx, + tcx.param_env_reveal_all_normalized(local), &body.body, - ); - if let Some((async_fn, _)) = determine_async(self.tcx(), local, &mono_body) { + tcx.def_span(res.def_id()), + ) else { + trace!(" monomorphization error"); + return Box::new([]); + }; + if let Some((async_fn, ..)) = determine_async(tcx, local, &mono_body) { return self.get_reachable_markers(async_fn).into(); } mono_body @@ -237,7 +226,7 @@ impl<'tcx> MarkerCtx<'tcx> { &self, local_decls: &mir::LocalDecls, terminator: &mir::Terminator<'tcx>, - ) -> impl Iterator + '_ { + ) -> impl Iterator + '_ { trace!( " Finding reachable markers for terminator {:?}", terminator.kind @@ -260,7 +249,7 @@ impl<'tcx> MarkerCtx<'tcx> { && let ty::TyKind::Generator(closure_fn, substs, _) = self.tcx().type_of(alias.def_id).skip_binder().kind() { trace!(" fits opaque type"); Either::Left(self.get_reachable_and_self_markers( - FnResolution::Final(ty::Instance::expect_resolve(self.tcx(), ty::ParamEnv::reveal_all(), *closure_fn, substs)) + ty::Instance::expect_resolve(self.tcx(), ty::ParamEnv::reveal_all(), *closure_fn, substs) )) } else { Either::Right(std::iter::empty()) @@ -379,19 +368,11 @@ impl<'tcx> MarkerCtx<'tcx> { .into_iter() } - pub fn type_has_surface_markers(&self, ty: ty::Ty) -> Option { - let def_id = ty.defid()?; - self.combined_markers(def_id) - .next() - .is_some() - .then_some(def_id) - } - /// All markers placed on this function, directly or through the type plus /// the type that was marked (if any). pub fn all_function_markers<'a>( &'a self, - function: FnResolution<'tcx>, + function: Instance<'tcx>, ) -> impl Iterator, DefId)>)> { // Markers not coming from types, hence the "None" let direct_markers = self @@ -416,31 +397,28 @@ impl<'tcx> MarkerCtx<'tcx> { } /// Iterate over all discovered annotations, whether local or external - pub fn all_annotations( - &self, - ) -> impl Iterator)> { + pub fn all_annotations(&self) -> impl Iterator)> { self.0 .local_annotations .iter() .flat_map(|(&id, anns)| { anns.iter() - .map(move |ann| (id.to_def_id(), Either::Left(ann))) + .map(move |ann| (id.to_def_id(), Cow::Borrowed(ann))) }) .chain( self.0 - .external_annotations - .iter() - .flat_map(|(&id, anns)| anns.iter().map(move |ann| (id, Either::Right(ann)))), + .loader + .all_annotations() + .map(|(it, ann)| (it, Cow::Borrowed(ann))), ) - } - - pub fn functions_seen(&self) -> Vec> { - let cache = self.0.reachable_markers.borrow(); - cache.keys().copied().collect::>() + .chain(self.0.external_annotations.iter().flat_map(|(&id, anns)| { + anns.iter() + .map(move |ann| (id, Cow::Owned(Annotation::Marker(ann.clone())))) + })) } } -pub type TypeMarkerElem = (DefId, Identifier); +pub type TypeMarkerElem = (DefId, InternedString); pub type TypeMarkers = [TypeMarkerElem]; /// The structure inside of [`MarkerCtx`]. @@ -448,18 +426,19 @@ pub struct MarkerDatabase<'tcx> { tcx: TyCtxt<'tcx>, /// Cache for parsed local annotations. They are created with /// [`MarkerCtx::retrieve_local_annotations_for`]. - local_annotations: HashMap>, + pub(crate) local_annotations: HashMap>, external_annotations: ExternalMarkers, /// Cache whether markers are reachable transitively. - reachable_markers: Cache, Box<[Identifier]>>, + pub(crate) reachable_markers: Cache, Box<[InternedString]>>, /// Configuration options _config: &'static MarkerControl, type_markers: Cache, Box>, + loader: Rc>, } impl<'tcx> MarkerDatabase<'tcx> { /// Construct a new database, loading external markers. - pub fn init(tcx: TyCtxt<'tcx>, args: &'static Args) -> Self { + pub fn init(tcx: TyCtxt<'tcx>, args: &'static Args, loader: Rc>) -> Self { Self { tcx, local_annotations: HashMap::default(), @@ -467,6 +446,7 @@ impl<'tcx> MarkerDatabase<'tcx> { reachable_markers: Default::default(), _config: args.marker_control(), type_markers: Default::default(), + loader, } } diff --git a/crates/paralegal-flow/src/ann/mod.rs b/crates/paralegal-flow/src/ann/mod.rs index 05b772240d..26ed90bab8 100644 --- a/crates/paralegal-flow/src/ann/mod.rs +++ b/crates/paralegal-flow/src/ann/mod.rs @@ -1,6 +1,8 @@ +use flowistry_pdg_construction::graph::InternedString; +use rustc_macros::{Decodable, Encodable}; use serde::{Deserialize, Serialize}; -use paralegal_spdg::{rustc_proxies, tiny_bitset_pretty, Identifier, TinyBitSet, TypeId}; +use paralegal_spdg::{rustc_proxies, tiny_bitset_pretty, TinyBitSet, TypeId}; pub mod db; pub mod parse; @@ -11,7 +13,19 @@ pub mod parse; /// For convenience the match methods [`Self::as_marker`], [`Self::as_otype`] /// and [`Self::as_exception`] are provided. These are particularly useful in /// conjunction with e.g. [`Iterator::filter_map`] -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Deserialize, Serialize, strum::EnumIs)] +#[derive( + PartialEq, + Eq, + PartialOrd, + Ord, + Debug, + Clone, + Deserialize, + Serialize, + strum::EnumIs, + Encodable, + Decodable, +)] pub enum Annotation { Marker(MarkerAnnotation), OType(#[serde(with = "rustc_proxies::DefId")] TypeId), @@ -27,6 +41,7 @@ impl Annotation { } } + #[allow(dead_code)] /// If this is an [`Annotation::OType`], returns the underlying [`TypeId`]. pub fn as_otype(&self) -> Option { match self { @@ -35,6 +50,7 @@ impl Annotation { } } + #[allow(dead_code)] /// If this is an [`Annotation::Exception`], returns the underlying [`ExceptionAnnotation`]. pub fn as_exception(&self) -> Option<&ExceptionAnnotation> { match self { @@ -46,7 +62,9 @@ impl Annotation { pub type VerificationHash = u128; -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Serialize, Deserialize)] +#[derive( + PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Serialize, Deserialize, Encodable, Decodable, +)] pub struct ExceptionAnnotation { /// The value of the verification hash we found in the annotation. Is `None` /// if there was no verification hash in the annotation. @@ -54,10 +72,12 @@ pub struct ExceptionAnnotation { } /// A marker annotation and its refinements. -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Serialize, Deserialize)] +#[derive( + PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Serialize, Deserialize, Encodable, Decodable, +)] pub struct MarkerAnnotation { /// The (unchanged) name of the marker as provided by the user - pub marker: Identifier, + pub marker: InternedString, #[serde(flatten)] pub refinement: MarkerRefinement, } @@ -69,7 +89,9 @@ fn const_false() -> bool { /// Refinements in the marker targeting. The default (no refinement provided) is /// `on_argument == vec![]` and `on_return == false`, which is also what is /// returned from [`Self::empty`]. -#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Deserialize, Serialize)] +#[derive( + PartialEq, Eq, PartialOrd, Ord, Debug, Clone, Deserialize, Serialize, Encodable, Decodable, +)] pub struct MarkerRefinement { #[serde(default, with = "tiny_bitset_pretty")] on_argument: TinyBitSet, diff --git a/crates/paralegal-flow/src/ann/parse.rs b/crates/paralegal-flow/src/ann/parse.rs index 83c04f8d58..55532d7f5f 100644 --- a/crates/paralegal-flow/src/ann/parse.rs +++ b/crates/paralegal-flow/src/ann/parse.rs @@ -12,14 +12,13 @@ use super::{ ExceptionAnnotation, MarkerAnnotation, MarkerRefinement, MarkerRefinementKind, VerificationHash, }; use crate::{ - consts, - rust::*, - utils, + consts, utils, utils::{write_sep, Print, TinyBitSet}, Symbol, }; -use ast::{token, tokenstream}; -use paralegal_spdg::Identifier; +use rustc_ast::{token, tokenstream, AttrArgs}; +use rustc_hir::def_id::DefId; +use rustc_middle::ty::TyCtxt; use token::*; use tokenstream::*; @@ -177,6 +176,7 @@ pub fn assert_token<'a>(k: TokenKind) -> impl FnMut(I<'a>) -> R<'a, ()> { ) } +#[allow(dead_code)] /// Parse something dictionnary-like. /// /// Expects the next token to be a braces delimited subtree containing pairs of @@ -214,9 +214,9 @@ pub fn tiny_bitset(i: I) -> R { } /// Parser for the payload of the `#[paralegal_flow::output_type(...)]` annotation. -pub(crate) fn otype_ann_match(ann: &ast::AttrArgs, tcx: TyCtxt) -> Result, String> { +pub(crate) fn otype_ann_match(ann: &AttrArgs, tcx: TyCtxt) -> Result, String> { match ann { - ast::AttrArgs::Delimited(dargs) => { + AttrArgs::Delimited(dargs) => { let mut p = nom::multi::separated_list0( assert_token(TokenKind::Comma), nom::multi::separated_list0( @@ -315,7 +315,7 @@ pub(crate) fn ann_match_fn(ann: &rustc_ast::AttrArgs) -> Result for Args { .iter() .flat_map(|s| s.split(',').map(ToOwned::to_owned)) .collect(); + let build_config = get_build_config()?; if let Some(from_env) = env_var_expect_unicode("PARALEGAL_ANALYZE")? { anactrl .analyze .extend(from_env.split(',').map(ToOwned::to_owned)); } - let build_config_file = std::path::Path::new("Paralegal.toml"); - let build_config = if build_config_file.exists() { - toml::from_str(&std::fs::read_to_string(build_config_file)?)? - } else { - Default::default() - }; let log_level_config = match debug_target { Some(target) if !target.is_empty() => LogLevelConfig::Targeted(target), _ => LogLevelConfig::Disabled, @@ -107,12 +102,22 @@ impl TryFrom for Args { } } +pub fn get_build_config() -> Result { + let build_config_file = std::path::Path::new("Paralegal.toml"); + Ok(if build_config_file.exists() { + toml::from_str(&std::fs::read_to_string(build_config_file)?)? + } else { + Default::default() + }) +} + #[derive(serde::Serialize, serde::Deserialize, clap::ValueEnum, Clone, Copy)] pub enum Debugger { /// The CodeLLDB debugger. Learn more at . CodeLldb, } +/// Post-processed command line and environment arguments. #[derive(serde::Serialize, serde::Deserialize)] pub struct Args { /// Print additional logging output (up to the "info" level) @@ -210,7 +215,7 @@ pub struct ClapArgs { } #[derive(Clone, clap::Args)] -pub struct ParseableDumpArgs { +struct ParseableDumpArgs { /// Generate intermediate of various formats and at various stages of /// compilation. A short description of each value is provided here, for a /// more comprehensive explanation refer to the [notion page on @@ -328,6 +333,7 @@ impl std::fmt::Display for LogLevelConfig { } impl Args { + /// Are we targeting a specific crate pub fn target(&self) -> Option<&str> { self.target.as_deref() } @@ -438,7 +444,7 @@ struct ClapAnalysisCtrl { unconstrained_depth: bool, } -#[derive(serde::Serialize, serde::Deserialize)] +#[derive(serde::Serialize, serde::Deserialize, Default)] pub struct AnalysisCtrl { /// Target this function as analysis target. Command line version of /// `#[paralegal::analyze]`). Must be a full rust path and resolve to a @@ -450,15 +456,6 @@ pub struct AnalysisCtrl { inlining_depth: InliningDepth, } -impl Default for AnalysisCtrl { - fn default() -> Self { - Self { - analyze: Vec::new(), - inlining_depth: InliningDepth::Adaptive, - } - } -} - impl TryFrom for AnalysisCtrl { type Error = Error; fn try_from(value: ClapAnalysisCtrl) -> Result { @@ -494,6 +491,12 @@ pub enum InliningDepth { Adaptive, } +impl Default for InliningDepth { + fn default() -> Self { + Self::Adaptive + } +} + impl AnalysisCtrl { /// Externally (via command line) selected analysis targets pub fn selected_targets(&self) -> &[String] { @@ -542,4 +545,11 @@ pub struct DepConfig { pub struct BuildConfig { /// Dependency specific configuration pub dep: crate::HashMap, + /// Overrides what is reported if this tool is called like `rustc + /// --version`. This is sometimes needed when crates attempt to detect the + /// rust version being used. + /// + /// Set this to "inherent" to use the rustc version that paralegal will be + /// using internally. + pub imitate_compiler: Option, } diff --git a/crates/paralegal-flow/src/consts.rs b/crates/paralegal-flow/src/consts.rs index f7841bebf5..4ef5da4f4d 100644 --- a/crates/paralegal-flow/src/consts.rs +++ b/crates/paralegal-flow/src/consts.rs @@ -34,3 +34,5 @@ lazy_static! { /// [`MetaItemMatch::match_extract`](crate::utils::MetaItemMatch::match_extract) pub static ref EXCEPTION_MARKER: AttrMatchT = sym_vec!["paralegal_flow", "exception"]; } + +pub const INTERMEDIATE_ARTIFACT_EXT: &str = "para"; diff --git a/crates/paralegal-flow/src/dbg.rs b/crates/paralegal-flow/src/dbg.rs deleted file mode 100644 index 0d46c5a21a..0000000000 --- a/crates/paralegal-flow/src/dbg.rs +++ /dev/null @@ -1,25 +0,0 @@ -//! Helpers for debugging -//! -//! Defines pretty printers and dot graph output. -//! -//! Often times the pretty printers wrappers around references to graph structs, -//! like [PrintableMatrix]. These wrappers have -//! `Debug` and/or `Display` implementations so that you can flexibly print them -//! to stdout, a file or a log statement. Some take additional information (such -//! as [TyCtxt]) to get contextual information that is used to make the output -//! more useful. -use crate::rust::mir; - -/// All locations that a body has (helper) -pub fn locations_of_body<'a: 'tcx, 'tcx>( - body: &'a mir::Body<'tcx>, -) -> impl Iterator + 'a + 'tcx { - body.basic_blocks - .iter_enumerated() - .flat_map(|(block, dat)| { - (0..=dat.statements.len()).map(move |statement_index| mir::Location { - block, - statement_index, - }) - }) -} diff --git a/crates/paralegal-flow/src/discover.rs b/crates/paralegal-flow/src/discover.rs index f59f350033..fd9644f3f3 100644 --- a/crates/paralegal-flow/src/discover.rs +++ b/crates/paralegal-flow/src/discover.rs @@ -3,21 +3,19 @@ //! //! Essentially this discovers all local `paralegal_flow::*` annotations. -use crate::{ - ana::SPDGGenerator, ann::db::MarkerDatabase, consts, desc::*, rust::*, stats::Stats, utils::*, -}; +use std::rc::Rc; + +use crate::{ana::MetadataLoader, ann::db::MarkerDatabase, consts, utils::*}; -use hir::{ - def_id::DefId, +use rustc_hir::{ + def_id::LocalDefId, intravisit::{self, FnKind}, BodyId, }; -use rustc_middle::hir::nested_filter::OnlyBodies; +use rustc_middle::{hir::nested_filter::OnlyBodies, ty::TyCtxt}; use rustc_span::{symbol::Ident, Span, Symbol}; -use anyhow::Result; - -use self::resolve::expect_resolve_string_to_def_id; +use self::resolve::resolve_string_to_def_id; /// Values of this type can be matched against Rust attributes pub type AttrMatchT = Vec; @@ -36,16 +34,16 @@ pub struct CollectingVisitor<'tcx> { /// later perform the analysis pub functions_to_analyze: Vec, - stats: Stats, - pub marker_ctx: MarkerDatabase<'tcx>, + + pub emit_target_collector: Vec, } /// A function we will be targeting to analyze with /// [`CollectingVisitor::handle_target`]. pub struct FnToAnalyze { pub name: Ident, - pub def_id: DefId, + pub def_id: LocalDefId, } impl FnToAnalyze { @@ -56,46 +54,43 @@ impl FnToAnalyze { } impl<'tcx> CollectingVisitor<'tcx> { - pub(crate) fn new(tcx: TyCtxt<'tcx>, opts: &'static crate::Args, stats: Stats) -> Self { + pub(crate) fn new( + tcx: TyCtxt<'tcx>, + opts: &'static crate::Args, + loader: Rc>, + ) -> Self { let functions_to_analyze = opts .anactrl() .selected_targets() .iter() .filter_map(|path| { - let def_id = expect_resolve_string_to_def_id(tcx, path, opts.relaxed())?; - if !def_id.is_local() { + let def_id = resolve_string_to_def_id(tcx, path).ok()?; + if let Some(local) = def_id.as_local() { + Some(FnToAnalyze { + def_id: local, + name: tcx.opt_item_ident(def_id).expect("analysis target does not have a name"), + }) + } else { tcx.sess.span_err(tcx.def_span(def_id), "found an external function as analysis target. Analysis targets are required to be local."); - return None; + None } - Some(FnToAnalyze { - def_id, - name: tcx.opt_item_ident(def_id).unwrap(), - }) }) .collect(); Self { tcx, opts, functions_to_analyze, - marker_ctx: MarkerDatabase::init(tcx, opts), - stats, + marker_ctx: MarkerDatabase::init(tcx, opts, loader), + emit_target_collector: vec![], } } - /// After running the discovery with `visit_all_item_likes_in_crate`, create - /// the read-only [`SPDGGenerator`] upon which the analysis will run. - fn into_generator(self) -> SPDGGenerator<'tcx> { - SPDGGenerator::new(self.marker_ctx.into(), self.opts, self.tcx, self.stats) - } - /// Driver function. Performs the data collection via visit, then calls /// [`Self::analyze`] to construct the Forge friendly description of all /// endpoints. - pub fn run(mut self) -> Result { + pub fn run(&mut self) { let tcx = self.tcx; - tcx.hir().visit_all_item_likes_in_crate(&mut self); - let targets = std::mem::take(&mut self.functions_to_analyze); - self.into_generator().analyze(targets) + tcx.hir().visit_all_item_likes_in_crate(self) } /// Does the function named by this id have the `paralegal_flow::analyze` annotation @@ -132,14 +127,15 @@ impl<'tcx> intravisit::Visitor<'tcx> for CollectingVisitor<'tcx> { _s: Span, id: LocalDefId, ) { + self.emit_target_collector.push(id); match &kind { - FnKind::ItemFn(name, _, _) | FnKind::Method(name, _) - if self.should_analyze_function(id) => - { - self.functions_to_analyze.push(FnToAnalyze { - name: *name, - def_id: id.to_def_id(), - }); + FnKind::ItemFn(name, _, _) | FnKind::Method(name, _) => { + if self.should_analyze_function(id) { + self.functions_to_analyze.push(FnToAnalyze { + name: *name, + def_id: id, + }); + } } _ => (), } diff --git a/crates/paralegal-flow/src/lib.rs b/crates/paralegal-flow/src/lib.rs index 595f3ee5a3..fca946cf54 100644 --- a/crates/paralegal-flow/src/lib.rs +++ b/crates/paralegal-flow/src/lib.rs @@ -23,83 +23,69 @@ extern crate petgraph; extern crate num_derive; extern crate num_traits; -pub extern crate rustc_index; +extern crate rustc_abi; +extern crate rustc_arena; +extern crate rustc_ast; +extern crate rustc_borrowck; +extern crate rustc_data_structures; +extern crate rustc_driver; +extern crate rustc_hash; +extern crate rustc_hir; +extern crate rustc_index; +extern crate rustc_interface; +extern crate rustc_macros; +extern crate rustc_middle; +extern crate rustc_mir_dataflow; +extern crate rustc_query_system; extern crate rustc_serialize; +extern crate rustc_span; +extern crate rustc_target; +extern crate rustc_type_ir; -pub mod rust { - //! Exposes the rustc external crates (this mod is just to tidy things up). - pub extern crate rustc_abi; - pub extern crate rustc_arena; - pub extern crate rustc_ast; - pub extern crate rustc_borrowck; - pub extern crate rustc_data_structures; - pub extern crate rustc_driver; - pub extern crate rustc_hir; - pub extern crate rustc_interface; - pub extern crate rustc_middle; - pub extern crate rustc_mir_dataflow; - pub extern crate rustc_query_system; - pub extern crate rustc_serialize; - pub extern crate rustc_span; - pub extern crate rustc_target; - pub extern crate rustc_type_ir; - pub use super::rustc_index; - pub use rustc_type_ir::sty; - - pub use rustc_ast as ast; - pub mod mir { - pub use super::rustc_abi::FieldIdx as Field; - pub use super::rustc_middle::mir::*; - } - pub use rustc_hir as hir; - pub use rustc_middle::ty; - - pub use rustc_middle::dep_graph::DepGraph; - pub use ty::TyCtxt; +extern crate either; - pub use hir::def_id::{DefId, LocalDefId}; - pub use hir::BodyId; - pub use mir::Location; -} +use std::borrow::Cow; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; +use std::{fmt::Display, time::Instant}; -use args::{ClapArgs, Debugger, LogLevelConfig}; -use desc::{utils::write_sep, ProgramDescription}; -use rust::*; +use desc::ProgramDescription; +use rustc_hir::def_id::{DefId, LocalDefId}; +use rustc_middle::ty; +use rustc_span::Symbol; +use ty::TyCtxt; +use rustc_driver::Compilation; use rustc_plugin::CrateFilter; use rustc_utils::mir::borrowck_facts; -pub use std::collections::{HashMap, HashSet}; -use std::{fmt::Display, time::Instant}; +use anyhow::{anyhow, Context as _, Result}; +use either::Either; // This import is sort of special because it comes from the private rustc // dependencies and not from our `Cargo.toml`. -pub extern crate either; -pub use either::Either; - -pub use rustc_span::Symbol; -pub mod ana; -pub mod ann; +mod ana; +mod ann; mod args; -pub mod dbg; mod discover; mod stats; //mod sah; #[macro_use] -pub mod utils; -pub mod consts; +mod utils; +mod consts; #[cfg(feature = "test")] pub mod test_utils; pub use paralegal_spdg as desc; pub use crate::ann::db::MarkerCtx; -pub use args::{AnalysisCtrl, Args, BuildConfig, DepConfig, DumpArgs, MarkerControl}; -use crate::{ - stats::{Stats, TimedStat}, - utils::Print, -}; +use ana::{MetadataLoader, SPDGGenerator}; +use args::{AnalysisCtrl, Args, ClapArgs, Debugger, LogLevelConfig}; +use consts::INTERMEDIATE_ARTIFACT_EXT; +use desc::utils::write_sep; +use stats::{Stats, TimedStat}; +use utils::Print; /// A struct so we can implement [`rustc_plugin::RustcPlugin`] pub struct DfppPlugin; @@ -129,24 +115,96 @@ struct ArgWrapper { struct Callbacks { opts: &'static Args, stats: Stats, - start: Instant, + persist_metadata: bool, } -struct NoopCallbacks {} - -impl rustc_driver::Callbacks for NoopCallbacks {} +/// Create the name of the file in which to store intermediate artifacts. +/// +/// HACK(Justus): `TyCtxt::output_filenames` returns a file stem of +/// `lib-`, whereas `OutputFiles::with_extension` returns a file +/// stem of `-`. I haven't found a clean way to get the same +/// name in both places, so i just assume that these two will always have this +/// relation and prepend the `"lib"` here. +fn intermediate_out_file_path(tcx: TyCtxt) -> Result { + let rustc_out_file = tcx + .output_filenames(()) + .with_extension(INTERMEDIATE_ARTIFACT_EXT); + let dir = rustc_out_file + .parent() + .ok_or_else(|| anyhow!("{} has no parent", rustc_out_file.display()))?; + let file = rustc_out_file + .file_name() + .ok_or_else(|| anyhow!("has no file name")) + .and_then(|s| s.to_str().ok_or_else(|| anyhow!("not utf8"))) + .with_context(|| format!("{}", rustc_out_file.display()))?; + + let file = if file.starts_with("lib") { + Cow::Borrowed(file) + } else { + format!("lib{file}").into() + }; + + Ok(dir.join(file.as_ref())) +} impl Callbacks { - pub fn run(&self, tcx: TyCtxt) -> anyhow::Result { + fn in_context(&mut self, tcx: TyCtxt) -> Result { + let compilation = if let Some(desc) = self.run_compilation(tcx)? { + if self.opts.dbg().dump_spdg() { + let out = std::fs::File::create("call-only-flow.gv").unwrap(); + paralegal_spdg::dot::dump(&desc, out).unwrap(); + } + + let ser = Instant::now(); + desc.canonical_write(self.opts.result_path()).unwrap(); + println!("Wrote graph to {}", self.opts.result_path().display()); + self.stats + .record_timed(TimedStat::Serialization, ser.elapsed()); + + println!("Analysis finished with timing: {}", self.stats); + if self.opts.abort_after_analysis() { + rustc_driver::Compilation::Stop + } else { + rustc_driver::Compilation::Continue + } + } else { + println!("No compilation artifact"); + rustc_driver::Compilation::Continue + }; + Ok(compilation) + } + + fn run_compilation(&self, tcx: TyCtxt) -> Result> { + tcx.sess.abort_if_errors(); + + let loader = MetadataLoader::new(tcx); + + let (analysis_targets, mctx) = loader.clone().collect_and_emit_metadata( + self.opts, + self.persist_metadata + .then(|| intermediate_out_file_path(tcx)) + .transpose()?, + ); tcx.sess.abort_if_errors(); - discover::CollectingVisitor::new(tcx, self.opts, self.stats.clone()).run() + + let mut gen = SPDGGenerator::new(mctx, self.opts, tcx, loader); + + (!analysis_targets.is_empty()) + .then(|| gen.analyze(analysis_targets)) + .transpose() } +} +struct NoopCallbacks {} + +impl rustc_driver::Callbacks for NoopCallbacks {} + +impl Callbacks { pub fn new(opts: &'static Args) -> Self { Self { opts, stats: Default::default(), - start: Instant::now(), + persist_metadata: true, } } } @@ -166,34 +224,10 @@ impl rustc_driver::Callbacks for Callbacks { _compiler: &rustc_interface::interface::Compiler, queries: &'tcx rustc_interface::Queries<'tcx>, ) -> rustc_driver::Compilation { - self.stats - .record_timed(TimedStat::Rustc, self.start.elapsed()); queries .global_ctxt() .unwrap() - .enter(|tcx| { - let desc = self.run(tcx)?; - info!("All elems walked"); - tcx.sess.abort_if_errors(); - - if self.opts.dbg().dump_spdg() { - let out = std::fs::File::create("call-only-flow.gv").unwrap(); - paralegal_spdg::dot::dump(&desc, out).unwrap(); - } - - let ser = Instant::now(); - desc.canonical_write(self.opts.result_path()).unwrap(); - self.stats - .record_timed(TimedStat::Serialization, ser.elapsed()); - - println!("Analysis finished with timing: {}", self.stats); - - anyhow::Ok(if self.opts.abort_after_analysis() { - rustc_driver::Compilation::Stop - } else { - rustc_driver::Compilation::Continue - }) - }) + .enter(|tcx| self.in_context(tcx)) .unwrap() } } @@ -221,6 +255,13 @@ fn add_to_rustflags(new: impl IntoIterator) -> Result<(), std::en Ok(()) } +pub const PARALEGAL_RUSTC_FLAGS: [&str; 4] = [ + "--cfg", + "paralegal", + "-Zcrate-attr=feature(register_tool)", + "-Zcrate-attr=register_tool(paralegal_flow)", +]; + impl rustc_plugin::RustcPlugin for DfppPlugin { type Args = Args; @@ -232,6 +273,10 @@ impl rustc_plugin::RustcPlugin for DfppPlugin { "paralegal-flow".into() } + fn reported_driver_version(&self) -> Cow<'static, str> { + env!("RUSTC_VERSION").into() + } + fn args( &self, _target_dir: &rustc_plugin::Utf8Path, @@ -346,14 +391,15 @@ impl rustc_plugin::RustcPlugin for DfppPlugin { .as_ref() .map_or(false, |n| n == "build_script_build"); + debug!("Handling {}", crate_name.unwrap_or("".to_owned())); + if !is_target || is_build_script { + debug!("Is not target, skipping"); return rustc_driver::RunCompiler::new(&compiler_args, &mut NoopCallbacks {}).run(); } plugin_args.setup_logging(); - let opts = Box::leak(Box::new(plugin_args)); - const RERUN_VAR: &str = "RERUN_WITH_PROFILER"; if let Ok(debugger) = std::env::var(RERUN_VAR) { info!("Restarting with debugger '{debugger}'"); @@ -365,12 +411,9 @@ impl rustc_plugin::RustcPlugin for DfppPlugin { std::process::exit(cmd.status().unwrap().code().unwrap_or(0)); } - compiler_args.extend([ - "--cfg".into(), - "paralegal".into(), - "-Zcrate-attr=feature(register_tool)".into(), - "-Zcrate-attr=register_tool(paralegal_flow)".into(), - ]); + let opts = Box::leak(Box::new(plugin_args)); + + compiler_args.extend(PARALEGAL_RUSTC_FLAGS.iter().copied().map(ToOwned::to_owned)); if let Some(dbg) = opts.attach_to_debugger() { dbg.attach() diff --git a/crates/paralegal-flow/src/stats.rs b/crates/paralegal-flow/src/stats.rs index d88816bc74..4eca2560eb 100644 --- a/crates/paralegal-flow/src/stats.rs +++ b/crates/paralegal-flow/src/stats.rs @@ -45,6 +45,7 @@ impl Stats { self.inner_mut().record_timed(stat, duration) } + #[allow(dead_code)] pub fn get_timed(&self, stat: TimedStat) -> Duration { self.0.lock().unwrap().timed[stat].unwrap_or(Duration::ZERO) } diff --git a/crates/paralegal-flow/src/test_utils.rs b/crates/paralegal-flow/src/test_utils.rs index 9fa1d4e6b1..8a97fc24b5 100644 --- a/crates/paralegal-flow/src/test_utils.rs +++ b/crates/paralegal-flow/src/test_utils.rs @@ -6,7 +6,8 @@ extern crate rustc_span; use crate::{ desc::{Identifier, ProgramDescription}, - HashSet, + utils::Print, + Callbacks, HashSet, PARALEGAL_RUSTC_FLAGS, }; use std::fmt::{Debug, Formatter}; use std::hash::{Hash, Hasher}; @@ -15,16 +16,17 @@ use std::process::Command; use paralegal_spdg::{ rustc_portable::DefId, traverse::{generic_flows_to, EdgeSelection}, - DefInfo, EdgeInfo, Node, SPDG, + DefInfo, EdgeInfo, Endpoint, Node, SPDG, }; -use flowistry_pdg::rustc_portable::LocalDefId; +use clap::Parser; use flowistry_pdg::CallString; use itertools::Itertools; use petgraph::visit::{Control, Data, DfsEvent, EdgeRef, FilterEdge, GraphBase, IntoEdges}; use petgraph::visit::{IntoNeighbors, IntoNodeReferences}; use petgraph::visit::{NodeRef as _, Visitable}; use petgraph::Direction; +use rustc_utils::test_utils::CompileResult; use std::path::Path; lazy_static! { @@ -229,20 +231,11 @@ impl InlineTestBuilder { args.setup_logging(); rustc_utils::test_utils::CompileBuilder::new(&self.input) - .with_args( - [ - "--cfg", - "paralegal", - "-Zcrate-attr=feature(register_tool)", - "-Zcrate-attr=register_tool(paralegal_flow)", - ] - .into_iter() - .map(ToOwned::to_owned), - ) - .compile(move |result| { - let tcx = result.tcx; - let memo = crate::Callbacks::new(Box::leak(Box::new(args))); - let pdg = memo.run(tcx).unwrap(); + .with_args(PARALEGAL_RUSTC_FLAGS.iter().copied().map(ToOwned::to_owned)) + .compile(move |CompileResult { tcx }| { + let mut memo = Callbacks::new(Box::leak(Box::new(args))); + memo.persist_metadata = false; + let pdg = memo.run_compilation(tcx).unwrap().unwrap(); let graph = PreFrg::from_description(pdg); let cref = graph.ctrl(&self.ctrl_name); check(cref) @@ -269,7 +262,10 @@ pub trait HasGraph<'g>: Sized + Copy { let name = Identifier::new_intern(name.as_ref()); let id = match self.graph().name_map.get(&name).map(Vec::as_slice) { Some([one]) => *one, - Some([]) | None => panic!("Did not find name {name}"), + Some([]) | None => panic!( + "Did not find name {name}. Known names:\n{:?}", + self.graph().name_map.keys().collect::>() + ), _ => panic!("Found too many function matching name {name}"), }; FnRef { @@ -295,7 +291,7 @@ pub trait HasGraph<'g>: Sized + Copy { } } - fn ctrl_hashed(self, name: &str) -> LocalDefId { + fn ctrl_hashed(self, name: &str) -> DefId { let candidates = self .graph() .desc @@ -364,7 +360,7 @@ impl PreFrg { #[derive(Clone)] pub struct CtrlRef<'g> { graph: &'g PreFrg, - id: LocalDefId, + id: Endpoint, ctrl: &'g SPDG, } @@ -412,7 +408,7 @@ impl<'g> CtrlRef<'g> { } } - pub fn id(&self) -> LocalDefId { + pub fn id(&self) -> Endpoint { self.id } pub fn spdg(&self) -> &'g SPDG { @@ -429,7 +425,20 @@ impl<'g> CtrlRef<'g> { .map(|v| v.at) .chain(self.ctrl.graph.node_weights().map(|info| info.at)) .filter(|m| { - instruction_info[&m.leaf()] + instruction_info + .get(m) + .unwrap_or_else(|| { + panic!( + "Could not find instruction {} in\n{}", + m.leaf(), + Print(|fmt| { + for (k, v) in instruction_info.iter() { + writeln!(fmt, " {k}: {v:?}")?; + } + Ok(()) + }) + ) + }) .kind .as_function_call() .map_or(false, |i| i.id == fun.ident) diff --git a/crates/paralegal-flow/src/utils/mod.rs b/crates/paralegal-flow/src/utils/mod.rs index 2509353063..9088750433 100644 --- a/crates/paralegal-flow/src/utils/mod.rs +++ b/crates/paralegal-flow/src/utils/mod.rs @@ -1,33 +1,30 @@ //! Utility functions, general purpose structs and extension traits extern crate smallvec; +use rustc_target::spec::abi::Abi; use thiserror::Error; -use crate::{ - desc::Identifier, - rust::{ - ast, - hir::{ - self, - def::Res, - def_id::{DefId, LocalDefId}, - hir_id::HirId, - BodyId, - }, - mir::{self, Location, Place, ProjectionElem}, - rustc_borrowck::consumers::BodyWithBorrowckFacts, - rustc_data_structures::intern::Interned, - rustc_span::Span as RustSpan, - rustc_span::{symbol::Ident, Span}, - rustc_target::spec::abi::Abi, - ty, - }, - rustc_span::ErrorGuaranteed, - Either, Symbol, TyCtxt, -}; -pub use flowistry_pdg_construction::{is_non_default_trait_method, FnResolution}; +use crate::{desc::Identifier, rustc_span::ErrorGuaranteed, Either, Symbol, TyCtxt}; + +pub use flowistry_pdg_construction::utils::is_non_default_trait_method; pub use paralegal_spdg::{ShortHash, TinyBitSet}; +use rustc_ast as ast; +use rustc_borrowck::consumers::BodyWithBorrowckFacts; +use rustc_data_structures::intern::Interned; +use rustc_hir::{ + self as hir, + def::Res, + def_id::{DefId, LocalDefId}, + hir_id::HirId, + BodyId, +}; +use rustc_middle::{ + mir::{self, Location, Place, ProjectionElem}, + ty::{self, Instance}, +}; +use rustc_span::{symbol::Ident, Span as RustSpan}; + use std::{cmp::Ordering, hash::Hash}; mod print; @@ -206,7 +203,7 @@ impl<'tcx> DfppBodyExt<'tcx> for mir::Body<'tcx> { } } -pub trait FnResolutionExt<'tcx> { +pub trait InstanceExt<'tcx> { /// Get the most precise type signature we can for this function, erase any /// regions and discharge binders. /// @@ -217,15 +214,15 @@ pub trait FnResolutionExt<'tcx> { fn sig(self, tcx: TyCtxt<'tcx>) -> Result, ErrorGuaranteed>; } -impl<'tcx> FnResolutionExt<'tcx> for FnResolution<'tcx> { +impl<'tcx> InstanceExt<'tcx> for Instance<'tcx> { fn sig(self, tcx: TyCtxt<'tcx>) -> Result, ErrorGuaranteed> { let sess = tcx.sess; let def_id = self.def_id(); let def_span = tcx.def_span(def_id); let fn_kind = FunctionKind::for_def_id(tcx, def_id)?; - let late_bound_sig = match (self, fn_kind) { - (FnResolution::Final(sub), FunctionKind::Generator) => { - let gen = sub.args.as_generator(); + let late_bound_sig = match fn_kind { + FunctionKind::Generator => { + let gen = self.args.as_generator(); ty::Binder::dummy(ty::FnSig { inputs_and_output: tcx.mk_type_list(&[gen.resume_ty(), gen.return_ty()]), c_variadic: false, @@ -233,41 +230,8 @@ impl<'tcx> FnResolutionExt<'tcx> for FnResolution<'tcx> { abi: Abi::Rust, }) } - (FnResolution::Final(sub), FunctionKind::Closure) => sub.args.as_closure().sig(), - (FnResolution::Final(sub), FunctionKind::Plain) => { - sub.ty(tcx, ty::ParamEnv::reveal_all()).fn_sig(tcx) - } - (FnResolution::Partial(_), FunctionKind::Closure) => { - if let Some(local) = def_id.as_local() { - sess.span_warn( - def_span, - "Precise variable instantiation for \ - closure not known, using user type annotation.", - ); - let sig = tcx.closure_user_provided_sig(local); - Ok(sig.value) - } else { - Err(sess.span_err( - def_span, - format!( - "Could not determine type signature for external closure {def_id:?}" - ), - )) - }? - } - (FnResolution::Partial(_), FunctionKind::Generator) => Err(sess.span_err( - def_span, - format!( - "Cannot determine signature of generator {def_id:?} without monomorphization" - ), - ))?, - (FnResolution::Partial(_), FunctionKind::Plain) => { - let sig = tcx.fn_sig(def_id); - sig.no_bound_vars().unwrap_or_else(|| { - sess.span_warn(def_span, format!("Cannot discharge bound variables for {sig:?}, they will not be considered by the analysis")); - sig.skip_binder() - }) - } + FunctionKind::Closure => self.args.as_closure().sig(), + FunctionKind::Plain => self.ty(tcx, ty::ParamEnv::reveal_all()).fn_sig(tcx), }; Ok(tcx .try_normalize_erasing_late_bound_regions(ty::ParamEnv::reveal_all(), late_bound_sig) @@ -346,14 +310,7 @@ pub trait AsFnAndArgs<'tcx> { fn as_instance_and_args( &self, tcx: TyCtxt<'tcx>, - ) -> Result< - ( - FnResolution<'tcx>, - SimplifiedArguments<'tcx>, - mir::Place<'tcx>, - ), - AsFnAndArgsErr<'tcx>, - >; + ) -> Result<(Instance<'tcx>, SimplifiedArguments<'tcx>, mir::Place<'tcx>), AsFnAndArgsErr<'tcx>>; } #[derive(Debug, Error)] @@ -362,8 +319,6 @@ pub enum AsFnAndArgsErr<'tcx> { NotAConstant, #[error("is not a function type: {0:?}")] NotFunctionType(ty::TyKind<'tcx>), - #[error("is not a `Val` constant: {0}")] - NotValueLevelConstant(ty::Const<'tcx>), #[error("terminator is not a `Call`")] NotAFunctionCall, #[error("function instance could not be resolved")] @@ -374,14 +329,8 @@ impl<'tcx> AsFnAndArgs<'tcx> for mir::Terminator<'tcx> { fn as_instance_and_args( &self, tcx: TyCtxt<'tcx>, - ) -> Result< - ( - FnResolution<'tcx>, - SimplifiedArguments<'tcx>, - mir::Place<'tcx>, - ), - AsFnAndArgsErr<'tcx>, - > { + ) -> Result<(Instance<'tcx>, SimplifiedArguments<'tcx>, mir::Place<'tcx>), AsFnAndArgsErr<'tcx>> + { let mir::TerminatorKind::Call { func, args, @@ -408,12 +357,12 @@ impl<'tcx> AsFnAndArgs<'tcx> for mir::Terminator<'tcx> { using partial resolution." ), ); - FnResolution::Partial(*defid) + return Err(AsFnAndArgsErr::InstanceResolutionErr); } Ok(_) => ty::Instance::resolve(tcx, ty::ParamEnv::reveal_all(), *defid, gargs) .map_err(|_| AsFnAndArgsErr::InstanceResolutionErr)? - .map_or(FnResolution::Partial(*defid), FnResolution::Final), - }; + .ok_or(AsFnAndArgsErr::InstanceResolutionErr), + }?; Ok(( instance, args.iter().map(|a| a.place()).collect(), @@ -462,6 +411,7 @@ pub enum Overlap<'tcx> { } impl<'tcx> Overlap<'tcx> { + #[allow(dead_code)] pub fn contains_other(self) -> bool { matches!(self, Overlap::Equal | Overlap::Parent(_)) } @@ -680,10 +630,6 @@ pub enum BodyResolutionError { #[error("not a function-like object")] /// The provided id did not refer to a function-like object. NotAFunction, - #[error("body not available")] - /// The provided id refers to an external entity and we have no access to - /// its body - External, /// The function refers to a trait item (not an `impl` item or raw `fn`) #[error("is associated function of trait {0:?}")] IsTraitAssocFn(DefId), @@ -745,7 +691,6 @@ impl<'tcx> TyCtxtExt<'tcx> for TyCtxt<'tcx> { Err(e) => { let sess = self.sess; match e { - BodyResolutionError::External => (), BodyResolutionError::IsTraitAssocFn(r#trait) => { sess.struct_span_warn( self.def_span(local_def_id.to_def_id()), @@ -786,6 +731,7 @@ pub fn with_temporary_logging_level R>(filter: log::LevelFilte r } +#[allow(dead_code)] pub fn time R>(msg: &str, f: F) -> R { info!("Starting {msg}"); let time = std::time::Instant::now(); @@ -822,43 +768,7 @@ impl IntoBodyId for DefId { } } -pub trait Spanned<'tcx> { - fn span(&self, tcx: TyCtxt<'tcx>) -> Span; -} - -impl<'tcx> Spanned<'tcx> for mir::Terminator<'tcx> { - fn span(&self, _tcx: TyCtxt<'tcx>) -> Span { - self.source_info.span - } -} - -impl<'tcx> Spanned<'tcx> for mir::Statement<'tcx> { - fn span(&self, _tcx: TyCtxt<'tcx>) -> Span { - self.source_info.span - } -} - -impl<'tcx> Spanned<'tcx> for (&mir::Body<'tcx>, mir::Location) { - fn span(&self, tcx: TyCtxt<'tcx>) -> Span { - self.0 - .stmt_at(self.1) - .either(|e| e.span(tcx), |e| e.span(tcx)) - } -} - -impl<'tcx> Spanned<'tcx> for DefId { - fn span(&self, tcx: TyCtxt<'tcx>) -> Span { - tcx.def_span(*self) - } -} - -impl<'tcx> Spanned<'tcx> for (LocalDefId, mir::Location) { - fn span(&self, tcx: TyCtxt<'tcx>) -> Span { - let body = tcx.body_for_def_id(self.0).unwrap(); - (&body.body, self.1).span(tcx) - } -} - +#[allow(dead_code)] pub fn map_either( either: Either, f: impl FnOnce(A) -> C, diff --git a/crates/paralegal-flow/src/utils/resolve.rs b/crates/paralegal-flow/src/utils/resolve.rs index 122e4d078e..a6ba09907f 100644 --- a/crates/paralegal-flow/src/utils/resolve.rs +++ b/crates/paralegal-flow/src/utils/resolve.rs @@ -1,12 +1,16 @@ -use crate::{ast, hir, ty, DefId, Symbol, TyCtxt}; use ast::Mutability; -use hir::{ +use rustc_ast as ast; +use rustc_hir::def_id::DefId; +use rustc_hir::{ def::{self, DefKind}, def_id::CrateNum, def_id::LocalDefId, def_id::LOCAL_CRATE, ImplItemRef, ItemKind, Node, PrimTy, TraitItemRef, }; +use rustc_middle::ty::{self, TyCtxt}; +use rustc_span::Symbol; +use thiserror::Error; use ty::{fast_reject::SimplifiedType, FloatTy, IntTy, UintTy}; #[derive(Debug, Clone, Copy)] @@ -15,18 +19,22 @@ pub enum Res { PrimTy(PrimTy), } -#[derive(Clone, Debug)] -pub enum ResolutionError<'a> { +#[derive(Clone, Debug, Error)] +pub enum ResolutionError { + #[error("cannot resolve primitive type {}", .0)] CannotResolvePrimitiveType(Symbol), + #[error("path is empty")] PathIsEmpty, + #[error("could not find child {segment} in {item:?} (which is a {search_space:?})")] CouldNotFindChild { item: DefId, - segment: &'a str, + segment: String, search_space: SearchSpace, }, + #[error("empty start segments")] EmptyStarts, + #[error("non-convertible resolution {:?}", .0)] UnconvertibleRes(def::Res), - CouldNotResolveCrate(&'a str), } #[derive(Clone, Debug)] @@ -36,7 +44,7 @@ pub enum SearchSpace { } impl Res { - fn from_def_res<'a>(res: def::Res) -> Result> { + fn from_def_res(res: def::Res) -> Result { match res { def::Res::Def(k, i) => Ok(Res::Def(k, i)), def::Res::PrimTy(t) => Ok(Res::PrimTy(t)), @@ -113,13 +121,25 @@ pub fn expect_resolve_string_to_def_id(tcx: TyCtxt, path: &str, relaxed: bool) - } } +pub fn resolve_string_to_def_id(tcx: TyCtxt, path: &str) -> anyhow::Result { + let segment_vec = path.split("::").collect::>(); + + let res = def_path_res(tcx, &segment_vec)?; + match res { + Res::Def(_, did) => Ok(did), + other => { + anyhow::bail!("expected {path} to resolve to an item, got {other:?}") + } + } +} + /// Lifted from `clippy_utils` -pub fn def_path_res<'a>(tcx: TyCtxt, path: &[&'a str]) -> Result> { - fn item_child_by_name<'a>( +pub fn def_path_res(tcx: TyCtxt, path: &[&str]) -> Result { + fn item_child_by_name( tcx: TyCtxt<'_>, def_id: DefId, name: &str, - ) -> Option>> { + ) -> Option> { if let Some(local_id) = def_id.as_local() { local_item_children_by_name(tcx, local_id, name) } else { @@ -127,11 +147,11 @@ pub fn def_path_res<'a>(tcx: TyCtxt, path: &[&'a str]) -> Result( + fn non_local_item_children_by_name( tcx: TyCtxt<'_>, def_id: DefId, name: &str, - ) -> Option>> { + ) -> Option> { match tcx.def_kind(def_id) { DefKind::Mod | DefKind::Enum | DefKind::Trait => tcx .module_children(def_id) @@ -148,11 +168,11 @@ pub fn def_path_res<'a>(tcx: TyCtxt, path: &[&'a str]) -> Result( + fn local_item_children_by_name( tcx: TyCtxt<'_>, local_id: LocalDefId, name: &str, - ) -> Option>> { + ) -> Option> { let hir = tcx.hir(); let root_mod; @@ -232,13 +252,13 @@ pub fn def_path_res<'a>(tcx: TyCtxt, path: &[&'a str]) -> Result { assert!(!output.is_empty()); assert!(input.flows_to_data(&output)); }); + +#[test] +fn await_on_generic() { + InlineTestBuilder::new(stringify!( + use std::{ + future::{Future}, + task::{Context, Poll}, + pin::Pin + }; + struct AFuture; + + impl Future for AFuture { + type Output = usize; + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + unimplemented!() + } + } + + trait Trait { + fn method(&mut self) -> AFuture; + } + + async fn main(mut t: T) -> usize { + t.method().await + } + )) + .check(|_ctrl| {}) +} + +#[test] +fn await_with_inner_generic() { + InlineTestBuilder::new(stringify!( + use std::{ + future::{Future}, + task::{Context, Poll}, + pin::Pin, + }; + struct AFuture<'a, T: ?Sized>(&'a mut T); + + impl<'a, T> Future for AFuture<'a, T> { + type Output = usize; + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + unimplemented!() + } + } + + trait Trait { + fn method(&mut self) -> AFuture<'_, Self> { + AFuture(self) + } + } + + async fn main(mut t: T) -> usize { + t.method().await + } + )) + .check(|_ctrl| {}) +} + +#[test] +#[ignore = "https://github.com/brownsys/paralegal/issues/159"] +fn await_with_inner_generic_constrained() { + InlineTestBuilder::new(stringify!( + use std::{ + future::{Future}, + task::{Context, Poll}, + pin::Pin, + }; + struct AFuture<'a, T: ?Sized>(&'a mut T); + + impl<'a, T: Trait + Unpin + ?Sized> Future for AFuture<'a, T> { + type Output = usize; + fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { + unimplemented!() + } + } + + trait Trait: Send + Unpin + 'static { + fn method(&mut self) -> AFuture<'_, Self> + where + Self: Unpin + Sized, + { + AFuture(self) + } + } + + async fn main(mut t: T) -> usize { + t.method().await + } + )) + .check(|_ctrl| {}) +} + +#[test] +fn async_through_another_layer() { + InlineTestBuilder::new(stringify!( + async fn maker(x: u32, y: u32) -> u32 { + x + } + + fn get_async(x: u32, y: u32) -> impl std::future::Future { + maker(y, x) + } + + #[paralegal_flow::marker(source, return)] + fn mark_source(t: T) -> T { + t + } + + #[paralegal_flow::marker(source_2, return)] + fn mark_source_2(t: T) -> T { + t + } + + #[paralegal_flow::marker(sink, arguments = [0])] + fn sink(t: T) {} + + async fn main() { + let src = mark_source(1); + let src2 = mark_source_2(2); + sink(get_async(src, src2).await) + } + )) + .check(|ctrl| { + assert!(!ctrl + .marked(Identifier::new_intern("source")) + .flows_to_any(&ctrl.marked(Identifier::new_intern("sink")))); + assert!(ctrl + .marked(Identifier::new_intern("source_2")) + .flows_to_any(&ctrl.marked(Identifier::new_intern("sink")))); + }) +} diff --git a/crates/paralegal-flow/tests/call_chain_analysis_tests.rs b/crates/paralegal-flow/tests/call_chain_analysis_tests.rs index 91f38fb74b..70f915f161 100644 --- a/crates/paralegal-flow/tests/call_chain_analysis_tests.rs +++ b/crates/paralegal-flow/tests/call_chain_analysis_tests.rs @@ -28,16 +28,33 @@ define_test!(without_return: ctrl -> { assert!(src.output().flows_to_data(&dest)); }); -define_test!(with_return: ctrl -> { - let src_fn = ctrl.function("source"); - let src = ctrl.call_site(&src_fn); - let ctrl = ctrl.ctrl("with_return"); - let dest_fn = ctrl.function("receiver"); - let dest_sink = ctrl.call_site(&dest_fn); - let dest = dest_sink.input().nth(0).unwrap(); +#[test] +fn with_return() { + InlineTestBuilder::new(stringify!( + #[paralegal_flow::marker(hello, return)] + fn source() -> i32 { + 0 + } + fn callee(x: i32) -> i32 { + source() + } + #[paralegal_flow::marker(there, arguments = [0])] + fn receiver(x: i32) {} - assert!(src.output().flows_to_data(&dest)); -}); + fn main(x: i32) { + receiver(callee(x)); + } + )) + .check(|ctrl| { + let src_fn = ctrl.function("source"); + let src = ctrl.call_site(&src_fn); + let dest_fn = ctrl.function("receiver"); + let dest_sink = ctrl.call_site(&dest_fn); + let dest = dest_sink.input().nth(0).unwrap(); + + assert!(src.output().flows_to_data(&dest)); + }) +} define_test!(on_mut_var: ctrl -> { let src_fn = ctrl.function("source"); diff --git a/crates/paralegal-flow/tests/clone-test.rs b/crates/paralegal-flow/tests/clone-test.rs new file mode 100644 index 0000000000..faf0764b89 --- /dev/null +++ b/crates/paralegal-flow/tests/clone-test.rs @@ -0,0 +1,78 @@ +use paralegal_flow::test_utils::InlineTestBuilder; + +#[test] +fn clone_nesting() { + InlineTestBuilder::new(stringify!( + #[derive(Clone)] + enum Opt { + Empty, + Filled(T), + } + + #[derive(Clone)] + struct AStruct { + f: usize, + g: usize, + } + + #[derive(Clone)] + enum AnEnum { + Var1(usize), + Var2(String), + } + + fn main() { + let v0 = Opt::Filled(AStruct { f: 0, g: 0 }).clone(); + let v2 = Opt::Filled(AnEnum::Var1(0)).clone(); + } + )) + .check(|_ctr| {}) +} + +#[test] +fn clone_test_2() { + InlineTestBuilder::new(stringify!( + #[derive(Clone)] + pub(crate) enum IdOrNestedObject { + Id(Url), + NestedObject(Kind), + } + + #[derive(Clone)] + struct Url(String); + + #[derive(Clone)] + pub struct Vote { + pub(crate) to: Vec, + } + + #[derive(Clone)] + struct VoteUrl(String); + + #[derive(Clone)] + struct TombstoneUrl(String); + + #[derive(Clone)] + pub struct AnnounceActivity { + pub(crate) object: IdOrNestedObject, + } + #[derive(Clone)] + pub struct Tombstone { + pub(crate) id: TombstoneUrl, + } + + #[derive(Clone)] + pub struct Delete { + pub(crate) object: IdOrNestedObject, + } + + #[derive(Clone)] + pub enum AnnouncableActivities { + Vote(Vote), + Delete(Delete), + } + + fn main() {} + )) + .check(|_g| {}) +} diff --git a/crates/paralegal-flow/tests/cross-crate.rs b/crates/paralegal-flow/tests/cross-crate.rs new file mode 100644 index 0000000000..2c60cd3dde --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate.rs @@ -0,0 +1,71 @@ +#![feature(rustc_private)] +#[macro_use] +extern crate lazy_static; + +use paralegal_flow::test_utils::*; +use paralegal_spdg::Identifier; + +const CRATE_DIR: &str = "tests/cross-crate"; + +lazy_static! { + static ref TEST_CRATE_ANALYZED: bool = run_paralegal_flow_with_flow_graph_dump(CRATE_DIR); +} + +macro_rules! define_test { + ($name:ident: $ctrl:ident -> $block:block) => { + define_test!($name: $ctrl, $name -> $block); + }; + ($name:ident: $ctrl:ident, $ctrl_name:ident -> $block:block) => { + paralegal_flow::define_flow_test_template!(TEST_CRATE_ANALYZED, CRATE_DIR, $name: $ctrl, $ctrl_name -> $block); + }; +} + +define_test!(basic : graph -> { + let src_fn = graph.function("src"); + let src = graph.call_site(&src_fn); + let not_src_fn = graph.function("not_src"); + let not_src = graph.call_site(¬_src_fn); + let target_fn = graph.function("target"); + let target = graph.call_site(&target_fn); + assert!(src.output().flows_to_data(&target.input())); + assert!(!not_src.output().flows_to_data(&target.input())); +}); + +define_test!(basic_marker: graph -> { + + let marker = Identifier::new_intern("mark"); + assert!(dbg!(&graph.spdg().markers).iter().any(|(_, markers)| markers.contains(&marker))) +}); + +define_test!(assigns_marker: graph -> { + let sources = graph.marked(Identifier::new_intern("source")); + let mark = graph.marked(Identifier::new_intern("mark")); + let target = graph.marked(Identifier::new_intern("target")); + assert!(!sources.is_empty()); + assert!(!mark.is_empty()); + assert!(!target.is_empty()); + assert!(sources.flows_to_data(&mark)); + assert!(mark.flows_to_data(&target)); +}); + +define_test!(basic_generic : graph -> { + let src_fn = graph.function("src"); + let src = graph.call_site(&src_fn); + let not_src_fn = graph.function("not_src"); + let not_src = graph.call_site(¬_src_fn); + let target_fn = graph.function("target"); + let target = graph.call_site(&target_fn); + assert!(src.output().flows_to_data(&target.input())); + assert!(!not_src.output().flows_to_data(&target.input())); +}); + +define_test!(assigns_marker_generic: graph -> { + let sources = graph.marked(Identifier::new_intern("source")); + let mark = graph.marked(Identifier::new_intern("mark")); + let target = graph.marked(Identifier::new_intern("target")); + assert!(!sources.is_empty()); + assert!(!mark.is_empty()); + assert!(!target.is_empty()); + assert!(sources.flows_to_data(&mark)); + assert!(mark.flows_to_data(&target)); +}); diff --git a/crates/paralegal-flow/tests/cross-crate/Cargo.lock b/crates/paralegal-flow/tests/cross-crate/Cargo.lock new file mode 100644 index 0000000000..6478b31dbc --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate/Cargo.lock @@ -0,0 +1,57 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "dependency" +version = "0.0.1" +dependencies = [ + "paralegal", +] + +[[package]] +name = "entry" +version = "0.0.1" +dependencies = [ + "dependency", + "paralegal", +] + +[[package]] +name = "paralegal" +version = "0.1.0" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro2" +version = "1.0.82" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ad3d49ab951a01fbaafe34f2ec74122942fe18a3f9814c3268f1bb72042131b" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" diff --git a/crates/paralegal-flow/tests/cross-crate/Cargo.toml b/crates/paralegal-flow/tests/cross-crate/Cargo.toml new file mode 100644 index 0000000000..6e5676e925 --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate/Cargo.toml @@ -0,0 +1,2 @@ +[workspace] +members = ["dependency", "entry"] diff --git a/crates/paralegal-flow/tests/cross-crate/dependency/Cargo.toml b/crates/paralegal-flow/tests/cross-crate/dependency/Cargo.toml new file mode 100644 index 0000000000..9cd274b7d0 --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate/dependency/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "dependency" +version = "0.0.1" + +[dependencies] +paralegal = { path = "../../../../paralegal" } diff --git a/crates/paralegal-flow/tests/cross-crate/dependency/src/lib.rs b/crates/paralegal-flow/tests/cross-crate/dependency/src/lib.rs new file mode 100644 index 0000000000..194dd768be --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate/dependency/src/lib.rs @@ -0,0 +1,30 @@ +pub fn find_me(a: usize, _b: usize) -> usize { + a +} + +#[paralegal::marker(mark, return)] +pub fn source() -> usize { + 0 +} + +#[paralegal::marker(mark, return)] +fn taint_it(_: A) -> A { + unimplemented!() +} + +pub fn assign_marker(a: usize) -> usize { + taint_it(a) +} + +pub fn find_me_generic(a: A, _b: A) -> A { + a +} + +#[paralegal::marker(mark, return)] +pub fn generic_source() -> A { + unimplemented!() +} + +pub fn assign_marker_generic(a: A) -> A { + taint_it(a) +} diff --git a/crates/paralegal-flow/tests/cross-crate/entry/Cargo.toml b/crates/paralegal-flow/tests/cross-crate/entry/Cargo.toml new file mode 100644 index 0000000000..52bdc7a83f --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate/entry/Cargo.toml @@ -0,0 +1,7 @@ +[package] +name = "entry" +version = "0.0.1" + +[dependencies] +dependency = { path = "../dependency" } +paralegal = { path = "../../../../paralegal" } diff --git a/crates/paralegal-flow/tests/cross-crate/entry/src/main.rs b/crates/paralegal-flow/tests/cross-crate/entry/src/main.rs new file mode 100644 index 0000000000..66f47a24c4 --- /dev/null +++ b/crates/paralegal-flow/tests/cross-crate/entry/src/main.rs @@ -0,0 +1,43 @@ +extern crate dependency; + +use dependency::{assign_marker, assign_marker_generic, find_me, find_me_generic, source}; + +#[paralegal::marker(source, return)] +fn src() -> usize { + 0 +} + +#[paralegal::marker(not_source)] +fn not_src() -> usize { + 1 +} + +#[paralegal::marker(target, arguments = [0])] +fn target(u: usize) {} + +#[paralegal::analyze] +fn basic() { + target(find_me(src(), not_src())) +} + +#[paralegal::analyze] +fn basic_marker() { + target(source()); +} + +#[paralegal::analyze] +fn assigns_marker() { + target(assign_marker(src())); +} + +#[paralegal::analyze] +fn basic_generic() { + target(find_me_generic(src(), not_src())) +} + +#[paralegal::analyze] +fn assigns_marker_generic() { + target(assign_marker_generic(src())); +} + +fn main() {} diff --git a/crates/paralegal-flow/tests/marker_tests.rs b/crates/paralegal-flow/tests/marker_tests.rs index 8ae77b08f9..83f6480e2c 100644 --- a/crates/paralegal-flow/tests/marker_tests.rs +++ b/crates/paralegal-flow/tests/marker_tests.rs @@ -26,7 +26,7 @@ define_test!(use_wrapper: ctrl -> { let cs = ctrl.call_site(&uwf); println!("{:?}", &ctrl.graph().desc.type_info); let tp = cs.output().as_singles().any(|n| - dbg!(ctrl.types_for(n.node())).iter().any(|t| + dbg!(ctrl.types_for(dbg!(n.node()))).iter().any(|t| ctrl.graph().desc.type_info[t].rendering.contains("::Wrapper") ) ); @@ -41,7 +41,7 @@ define_test!(trait_method_marker: ctrl -> { .iter() .any(|(node, markers)| { let weight = spdg.graph.node_weight(*node).unwrap(); - !matches!(ctrl.graph().desc.instruction_info[&weight.at.leaf()].kind, + !matches!(ctrl.graph().desc.instruction_info[&weight.at].kind, InstructionKind::FunctionCall(fun) if fun.id == method.ident) || markers.contains(&marker) })); diff --git a/crates/paralegal-policy/src/algo/ahb.rs b/crates/paralegal-policy/src/algo/ahb.rs index b17ea76729..9ca762bf4e 100644 --- a/crates/paralegal-policy/src/algo/ahb.rs +++ b/crates/paralegal-policy/src/algo/ahb.rs @@ -5,7 +5,7 @@ use std::{collections::HashSet, sync::Arc}; pub use paralegal_spdg::rustc_portable::{DefId, LocalDefId}; -use paralegal_spdg::{GlobalNode, Identifier, Node, SPDGImpl}; +use paralegal_spdg::{Endpoint, GlobalNode, Identifier, Node, SPDGImpl}; use anyhow::{ensure, Result}; use itertools::Itertools; @@ -143,7 +143,7 @@ impl crate::Context { .map(|i| (i.controller_id(), i.local_node())) .into_group_map(); - let mut trace = Trace::new(self.config.always_happens_before_tracing); + let mut trace = Trace::new(self.config.lock().unwrap().always_happens_before_tracing); let select_data = |e: <&SPDGImpl as IntoEdgeReferences>::EdgeRef| e.weight().is_data(); @@ -196,7 +196,7 @@ pub enum TraceLevel { struct Tracer<'a> { tree: Box<[Node]>, trace: &'a mut Trace, - ctrl_id: LocalDefId, + ctrl_id: Endpoint, } enum Trace { @@ -279,7 +279,7 @@ impl<'a> Tracer<'a> { trace: &'a mut Trace, node_bound: usize, initials: impl IntoIterator, - ctrl_id: LocalDefId, + ctrl_id: Endpoint, ) -> Self { Self { tree: if matches!(trace, Trace::None(_)) { diff --git a/crates/paralegal-policy/src/context.rs b/crates/paralegal-policy/src/context.rs index 24006ab727..a82357e858 100644 --- a/crates/paralegal-policy/src/context.rs +++ b/crates/paralegal-policy/src/context.rs @@ -1,11 +1,11 @@ use std::collections::BTreeMap; use std::fs::File; -use std::io::{BufRead, BufReader}; +use std::io::BufReader; +use std::sync::Mutex; use std::time::{Duration, Instant}; use std::vec; use std::{io::Write, process::exit, sync::Arc}; -use paralegal_spdg::rustc_portable::defid_as_local; pub use paralegal_spdg::rustc_portable::{DefId, LocalDefId}; use paralegal_spdg::traverse::{generic_flows_to, EdgeSelection}; use paralegal_spdg::{ @@ -30,8 +30,6 @@ use crate::{assert_warning, diagnostics::DiagnosticsRecorder}; /// User-defined PDG markers. pub type Marker = Identifier; -/// The type identifying a controller -pub type ControllerId = LocalDefId; /// The type identifying a function that is used in call sites. pub type FunctionId = DefId; @@ -39,7 +37,7 @@ pub type FunctionId = DefId; pub type MarkableId = GlobalNode; type MarkerIndex = HashMap; -type FlowsTo = HashMap; +type FlowsTo = HashMap; /// Collection of entities a particular marker has been applied to #[derive(Clone, Debug, Default)] @@ -68,7 +66,7 @@ fn bfs_iter< G: IntoNeighbors + GraphRef + Visitable::Map>, >( g: G, - controller_id: LocalDefId, + controller_id: Endpoint, start: impl IntoIterator, ) -> impl Iterator { let mut discovered = g.visit_map(); @@ -102,14 +100,13 @@ fn bfs_iter< /// [`Self::emit_diagnostics`]. If you used /// [`super::GraphLocation::with_context`] this will be done automatically for /// you. -#[derive(Debug)] pub struct Context { marker_to_ids: MarkerIndex, desc: ProgramDescription, flows_to: Option, pub(crate) diagnostics: DiagnosticsRecorder, name_map: HashMap>, - pub(crate) config: Arc, + pub(crate) config: Arc>, pub(crate) stats: ContextStats, } @@ -145,7 +142,7 @@ impl Context { flows_to, diagnostics: Default::default(), name_map, - config: Arc::new(config), + config: Arc::new(Mutex::new(config)), stats: ContextStats { pdg_construction: None, precomputation: start.elapsed(), @@ -261,8 +258,9 @@ impl Context { } /// Dispatch and drain all queued diagnostics without aborting the program. - pub fn emit_diagnostics(&self, w: impl Write) -> std::io::Result { - self.diagnostics.emit(w) + pub fn emit_diagnostics(&self) -> std::io::Result { + self.diagnostics + .emit(&mut self.config.lock().unwrap().output_writer) } /// Returns all nodes that are in any of the PDGs @@ -367,7 +365,7 @@ impl Context { /// /// If the controller with this id does not exist *or* the controller has /// fewer than `index` arguments. - pub fn controller_argument(&self, ctrl_id: ControllerId, index: u32) -> Option { + pub fn controller_argument(&self, ctrl_id: Endpoint, index: u32) -> Option { let ctrl = self.desc.controllers.get(&ctrl_id)?; let inner = *ctrl.arguments.get(index as usize)?; @@ -430,10 +428,7 @@ impl Context { } /// Returns all DataSources, DataSinks, and CallSites for a Controller as Nodes. - pub fn all_nodes_for_ctrl( - &self, - ctrl_id: ControllerId, - ) -> impl Iterator + '_ { + pub fn all_nodes_for_ctrl(&self, ctrl_id: Endpoint) -> impl Iterator + '_ { let ctrl = &self.desc.controllers[&ctrl_id]; ctrl.graph .node_indices() @@ -443,7 +438,7 @@ impl Context { /// Returns an iterator over the data sources within controller `c` that have type `t`. pub fn srcs_with_type( &self, - ctrl_id: ControllerId, + ctrl_id: Endpoint, t: DefId, ) -> impl Iterator + '_ { self.desc.controllers[&ctrl_id] @@ -459,7 +454,7 @@ impl Context { /// Returns an iterator over all nodes that do not have any influencers of the given edge_type. pub fn roots( &self, - ctrl_id: ControllerId, + ctrl_id: Endpoint, edge_type: EdgeSelection, ) -> impl Iterator + '_ { let g = &self.desc.controllers[&ctrl_id].graph; @@ -478,9 +473,7 @@ impl Context { continue; } let w = g.node_weight(n).unwrap(); - if self.desc.instruction_info[&w.at.leaf()] - .kind - .is_function_call() + if self.desc.instruction_info[&w.at].kind.is_function_call() || w.at.leaf().location.is_start() { roots.push(GlobalNode::from_local_node(ctrl_id, n)); @@ -530,7 +523,7 @@ impl Context { } /// Iterate over all defined controllers - pub fn all_controllers(&self) -> impl Iterator { + pub fn all_controllers(&self) -> impl Iterator { self.desc().controllers.iter().map(|(k, v)| (*k, v)) } @@ -622,6 +615,8 @@ impl Context { mut out: impl Write, include_signatures: bool, ) -> std::io::Result<()> { + use std::io::BufRead; + let ordered_span_set = self .desc .analyzed_spans @@ -633,11 +628,7 @@ impl Context { self.desc .def_info .iter() - .filter(|(did, _)| { - !matches!(defid_as_local(**did), Some(local) - if self.desc.analyzed_spans.contains_key(&local) - ) - }) + .filter(|(did, _)| self.desc.analyzed_spans.contains_key(did)) .map(|(_, i)| (&i.src_info, matches!(i.kind, DefKind::Type))) }) .into_iter() @@ -879,7 +870,7 @@ impl NodeExt for GlobalNode { } fn instruction(self, ctx: &Context) -> &InstructionInfo { - &ctx.desc.instruction_info[&self.info(ctx).at.leaf()] + &ctx.desc.instruction_info[&self.info(ctx).at] } fn successors(self, ctx: &Context) -> Box + '_> { diff --git a/crates/paralegal-policy/src/diagnostics.rs b/crates/paralegal-policy/src/diagnostics.rs index e49b0ef965..3ddf132953 100644 --- a/crates/paralegal-policy/src/diagnostics.rs +++ b/crates/paralegal-policy/src/diagnostics.rs @@ -95,9 +95,9 @@ use indexmap::IndexMap; use std::rc::Rc; use std::{io::Write, sync::Arc}; -use paralegal_spdg::{GlobalNode, Identifier, Span, SpanCoord, SPDG}; +use paralegal_spdg::{Endpoint, GlobalNode, Identifier, Span, SpanCoord, SPDG}; -use crate::{Context, ControllerId, NodeExt}; +use crate::{Context, NodeExt}; /// Check the condition and emit a [`Diagnostics::error`] if it fails. #[macro_export] @@ -785,7 +785,7 @@ impl PolicyContext { /// diagnostic context management. pub fn named_controller( self: Arc, - id: ControllerId, + id: Endpoint, policy: impl FnOnce(Arc) -> A, ) -> A { policy(Arc::new(ControllerContext { @@ -820,7 +820,7 @@ impl HasDiagnosticsBase for PolicyContext { /// See the [module level documentation][self] for more information on /// diagnostic context management. pub struct ControllerContext { - id: ControllerId, + id: Endpoint, inner: Arc, } @@ -863,7 +863,7 @@ impl ControllerContext { } /// Access the id for the controller of this context - pub fn id(&self) -> ControllerId { + pub fn id(&self) -> Endpoint { self.id } @@ -974,7 +974,7 @@ impl Context { /// diagnostic context management. pub fn named_controller( self: Arc, - id: ControllerId, + id: Endpoint, policy: impl FnOnce(Arc) -> A, ) -> A { policy(Arc::new(ControllerContext { diff --git a/crates/paralegal-policy/src/lib.rs b/crates/paralegal-policy/src/lib.rs index be8729c0a8..c064dc2bc9 100644 --- a/crates/paralegal-policy/src/lib.rs +++ b/crates/paralegal-policy/src/lib.rs @@ -49,8 +49,6 @@ #![warn(missing_docs)] -extern crate core; - use anyhow::{ensure, Result}; pub use paralegal_spdg; use paralegal_spdg::utils::TruncatedHumanTime; @@ -245,7 +243,7 @@ impl GraphLocation { let start = Instant::now(); let result = prop(ctx.clone())?; - let success = ctx.emit_diagnostics(std::io::stdout())?; + let success = ctx.emit_diagnostics()?; Ok(PolicyReturn { success, result, @@ -276,13 +274,14 @@ impl GraphLocation { } /// Configuration for the framework -#[derive(Clone, Debug)] pub struct Config { /// How much information to retain for error messages in `always_happens_before` pub always_happens_before_tracing: algo::ahb::TraceLevel, /// Whether tho precompute an index for `flows_to` queries with /// `EdgeSelection::Data` or whether to use a new DFS every time. pub use_flows_to_index: bool, + /// Where to write output to + pub output_writer: Box, } impl Default for Config { @@ -290,6 +289,7 @@ impl Default for Config { Config { always_happens_before_tracing: algo::ahb::TraceLevel::StartAndEnd, use_flows_to_index: false, + output_writer: Box::new(std::io::stdout()), } } } diff --git a/crates/paralegal-policy/src/test_utils.rs b/crates/paralegal-policy/src/test_utils.rs index df07e9d7dd..5797d037aa 100644 --- a/crates/paralegal-policy/src/test_utils.rs +++ b/crates/paralegal-policy/src/test_utils.rs @@ -1,6 +1,6 @@ use crate::Context; -use crate::ControllerId; use paralegal_flow::test_utils::PreFrg; +use paralegal_spdg::Endpoint; use paralegal_spdg::IntoIterGlobalNodes; use paralegal_spdg::NodeCluster; use paralegal_spdg::{Identifier, InstructionKind, Node as SPDGNode, SPDG}; @@ -21,7 +21,7 @@ pub fn test_ctx() -> Arc { pub fn get_callsite_or_datasink_node<'a>( ctx: &'a Context, - controller: ControllerId, + controller: Endpoint, name: &'a str, ) -> NodeCluster { get_callsite_node(ctx, controller, name) @@ -29,11 +29,7 @@ pub fn get_callsite_or_datasink_node<'a>( .unwrap() } -pub fn get_callsite_node<'a>( - ctx: &'a Context, - controller: ControllerId, - name: &'a str, -) -> NodeCluster { +pub fn get_callsite_node<'a>(ctx: &'a Context, controller: Endpoint, name: &'a str) -> NodeCluster { let name = Identifier::new_intern(name); let ctrl = &ctx.desc().controllers[&controller]; let inner = ctrl @@ -49,7 +45,7 @@ fn is_at_function_call_with_name( node: SPDGNode, ) -> bool { let weight = ctrl.graph.node_weight(node).unwrap().at; - let instruction = &ctx.desc().instruction_info[&weight.leaf()]; + let instruction = &ctx.desc().instruction_info[&weight]; matches!( instruction.kind, InstructionKind::FunctionCall(call) if @@ -57,7 +53,7 @@ fn is_at_function_call_with_name( ) } -pub fn get_sink_node<'a>(ctx: &'a Context, controller: ControllerId, name: &'a str) -> NodeCluster { +pub fn get_sink_node<'a>(ctx: &'a Context, controller: Endpoint, name: &'a str) -> NodeCluster { let name = Identifier::new_intern(name); let ctrl = &ctx.desc().controllers[&controller]; let inner = ctrl diff --git a/crates/paralegal-policy/tests/entrypoint-generics.rs b/crates/paralegal-policy/tests/entrypoint-generics.rs index 3949b07e13..f61a442691 100644 --- a/crates/paralegal-policy/tests/entrypoint-generics.rs +++ b/crates/paralegal-policy/tests/entrypoint-generics.rs @@ -52,6 +52,43 @@ fn simple_parent() -> Result<()> { #[test] fn default_method() -> Result<()> { + let test = Test::new(stringify!( + #[paralegal::marker(source, return)] + fn actual_source() -> usize { + 0 + } + + trait Src { + fn source(&self) -> usize { + actual_source() + } + } + + #[paralegal::marker(sink, arguments = [0])] + fn actual_sink(t: T) {} + + trait Snk { + fn sink(&self, t: usize) { + actual_sink(t) + } + } + + struct Wrap(T); + + impl Wrap { + #[paralegal::analyze] + fn main(&self, s: &S) { + s.sink(self.0.source()) + } + } + ))?; + + test.run(simple_policy) +} + +#[test] +#[ignore = "Default methods with generics don't resolve properly. See https://github.com/brownsys/paralegal/issues/152"] +fn default_method_with_generic() -> Result<()> { let test = Test::new(stringify!( #[paralegal::marker(source, return)] fn actual_source() -> usize { diff --git a/crates/paralegal-policy/tests/freedit.rs b/crates/paralegal-policy/tests/freedit.rs index 26e2b79dfb..1674ecfdff 100644 --- a/crates/paralegal-policy/tests/freedit.rs +++ b/crates/paralegal-policy/tests/freedit.rs @@ -132,7 +132,7 @@ fn simple_monomorphization() -> Result<()> { #[paralegal::analyze] fn unconnected() { - Receiver.target(Donator.source()) + Donator.target(Receiver.source()) } ))?; test.run(|ctx| { @@ -146,7 +146,7 @@ fn simple_monomorphization() -> Result<()> { .filter(|n| n.controller_id() == ctx.id()) .collect(); - let expect_connect = ctx.current().name.as_str() != "connected"; + let expect_connect = ctx.current().name.as_str() == "connected"; assert_error!( ctx, @@ -263,3 +263,74 @@ fn markers_on_generic_calls() -> Result<()> { Ok(()) }) } + +#[test] +fn finding_utc_now() -> Result<()> { + let mut test = Test::new(stringify!( + use sled::Db; + use chrono::Utc; + use thiserror::Error; + + #[derive(Error, Debug)] + pub enum AppError { + #[error("Sled db error: {}", .0)] + SledError(#[from] sled::Error), + #[error(transparent)] + Utf8Error(#[from] std::str::Utf8Error), + } + + pub async fn clear_invalid(db: &Db, tree_name: &str) -> Result<(), AppError> { + // let tree = db.open_tree(tree_name)?; + // for i in tree.iter() { + // let (k, _) = i?; + // let k_str = std::str::from_utf8(&k)?; + // let time_stamp = k_str + // .split_once('_') + // .and_then(|s| i64::from_str_radix(s.0, 16).ok()); + let time_stamp = Some(0_i64); + if let Some(time_stamp) = time_stamp { + if time_stamp < Utc::now().timestamp() { + panic!() + //tree.remove(k)?; + } + } + //} + Ok(()) + } + + #[paralegal::analyze] + pub async fn user_chron_job() -> ! { + let db = sled::Config::default().open().unwrap(); + loop { + //sleep_seconds(600).await; + clear_invalid(&db, "dummy").await.unwrap() + //sleep_seconds(3600 * 4).await; + } + } + ))?; + test.with_external_annotations( + " + [[\"chrono::Utc::now\"]] + marker = \"time\" + on_return = true + ", + ) + .with_dep([ + "chrono@0.4.38", + "--no-default-features", + "--features", + "clock", + ]) + .with_dep(["sled@0.34.7"]) + .with_dep(["thiserror@1"]); + test.run(|ctx| { + assert_error!( + ctx, + ctx.marked_nodes(Identifier::new_intern("time")) + .next() + .is_some(), + "No time found" + ); + Ok(()) + }) +} diff --git a/crates/paralegal-policy/tests/helpers/mod.rs b/crates/paralegal-policy/tests/helpers/mod.rs index 181eb34616..00dad610c3 100644 --- a/crates/paralegal-policy/tests/helpers/mod.rs +++ b/crates/paralegal-policy/tests/helpers/mod.rs @@ -8,7 +8,7 @@ use std::{ path::{Path, PathBuf}, process::Command, sync::Arc, - time::SystemTime, + time::{SystemTime, UNIX_EPOCH}, }; use anyhow::anyhow; @@ -34,14 +34,16 @@ lazy_static::lazy_static! { fn temporary_directory(to_hash: &impl Hash) -> Result { let tmpdir = env::temp_dir(); - let secs = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?; let mut hasher = DefaultHasher::new(); - secs.hash(&mut hasher); to_hash.hash(&mut hasher); + let t = SystemTime::now().duration_since(UNIX_EPOCH)?; + t.hash(&mut hasher); let hash = hasher.finish(); let short_hash = hash % 0x1_000_000; let path = tmpdir.join(format!("test-crate-{short_hash:06x}")); - fs::create_dir(&path)?; + if !path.exists() { + fs::create_dir(&path)?; + } Ok(path) } @@ -70,6 +72,16 @@ impl Test { pub fn new(code: impl Into) -> Result { let code = code.into(); let tempdir = temporary_directory(&code)?; + for entry in fs::read_dir(&tempdir)? { + let f = entry?; + let typ = f.file_type()?; + if typ.is_dir() { + fs::remove_dir_all(f.path())?; + } else if typ.is_file() { + fs::remove_file(f.path())?; + } + } + println!("Running in {}", tempdir.display()); Ok(Self { code, external_ann_file_name: tempdir.join("external_annotations.toml"), diff --git a/crates/paralegal-policy/tests/lemmy.rs b/crates/paralegal-policy/tests/lemmy.rs index a9ca283fc1..2e99cf7cff 100644 --- a/crates/paralegal-policy/tests/lemmy.rs +++ b/crates/paralegal-policy/tests/lemmy.rs @@ -4,7 +4,7 @@ use std::{collections::hash_map::RandomState, sync::Arc}; use helpers::{Result, Test}; use paralegal_policy::{ - assert_error, assert_warning, Context, Diagnostics, EdgeSelection, NodeExt, + assert_error, assert_warning, Context, Diagnostics, EdgeSelection, NodeExt, NodeQueries, }; use paralegal_spdg::{GlobalNode, Identifier}; @@ -128,6 +128,58 @@ fn support_calling_async_trait_0_1_53() -> Result<()> { test.run(calling_async_trait_policy) } +fn call_async_trait_single_inline_with_version(v: &str) -> Result<()> { + let mut test = Test::new(stringify!( + #[paralegal::marker(marked, return)] + fn apply_marker(i: T) -> T { + i + } + + struct Ctx; + #[async_trait::async_trait(?Send)] + trait Trait { + async fn transform(&self, i: usize) -> usize; + } + + #[async_trait::async_trait(?Send)] + impl Trait for Ctx { + async fn transform(&self, i: usize) -> usize { + apply_marker(i) + } + } + + #[paralegal::analyze] + async fn main() { + assert_eq!(Ctx.transform(0).await, 0); + } + ))?; + test.with_dep([v]); + test.run(|ctx| { + let marked = ctx + .marked_nodes(Identifier::new_intern("marked")) + .collect::>(); + assert!(!marked.is_empty()); + for src in marked.iter() { + for sink in marked.iter() { + assert!(src == sink || !src.flows_to(*sink, &ctx, EdgeSelection::Data)); + } + } + Ok(()) + }) +} + +#[test] +#[ignore = "No support yet for calling `async_trait` functions, as that requires (a form of) `dyn` handling"] +fn call_async_trait_single_inline_0_1_53() -> Result<()> { + call_async_trait_single_inline_with_version("async-trait@=0.1.53") +} + +#[test] +#[ignore = "No support yet for calling `async_trait` functions, as that requires (a form of) `dyn` handling"] +fn call_async_trait_single_inline_latest() -> Result<()> { + call_async_trait_single_inline_with_version("async-trait") +} + #[test] fn transitive_control_flow() -> Result<()> { let test = Test::new(stringify!( diff --git a/crates/paralegal-policy/tests/misc_async.rs b/crates/paralegal-policy/tests/misc_async.rs index ede82774ea..7cc7a27f62 100644 --- a/crates/paralegal-policy/tests/misc_async.rs +++ b/crates/paralegal-policy/tests/misc_async.rs @@ -55,3 +55,20 @@ on_argument = [1] Ok(()) }) } + +#[test] +#[ignored = "https://github.com/brownsys/paralegal/issues/159"] +fn oneshot_channel() -> Result<()> { + let mut test = Test::new(stringify!( + #[paralegal::analyze] + async fn main() { + let (_, receiver) = tokio::sync::oneshot::channel(); + + receiver.await.unwrap() + } + ))?; + + test.with_dep(["tokio", "--features", "sync"]); + + test.run(|_ctx| Ok(())) +} diff --git a/crates/paralegal-spdg/src/dot.rs b/crates/paralegal-spdg/src/dot.rs index a6d14ca504..d93d0f1770 100644 --- a/crates/paralegal-spdg/src/dot.rs +++ b/crates/paralegal-spdg/src/dot.rs @@ -1,8 +1,7 @@ //! Display SPDGs as dot graphs -use crate::{GlobalEdge, InstructionKind, Node, ProgramDescription}; +use crate::{Endpoint, GlobalEdge, InstructionKind, Node, ProgramDescription}; use dot::{CompassPoint, Edges, Id, LabelText, Nodes}; -use flowistry_pdg::rustc_portable::LocalDefId; use flowistry_pdg::{CallString, RichLocation}; use petgraph::prelude::EdgeRef; @@ -11,14 +10,14 @@ use std::collections::HashMap; struct DotPrintableProgramDescription<'d> { spdg: &'d ProgramDescription, - call_sites: HashMap)>, - selected_controllers: Vec, + call_sites: HashMap)>, + selected_controllers: Vec, } impl<'d> DotPrintableProgramDescription<'d> { pub fn new_for_selection( spdg: &'d ProgramDescription, - mut selector: impl FnMut(LocalDefId) -> bool, + mut selector: impl FnMut(Endpoint) -> bool, ) -> Self { let selected_controllers: Vec<_> = spdg .controllers @@ -108,7 +107,7 @@ impl<'a, 'd> dot::Labeller<'a, CallString, GlobalEdge> for DotPrintableProgramDe fn node_label(&'a self, n: &CallString) -> LabelText<'a> { let (ctrl_id, nodes) = &self.call_sites[n]; let ctrl = &self.spdg.controllers[ctrl_id]; - let instruction = &self.spdg.instruction_info[&n.leaf()]; + let instruction = &self.spdg.instruction_info[&n]; let write_label = || { use std::fmt::Write; @@ -127,6 +126,7 @@ impl<'a, 'd> dot::Labeller<'a, CallString, GlobalEdge> for DotPrintableProgramDe s.push('*'); } InstructionKind::Return => s.push_str("end"), + InstructionKind::SwitchInt => s.push('C'), }; for &n in nodes { @@ -197,7 +197,7 @@ pub fn dump(spdg: &ProgramDescription, out: W) -> std::io::Re pub fn dump_for_controller( spdg: &ProgramDescription, out: impl std::io::Write, - controller_id: LocalDefId, + controller_id: Endpoint, ) -> std::io::Result<()> { let mut found = false; dump_for_selection(spdg, out, |l| { @@ -218,7 +218,7 @@ pub fn dump_for_controller( pub fn dump_for_selection( spdg: &ProgramDescription, mut out: impl std::io::Write, - selector: impl FnMut(LocalDefId) -> bool, + selector: impl FnMut(Endpoint) -> bool, ) -> std::io::Result<()> { let printable = DotPrintableProgramDescription::new_for_selection(spdg, selector); dot::render(&printable, &mut out) diff --git a/crates/paralegal-spdg/src/lib.rs b/crates/paralegal-spdg/src/lib.rs index 25954c91db..9fa21eb6c1 100644 --- a/crates/paralegal-spdg/src/lib.rs +++ b/crates/paralegal-spdg/src/lib.rs @@ -19,6 +19,11 @@ pub(crate) mod rustc { pub use middle::mir; } +#[cfg(feature = "rustc")] +extern crate rustc_macros; +#[cfg(feature = "rustc")] +extern crate rustc_serialize; + extern crate strum; pub use flowistry_pdg::*; @@ -31,6 +36,8 @@ pub mod utils; use internment::Intern; use itertools::Itertools; +#[cfg(feature = "rustc")] +use rustc_macros::{Decodable, Encodable}; use rustc_portable::DefId; use serde::{Deserialize, Serialize}; use std::time::Duration; @@ -49,7 +56,7 @@ pub use std::collections::{HashMap, HashSet}; use std::fmt::{Display, Formatter}; /// The types of identifiers that identify an entrypoint -pub type Endpoint = LocalDefId; +pub type Endpoint = DefId; /// Identifiers for types pub type TypeId = DefId; /// Identifiers for functions @@ -264,9 +271,10 @@ impl Span { /// Metadata on a function call. #[derive(Debug, Clone, Copy, Serialize, Deserialize, Eq, Ord, PartialOrd, PartialEq)] +#[cfg_attr(feature = "rustc", derive(Encodable, Decodable))] pub struct FunctionCallInfo { - /// Has this call been inlined - pub is_inlined: bool, + // /// Has this call been inlined + // pub is_inlined: bool, /// What is the ID of the item that was called here. #[cfg_attr(feature = "rustc", serde(with = "rustc_proxies::DefId"))] pub id: DefId, @@ -281,8 +289,10 @@ pub enum InstructionKind { Statement, /// A function call FunctionCall(FunctionCallInfo), - /// A basic block terminator, usually switchInt + /// Some other terminator Terminator, + /// A switch int terminator + SwitchInt, /// The beginning of a function Start, /// The merged exit points of a function @@ -320,7 +330,7 @@ pub type ControllerMap = HashMap; #[derive(Serialize, Deserialize, Debug)] pub struct ProgramDescription { /// Entry points we analyzed and their PDGs - #[cfg_attr(feature = "rustc", serde(with = "ser_localdefid_map"))] + #[cfg_attr(feature = "rustc", serde(with = "ser_defid_map"))] #[cfg_attr(not(feature = "rustc"), serde(with = "serde_map_via_vec"))] pub controllers: ControllerMap, @@ -332,31 +342,29 @@ pub struct ProgramDescription { /// Metadata about the instructions that are executed at all program /// locations we know about. #[serde(with = "serde_map_via_vec")] - pub instruction_info: HashMap, + pub instruction_info: HashMap, #[cfg_attr(not(feature = "rustc"), serde(with = "serde_map_via_vec"))] #[cfg_attr(feature = "rustc", serde(with = "ser_defid_map"))] /// Metadata about the `DefId`s pub def_info: HashMap, - /// How many marker annotations were found - pub marker_annotation_count: u32, - /// How long rustc ran before out plugin executed + + /// INFO: Not implemented, always 0 pub rustc_time: Duration, - /// The number of functions we produced a PDG for + /// INFO: Not implemented, always 0 + pub marker_annotation_count: u32, + /// INFO: Not implemented, always 0 pub dedup_functions: u32, - /// The lines of code corresponding to the functions from - /// [`Self::dedup_functions`]. + /// INFO: Not implemented, always 0 pub dedup_locs: u32, - /// The number of functions we produced PDGs for or we inspected to check - /// for markers. - pub seen_functions: u32, - /// The lines of code corresponding to the functions from - /// [`Self::seen_functions`]. This is the sum of all - /// `analyzed_locs` of the controllers but deduplicated. + /// INFO: Not implemented, always 0 pub seen_locs: u32, - #[doc(hidden)] - #[serde(with = "ser_localdefid_map")] - pub analyzed_spans: HashMap, + /// INFO: Not implemented, always 0 + pub seen_functions: u32, + #[cfg_attr(not(feature = "rustc"), serde(with = "serde_map_via_vec"))] + #[cfg_attr(feature = "rustc", serde(with = "ser_defid_map"))] + /// INFO: Not implemented, always emtpy + pub analyzed_spans: HashMap, } /// Metadata about a type @@ -499,7 +507,7 @@ pub fn hash_pls(t: T) -> u64 { /// Return type of [`IntoIterGlobalNodes::iter_global_nodes`]. pub struct GlobalNodeIter { - controller_id: LocalDefId, + controller_id: DefId, iter: I::Iter, } @@ -526,7 +534,7 @@ pub trait IntoIterGlobalNodes: Sized + Copy { fn iter_nodes(self) -> Self::Iter; /// The controller id all of these nodes are located in. - fn controller_id(self) -> LocalDefId; + fn controller_id(self) -> DefId; /// Iterate all nodes as globally identified one's. /// @@ -565,13 +573,13 @@ pub type Node = NodeIndex; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct GlobalNode { node: Node, - controller_id: LocalDefId, + controller_id: DefId, } impl GlobalNode { /// Create a new node with no guarantee that it exists in the SPDG of the /// controller. - pub fn unsafe_new(ctrl_id: LocalDefId, index: usize) -> Self { + pub fn unsafe_new(ctrl_id: DefId, index: usize) -> Self { GlobalNode { controller_id: ctrl_id, node: crate::Node::new(index), @@ -582,7 +590,7 @@ impl GlobalNode { /// particular SPDG with it's controller id. /// /// Meant for internal use only. - pub fn from_local_node(ctrl_id: LocalDefId, node: Node) -> Self { + pub fn from_local_node(ctrl_id: DefId, node: Node) -> Self { GlobalNode { controller_id: ctrl_id, node, @@ -595,7 +603,7 @@ impl GlobalNode { } /// The identifier for the SPDG this node is contained in - pub fn controller_id(self) -> LocalDefId { + pub fn controller_id(self) -> DefId { self.controller_id } } @@ -606,7 +614,7 @@ impl IntoIterGlobalNodes for GlobalNode { std::iter::once(self.local_node()) } - fn controller_id(self) -> LocalDefId { + fn controller_id(self) -> DefId { self.controller_id } } @@ -615,7 +623,7 @@ impl IntoIterGlobalNodes for GlobalNode { pub mod node_cluster { use std::ops::Range; - use flowistry_pdg::rustc_portable::LocalDefId; + use flowistry_pdg::rustc_portable::DefId; use crate::{GlobalNode, IntoIterGlobalNodes, Node}; @@ -626,7 +634,7 @@ pub mod node_cluster { /// individual [`GlobalNode`]s #[derive(Debug, Hash, Clone)] pub struct NodeCluster { - controller_id: LocalDefId, + controller_id: DefId, nodes: Box<[Node]>, } @@ -665,7 +673,7 @@ pub mod node_cluster { self.iter() } - fn controller_id(self) -> LocalDefId { + fn controller_id(self) -> DefId { self.controller_id } } @@ -683,7 +691,7 @@ pub mod node_cluster { impl NodeCluster { /// Create a new cluster. This for internal use. - pub fn new(controller_id: LocalDefId, nodes: impl IntoIterator) -> Self { + pub fn new(controller_id: DefId, nodes: impl IntoIterator) -> Self { Self { controller_id, nodes: nodes.into_iter().collect::>().into(), @@ -698,7 +706,7 @@ pub mod node_cluster { } /// Controller that these nodes belong to - pub fn controller_id(&self) -> LocalDefId { + pub fn controller_id(&self) -> DefId { self.controller_id } @@ -731,12 +739,12 @@ pub use node_cluster::NodeCluster; #[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] pub struct GlobalEdge { index: EdgeIndex, - controller_id: LocalDefId, + controller_id: Endpoint, } impl GlobalEdge { /// The id of the controller that this edge is located in - pub fn controller_id(self) -> LocalDefId { + pub fn controller_id(self) -> Endpoint { self.controller_id } } @@ -812,8 +820,8 @@ pub struct SPDG { /// The module path to this controller function pub path: Box<[Identifier]>, /// The id - #[cfg_attr(feature = "rustc", serde(with = "rustc_proxies::LocalDefId"))] - pub id: LocalDefId, + #[cfg_attr(feature = "rustc", serde(with = "rustc_proxies::DefId"))] + pub id: DefId, /// The PDG pub graph: SPDGImpl, /// Nodes to which markers are assigned. @@ -826,11 +834,11 @@ pub struct SPDG { /// that this contains multiple types for a single node, because it hold /// top-level types and subtypes that may be marked. pub type_assigns: HashMap, - /// Statistics + /// INFO: Not Implemented, always zero pub statistics: SPDGStats, } -#[derive(Clone, Serialize, Deserialize, Debug)] +#[derive(Clone, Serialize, Deserialize, Debug, Default)] /// Statistics about the code that produced an SPDG pub struct SPDGStats { /// The number of unique lines of code we generated a PDG for. This means diff --git a/crates/paralegal-spdg/src/ser.rs b/crates/paralegal-spdg/src/ser.rs index 59ff56be98..a1e1e96e59 100644 --- a/crates/paralegal-spdg/src/ser.rs +++ b/crates/paralegal-spdg/src/ser.rs @@ -47,7 +47,8 @@ impl ProgramDescription { /// Read `self` using the configured serialization format pub fn canonical_read(path: impl AsRef) -> Result { let path = path.as_ref(); - let in_file = File::open(path)?; + let in_file = File::open(path) + .with_context(|| format!("Reading PDG file from {}", path.display()))?; cfg_if! { if #[cfg(feature = "binenc")] { let read = bincode::deserialize_from( diff --git a/crates/paralegal-spdg/src/tiny_bitset.rs b/crates/paralegal-spdg/src/tiny_bitset.rs index 8fd3b2eca2..3b1343e403 100644 --- a/crates/paralegal-spdg/src/tiny_bitset.rs +++ b/crates/paralegal-spdg/src/tiny_bitset.rs @@ -1,10 +1,14 @@ use crate::utils::display_list; use std::fmt::{Display, Formatter}; +#[cfg(feature = "rustc")] +use rustc_macros::{Decodable, Encodable}; + /// A bit-set implemented with a [`u16`], supporting domains up to 16 elements. #[derive( Clone, Eq, PartialEq, PartialOrd, Ord, Hash, Copy, serde::Serialize, serde::Deserialize, )] +#[cfg_attr(feature = "rustc", derive(Encodable, Decodable))] pub struct TinyBitSet(u16); impl Default for TinyBitSet {