Skip to content

Commit

Permalink
typechecker: cache unlinked request envs and remove typecheck_multi (#…
Browse files Browse the repository at this point in the history
…1426)

Signed-off-by: Craig Disselkoen <[email protected]>
  • Loading branch information
cdisselkoen authored Jan 15, 2025
1 parent 875ba48 commit eafae7e
Showing 1 changed file with 30 additions and 65 deletions.
95 changes: 30 additions & 65 deletions cedar-policy-validator/src/typecheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ mod typecheck_answer;
use itertools::Itertools;
pub(crate) use typecheck_answer::TypecheckAnswer;

use std::{
borrow::Cow,
collections::{HashMap, HashSet},
iter::zip,
};
use std::{borrow::Cow, collections::HashSet, iter::zip};

use crate::{
extension_schema::ExtensionFunctionType,
Expand All @@ -49,7 +45,6 @@ use cedar_policy_core::{
PrincipalOrResourceConstraint, SlotId, Template, UnaryOp, Var,
},
expr_builder::ExprBuilder as _,
parser::Loc,
};

#[cfg(not(target_arch = "wasm32"))]
Expand All @@ -73,6 +68,10 @@ pub struct Typechecker<'a> {
schema: &'a ValidatorSchema,
extensions: &'static ExtensionSchemas<'static>,
mode: ValidationMode,
/// List of valid (unlinked) `RequestEnv`s for this schema.
/// Cached here so it can be computed once (during `Typechecker`
/// construction) and potentially used for many typechecking operations.
unlinked_envs: Vec<RequestEnv<'a>>,
}

impl<'a> Typechecker<'a> {
Expand All @@ -82,6 +81,7 @@ impl<'a> Typechecker<'a> {
schema,
extensions: ExtensionSchemas::all_available(),
mode,
unlinked_envs: Self::unlinked_request_envs(schema, mode).collect(),
}
}

Expand Down Expand Up @@ -143,28 +143,7 @@ impl<'a> Typechecker<'a> {
&'b self,
t: &'b Template,
) -> Vec<(RequestEnv<'b>, PolicyCheck)> {
let map = self.typecheck_multi_by_request_env([t]);
map.into_values()
.filter(|(v, _)| !v.is_empty())
.next()
.unwrap_or_default() // if all the entries have empty vecs, return an empty vec
.0
}

/// Same as `typecheck_by_request_env()`, but typechecks multiple policies
/// at once and returns all the results indexed by policy ID, more
/// efficiently than calling `typecheck_by_request_env()` multiple times.
///
/// The `Loc` of each policy is also returned, for error reporting purposes.
///
/// Callers using this as the toplevel entry point, rather than
/// `typecheck_policy()`, will not get `impossible_policy` validation
/// warnings.
pub fn typecheck_multi_by_request_env<'b>(
&'b self,
ts: impl IntoIterator<Item = &'b Template>,
) -> HashMap<PolicyID, (Vec<(RequestEnv<'b>, PolicyCheck)>, Option<Loc>)> {
self.apply_typecheck_fn_by_request_env(ts, |request_env, policy_id, expr| {
self.apply_typecheck_fn_by_request_env(t, |request_env, policy_id, expr| {
let mut type_errors = Vec::new();
let single_env_typechecker = SingleEnvTypechecker {
schema: self.schema,
Expand Down Expand Up @@ -195,57 +174,43 @@ impl<'a> Typechecker<'a> {
})
}

/// Apply `typecheck_fn` to each policy in every schema-defined request
/// Apply `typecheck_fn` to the given policy in every schema-defined request
/// environment, and collect all the results.
///
/// Results are returned indexed by the policy ID of the policy they belong to.
/// Results for a single policy are returned in no particular order.
/// The `Loc` of each policy is also returned, for error reporting purposes.
/// Results are returned in no particular order.
fn apply_typecheck_fn_by_request_env<'b, F, C>(
&'b self,
ts: impl IntoIterator<Item = &'b Template>,
t: &'b Template,
typecheck_fn: F,
) -> HashMap<PolicyID, (Vec<(RequestEnv<'b>, C)>, Option<Loc>)>
) -> Vec<(RequestEnv<'b>, C)>
where
F: Fn(&RequestEnv<'b>, &PolicyID, &Expr) -> C,
{
let mut ret = HashMap::new();

// compute `.condition()` just once for each policy, and cache it here
let ts: Vec<(&'b Template, Expr)> = ts.into_iter().map(|t| (t, t.condition())).collect();

// initialize the entry for each `PolicyID` by inserting the appropriate loc
for (t, _) in &ts {
ret.insert(t.id().clone(), (Vec::new(), t.loc().cloned()));
}
// compute `.condition()` just once, and cache it here
let cond = t.condition();

// Validate each (principal, resource) pair with the substituted policy
// for the corresponding action.
//
// this ordering of loop nesting is chosen in order to call
// `unlinked_request_envs()` just once for all policies
for unlinked_e in self.unlinked_request_envs() {
for (t, cond) in &ts {
// PANIC SAFETY: already inserted this key above
#[allow(clippy::expect_used)]
ret.get_mut(t.id())
.expect("already inserted this key above")
.0
.extend(self.link_request_env(&unlinked_e, t).map(|linked_e| {
let check = typecheck_fn(&linked_e, t.id(), cond);
(linked_e, check)
}));
}
}
ret
self.unlinked_envs
.iter()
.map(|unlinked_e| {
self.link_request_env(unlinked_e, t).map(|linked_e| {
let check = typecheck_fn(&linked_e, t.id(), &cond);
(linked_e, check)
})
})
.flatten()
.collect()
}

fn unlinked_request_envs(&self) -> impl Iterator<Item = RequestEnv<'_>> + '_ {
fn unlinked_request_envs(
schema: &ValidatorSchema,
mode: ValidationMode,
) -> impl Iterator<Item = RequestEnv<'_>> + '_ {
// Gather all of the actions declared in the schema.
let all_actions = self
.schema
let all_actions = schema
.known_action_ids()
.filter_map(|a| self.schema.get_action_id(a));
.filter_map(|a| schema.get_action_id(a));

// For every action compute the cross product of the principal and
// resource applies_to sets.
Expand All @@ -264,7 +229,7 @@ impl<'a> Typechecker<'a> {
})
})
})
.chain(if self.mode.is_partial() {
.chain(if mode.is_partial() {
// A partial schema might not list all actions, and may not
// include all principal and resource types for the listed ones.
// So we typecheck with a fully unknown request to handle these
Expand Down

0 comments on commit eafae7e

Please sign in to comment.