Skip to content

Commit

Permalink
Don't crash on associated types + refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
JustusAdam committed Aug 5, 2024
1 parent d335f19 commit c274519
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 100 deletions.
253 changes: 155 additions & 98 deletions crates/paralegal-flow/src/ana/inline_judge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ use std::rc::Rc;
use flowistry_pdg_construction::{body_cache::BodyCache, CallInfo};
use paralegal_spdg::{utils::write_sep, Identifier};
use rustc_hash::FxHashSet;
use rustc_hir::def_id::{CrateNum, LOCAL_CRATE};
use rustc_hir::{
def::DefKind,
def_id::{CrateNum, DefId, LOCAL_CRATE},
};
use rustc_middle::ty::{
BoundVariableKind, ClauseKind, ImplPolarity, Instance, ParamEnv, TraitPredicate,
AssocKind, BoundVariableKind, Clause, ClauseKind, ImplPolarity, Instance, ParamEnv,
ProjectionPredicate, TraitPredicate, Ty,
};
use rustc_span::{Span, Symbol};
use rustc_type_ir::TyKind;
Expand Down Expand Up @@ -118,111 +122,164 @@ impl<'tcx> InlineJudge<'tcx> {
call_span: Span,
emit_err: bool,
) {
let sess = self.tcx().sess;
let predicates = self
.tcx()
.predicates_of(resolved.def_id())
.instantiate(self.tcx(), resolved.args);
for (clause, span) in &predicates {
let err = move |s: &str| {
let msg = format!("Cannot verify that non-inlined function is safe due to: {s}");
if emit_err {
let mut diagnostic = sess.struct_span_err(span, msg);
diagnostic.span_note(call_span, "Called from here");
diagnostic.emit();
} else {
let mut diagnostic = sess.struct_span_warn(span, msg);
diagnostic.span_note(call_span, "Called from here");
diagnostic.emit();
}
};
let err_markers = |s: &str, markers: &[Identifier]| {
if !markers.is_empty() {
err(&format!(
"{s}: found marker(s) {}",
Print(|fmt| write_sep(fmt, ", ", markers, |elem, fmt| write!(
fmt,
"'{elem}'"
)))
));
}
};
let kind = clause.kind();
for bound in kind.bound_vars() {
match bound {
BoundVariableKind::Ty(t) => err(&format!("bound type {t:?}")),
BoundVariableKind::Const | BoundVariableKind::Region(_) => (),
}
SafetyChecker {
tcx: self.tcx(),
emit_err,
param_env,
resolved,
call_span,
marker_ctx: self.marker_ctx.clone(),
}
.check()
}
}

struct SafetyChecker<'tcx> {
tcx: TyCtxt<'tcx>,
emit_err: bool,
param_env: ParamEnv<'tcx>,
resolved: Instance<'tcx>,
call_span: Span,
marker_ctx: MarkerCtx<'tcx>,
}

impl<'tcx> SafetyChecker<'tcx> {
fn err(&self, s: &str, span: Span) {
let sess = self.tcx.sess;
let msg = format!("Cannot verify that non-inlined function is safe due to: {s}");
if self.emit_err {
let mut diagnostic = sess.struct_span_err(span, msg);
diagnostic.span_note(self.call_span, "Called from here");
diagnostic.emit();
} else {
let mut diagnostic = sess.struct_span_warn(span, msg);
diagnostic.span_note(self.call_span, "Called from here");
diagnostic.emit();
}
}

fn err_markers(&self, s: &str, markers: &[Identifier], span: Span) {
if !markers.is_empty() {
self.err(
&format!(
"{s}: found marker(s) {}",
Print(|fmt| write_sep(fmt, ", ", markers, |elem, fmt| write!(fmt, "'{elem}'")))
),
span,
);
}
}

fn check_projection_predicate(&self, predicate: &ProjectionPredicate<'tcx>, span: Span) {
if let Some(t) = predicate.term.ty() {
let t = self.tcx.normalize_erasing_regions(self.param_env, t);
let markers = self.marker_ctx.deep_type_markers(t);
if !markers.is_empty() {
let markers = markers.iter().map(|t| t.1).collect::<Box<_>>();
self.err_markers(
&format!("type {t:?} is not approximation safe"),
&markers,
span,
);
}
}
}

match kind.skip_binder() {
ClauseKind::TypeOutlives(_)
| ClauseKind::WellFormed(_)
| ClauseKind::ConstArgHasType(..)
| ClauseKind::ConstEvaluatable(_)
| ClauseKind::RegionOutlives(_) => {
// These predicates do not allow for "code injection" since they do not concern things that can be marked.
}
ClauseKind::Projection(p) => {
// ProjectionPredicate(AliasTy { args: [Iter], def_id: IntoIterator::Item }, Term::Ty(<std::array::IntoIter<M, 1> as std::iter::Iterator>::Item))
println!("Found projection {p:?}");
if let Some(t) = p.term.ty() {
let t = self.tcx.normalize_erasing_regions(param_env, t);
let markers = self.marker_ctx().deep_type_markers(t);
if !markers.is_empty() {
let markers = markers.iter().map(|t| t.1).collect::<Box<_>>();
err_markers(&format!("type {t:?} is not approximation safe"), &markers);
fn check_trait_predicate(&self, predicate: &TraitPredicate<'tcx>, span: Span) {
match predicate {
TraitPredicate {
polarity: ImplPolarity::Positive,
trait_ref,
} if !self.tcx.trait_is_auto(trait_ref.def_id) => {
let ref_1 = trait_ref.args[0];
let Some(self_ty) = ref_1.as_type() else {
self.err("expected self type to be type, got {ref_1:?}", span);
return;
};

if self.tcx.is_fn_trait(trait_ref.def_id) {
let instance = match self_ty.kind() {
TyKind::Closure(id, args) | TyKind::FnDef(id, args) => {
Instance::resolve(self.tcx, ParamEnv::reveal_all(), *id, args)
}
}
}
ClauseKind::Trait(TraitPredicate {
polarity: ImplPolarity::Positive,
trait_ref,
}) if !self.tcx().trait_is_auto(trait_ref.def_id) => {
let ref_1 = trait_ref.args[0];
let Some(self_ty) = ref_1.as_type() else {
err("expected self type to be type, got {ref_1:?}");
continue;
};

if self.tcx().is_fn_trait(trait_ref.def_id) {
let instance = match self_ty.kind() {
TyKind::Closure(id, args) | TyKind::FnDef(id, args) => {
Instance::resolve(self.tcx(), ParamEnv::reveal_all(), *id, args)
}
_ => {
err(&format!(
_ => {
self.err(
&format!(
"fn-trait instance for {self_ty:?} not being a function of closure"
));
continue;
}
}
.unwrap()
.unwrap();
let markers = self.marker_ctx().get_reachable_markers(instance);
if !markers.is_empty() {
err_markers(
&format!("closure {instance:?} is not approximation safe"),
markers,
),
span,
);
return;
}
} else {
self.tcx()
.for_each_relevant_impl(trait_ref.def_id, self_ty, |impl_| {
for method in self.tcx().associated_item_def_ids(impl_) {
let markers = self.marker_ctx().get_reachable_markers(*method);
if !markers.is_empty() {
err_markers(
&format!("impl {impl_:?} for {self_ty:?}"),
markers,
)
}
}
})
}
.unwrap()
.unwrap();
let markers = self.marker_ctx.get_reachable_markers(instance);
if !markers.is_empty() {
self.err_markers(
&format!("closure {instance:?} is not approximation safe"),
markers,
span,
);
}
} else {
self.tcx
.for_each_relevant_impl(trait_ref.def_id, self_ty, |r#impl| {
self.check_impl(r#impl, self_ty, span)
})
}
}
_ => (),
}
}

fn check_impl(&self, r#impl: DefId, self_ty: Ty<'tcx>, span: Span) {
for item in self.tcx.associated_items(r#impl).in_definition_order() {
// NOTE: We don't need to check markers on types here, because they
// will be checked if there is a method that produces (or consumes)
// this type.
match item.kind {
AssocKind::Fn => {
let method = item.def_id;
let markers = self.marker_ctx.get_reachable_markers(method);
if !markers.is_empty() {
self.err_markers(&self.tcx.def_path_str(method), markers, span)
}
}
_ => (),
AssocKind::Const | AssocKind::Type => (),
}
}
}

fn check_predicate(&self, clause: Clause<'tcx>, span: Span) {
let kind = clause.kind();
for bound in kind.bound_vars() {
match bound {
BoundVariableKind::Ty(t) => self.err(&format!("bound type {t:?}"), span),
BoundVariableKind::Const | BoundVariableKind::Region(_) => (),
}
}

match &kind.skip_binder() {
ClauseKind::TypeOutlives(_)
| ClauseKind::WellFormed(_)
| ClauseKind::ConstArgHasType(..)
| ClauseKind::ConstEvaluatable(_)
| ClauseKind::RegionOutlives(_) => {
// These predicates do not allow for "code injection" since they do not concern things that can be marked.
}
ClauseKind::Projection(predicate) => self.check_projection_predicate(predicate, span),
ClauseKind::Trait(predicate) => self.check_trait_predicate(predicate, span),
}
}

fn check(&self) {
let predicates = self
.tcx
.predicates_of(self.resolved.def_id())
.instantiate(self.tcx, self.resolved.args);
for (clause, span) in &predicates {
self.check_predicate(clause, span)
}
}
}
2 changes: 0 additions & 2 deletions crates/paralegal-flow/src/ana/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,6 @@ impl<'tcx> CallChangeCallback<'tcx> for MyCallback<'tcx> {
fn on_inline(&self, info: CallInfo<'tcx, '_>) -> CallChanges<'tcx> {
let changes = CallChanges::default();

let mut skip = SkipCall::Skip;

let skip = match self.judge.should_inline(&info) {
InlineJudgement::NoInline => SkipCall::Skip,
InlineJudgement::UseFlowModel(model) => {
Expand Down

0 comments on commit c274519

Please sign in to comment.