From 89d0c20e1c770168d775af3774a7f1aca52bb0f2 Mon Sep 17 00:00:00 2001 From: Craig Disselkoen Date: Mon, 13 Jan 2025 12:49:01 -0500 Subject: [PATCH 1/7] bump language version to 4.2 (#1423) Signed-off-by: Craig Disselkoen --- cedar-policy/src/api.rs | 2 +- cedar-policy/src/tests.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cedar-policy/src/api.rs b/cedar-policy/src/api.rs index dc10a7214..c68664cf3 100644 --- a/cedar-policy/src/api.rs +++ b/cedar-policy/src/api.rs @@ -71,7 +71,7 @@ pub(crate) mod version { static ref SDK_VERSION: Version = env!("CARGO_PKG_VERSION").parse().unwrap(); // Cedar language version // The patch version field may be unnecessary - static ref LANG_VERSION: Version = Version::new(4, 0, 0); + static ref LANG_VERSION: Version = Version::new(4, 2, 0); } /// Get the Cedar SDK Semantic Versioning version #[allow(clippy::module_name_repetitions)] diff --git a/cedar-policy/src/tests.rs b/cedar-policy/src/tests.rs index 3e2e1afbc..6a9f6a764 100644 --- a/cedar-policy/src/tests.rs +++ b/cedar-policy/src/tests.rs @@ -6292,7 +6292,7 @@ mod version_tests { #[test] fn test_lang_version() { - assert_eq!(get_lang_version().to_string(), "4.0.0"); + assert_eq!(get_lang_version().to_string(), "4.2.0"); } } From 4c9c8d892982df5e540fdfe6c569249551525159 Mon Sep 17 00:00:00 2001 From: Craig Disselkoen Date: Mon, 13 Jan 2025 14:56:31 -0500 Subject: [PATCH 2/7] partial fix for 1421 (#1422) Signed-off-by: Craig Disselkoen --- cedar-policy-validator/src/typecheck.rs | 122 ++++++++++++------------ 1 file changed, 59 insertions(+), 63 deletions(-) diff --git a/cedar-policy-validator/src/typecheck.rs b/cedar-policy-validator/src/typecheck.rs index 8f00f5a26..49e781aef 100644 --- a/cedar-policy-validator/src/typecheck.rs +++ b/cedar-policy-validator/src/typecheck.rs @@ -24,7 +24,11 @@ mod typecheck_answer; use itertools::Itertools; pub(crate) use typecheck_answer::TypecheckAnswer; -use std::{borrow::Cow, collections::HashSet, iter::zip}; +use std::{ + borrow::Cow, + collections::{HashMap, HashSet}, + iter::zip, +}; use crate::{ extension_schema::ExtensionFunctionType, @@ -46,6 +50,7 @@ use cedar_policy_core::{ }, expr_builder::ExprBuilder as _, extensions::Extensions, + parser::Loc, }; #[cfg(not(target_arch = "wasm32"))] @@ -143,10 +148,27 @@ impl<'a> Typechecker<'a> { &'b self, t: &'b Template, ) -> Vec<(RequestEnv<'b>, PolicyCheck)> { - self.apply_typecheck_fn_by_request_env(t, |request, expr| { + let map = self.typecheck_multi_by_request_env([t]); + map.into_values() + .filter(|(v, _)| !v.is_empty()) + .next() + .unwrap_or_default() // if all the entries have empty vecs, return an empty vec + .0 + } + + /// Same as `typecheck_by_request_env()`, but typechecks multiple policies + /// at once and returns all the results indexed by policy ID, more + /// efficiently than calling `typecheck_by_request_env()` multiple times. + /// + /// The `Loc` of each policy is also returned, for error reporting purposes. + pub fn typecheck_multi_by_request_env<'b>( + &'b self, + ts: impl IntoIterator, + ) -> HashMap, PolicyCheck)>, Option)> { + self.apply_typecheck_fn_by_request_env(ts, |request, expr| { let mut type_errors = Vec::new(); let empty_prior_capability = CapabilitySet::new(); - let ty = self.expect_type( + let ans = self.expect_type( request, &empty_prior_capability, expr, @@ -155,8 +177,8 @@ impl<'a> Typechecker<'a> { |_| None, ); - let is_false = ty.contains_type(&Type::singleton_boolean(false)); - match (is_false, ty.typechecked(), ty.into_typed_expr()) { + let is_false = ans.contains_type(&Type::singleton_boolean(false)); + match (is_false, ans.typechecked(), ans.into_typed_expr()) { (false, true, None) => PolicyCheck::Fail(type_errors), (false, true, Some(e)) => PolicyCheck::Success(e), (false, false, _) => PolicyCheck::Fail(type_errors), @@ -168,75 +190,49 @@ impl<'a> Typechecker<'a> { }) } - /// Utility abstracting the common logic for strict and regular typechecking - /// by request environment. + /// Apply `typecheck_fn` to each policy in every schema-defined request + /// environment, and collect all the results. + /// + /// Results are returned indexed by the policy ID of the policy they belong to. + /// Results for a single policy are returned in no particular order. + /// The `Loc` of each policy is also returned, for error reporting purposes. fn apply_typecheck_fn_by_request_env<'b, F, C>( &'b self, - t: &'b Template, + ts: impl IntoIterator, typecheck_fn: F, - ) -> Vec<(RequestEnv<'b>, C)> + ) -> HashMap, C)>, Option)> where F: Fn(&RequestEnv<'b>, &Expr) -> C, { - let mut result_checks = Vec::new(); + let mut ret = HashMap::new(); - // Validate each (principal, resource) pair with the substituted policy - // for the corresponding action. Implemented as for loop to make it - // explicit that `expect_type` will be called for every element of - // request_env without short circuiting. - let policy_condition = &t.condition(); - for unlinked_e in self.unlinked_request_envs() { - for linked_e in self.link_request_env(&unlinked_e, t) { - let check = typecheck_fn(&linked_e, policy_condition); - result_checks.push((linked_e, check)) - } - } - result_checks - } + // compute `.condition()` just once for each policy, and cache it here + let ts: Vec<(&'b Template, Expr)> = ts.into_iter().map(|t| (t, t.condition())).collect(); - /// Additional entry point for typechecking requests. This method takes a slice - /// over policies and typechecks each under every schema-defined request environment. - /// - /// The result contains these environments in no particular order, but each list of - /// policy checks will always match the original order. - pub fn multi_typecheck_by_request_env( - &self, - policy_templates: &[&Template], - ) -> Vec<(RequestEnv<'_>, Vec)> { - let mut env_checks = Vec::new(); - for request in self.unlinked_request_envs() { - let mut policy_checks = Vec::new(); - for t in policy_templates.iter() { - let condition_expr = t.condition(); - for linked_env in self.link_request_env(&request, t) { - let mut type_errors = Vec::new(); - let empty_prior_capability = CapabilitySet::new(); - let ty = self.expect_type( - &linked_env, - &empty_prior_capability, - &condition_expr, - Type::primitive_boolean(), - &mut type_errors, - |_| None, - ); + // initialize the entry for each `PolicyID` by inserting the appropriate loc + for (t, _) in &ts { + ret.insert(t.id().clone(), (Vec::new(), t.loc().cloned())); + } - let is_false = ty.contains_type(&Type::singleton_boolean(false)); - match (is_false, ty.typechecked(), ty.into_typed_expr()) { - (false, true, None) => policy_checks.push(PolicyCheck::Fail(type_errors)), - (false, true, Some(e)) => policy_checks.push(PolicyCheck::Success(e)), - (false, false, _) => policy_checks.push(PolicyCheck::Fail(type_errors)), - (true, _, Some(e)) => { - policy_checks.push(PolicyCheck::Irrelevant(type_errors, e)) - } - // PANIC SAFETY: `is_false` implies `e` has a type implies `Some(e)`. - #[allow(clippy::unreachable)] - (true, _, None) => unreachable!(), - } - } + // Validate each (principal, resource) pair with the substituted policy + // for the corresponding action. + // + // this ordering of loop nesting is chosen in order to call + // `unlinked_request_envs()` just once for all policies + for unlinked_e in self.unlinked_request_envs() { + for (t, cond) in &ts { + // PANIC SAFETY: already inserted this key above + #[allow(clippy::expect_used)] + ret.get_mut(t.id()) + .expect("already inserted this key above") + .0 + .extend(self.link_request_env(&unlinked_e, t).map(|linked_e| { + let check = typecheck_fn(&linked_e, cond); + (linked_e, check) + })); } - env_checks.push((request, policy_checks)); } - env_checks + ret } fn unlinked_request_envs(&self) -> impl Iterator> + '_ { From 025298829e8ea255c95b76d5f15c2723a0d89bb7 Mon Sep 17 00:00:00 2001 From: shaobo-he-aws <130499339+shaobo-he-aws@users.noreply.github.com> Date: Mon, 13 Jan 2025 13:20:46 -0800 Subject: [PATCH 3/7] Add APIs to get schema annotations (#1389) Signed-off-by: Shaobo He Co-authored-by: John Kastner <130772734+john-h-kastner-aws@users.noreply.github.com> Co-authored-by: Craig Disselkoen --- cedar-policy/src/api.rs | 166 +++++++++++++++++++ cedar-policy/src/tests.rs | 336 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 502 insertions(+) diff --git a/cedar-policy/src/api.rs b/cedar-policy/src/api.rs index c68664cf3..ea84298fd 100644 --- a/cedar-policy/src/api.rs +++ b/cedar-policy/src/api.rs @@ -30,6 +30,7 @@ use cedar_policy_validator::entity_manifest; pub use cedar_policy_validator::entity_manifest::{ AccessTrie, EntityManifest, EntityRoot, Fields, RootAccessTrie, }; +use cedar_policy_validator::json_schema; use cedar_policy_validator::typecheck::{PolicyCheck, Typechecker}; pub use id::*; @@ -1397,7 +1398,172 @@ pub struct SchemaFragment { lossless: cedar_policy_validator::json_schema::Fragment, } +fn get_annotation_by_key( + annotations: &est::Annotations, + annotation_key: impl AsRef, +) -> Option<&str> { + annotations + .0 + .get(&annotation_key.as_ref().parse().ok()?) + .map(|value| annotation_value_to_str_ref(value.as_ref())) +} + +fn annotation_value_to_str_ref(value: Option<&ast::Annotation>) -> &str { + value.map_or("", |a| a.as_ref()) +} + +fn annotations_to_pairs(annotations: &est::Annotations) -> impl Iterator { + annotations + .0 + .iter() + .map(|(key, value)| (key.as_ref(), annotation_value_to_str_ref(value.as_ref()))) +} + impl SchemaFragment { + /// Get annotations of a non-empty namespace. + /// + /// We do not allow namespace-level annotations on the empty namespace. + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] + pub fn namespace_annotations( + &self, + namespace: EntityNamespace, + ) -> Option> { + self.lossless + .0 + .get(&Some(namespace.0)) + .map(|ns_def| annotations_to_pairs(&ns_def.annotations)) + } + + /// Get annotation value of a non-empty namespace by annotation key + /// `annotation_key` + /// + /// We do not allow namespace-level annotations on the empty namespace. + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] + /// or `annotation_key` is not a valid annotation key + /// or it does not exist + pub fn namespace_annotation( + &self, + namespace: EntityNamespace, + annotation_key: impl AsRef, + ) -> Option<&str> { + let ns = self.lossless.0.get(&Some(namespace.0))?; + get_annotation_by_key(&ns.annotations, annotation_key) + } + + /// Get annotations of a common type declaration + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] or + /// `ty` is not a valid common type ID or `ty` is not found in the + /// corresponding namespace definition + pub fn common_type_annotations( + &self, + namespace: Option, + ty: &str, + ) -> Option> { + let ns_def = self.lossless.0.get(&namespace.map(|n| n.0))?; + let ty = json_schema::CommonTypeId::new(ast::UnreservedId::from_normalized_str(ty).ok()?) + .ok()?; + ns_def + .common_types + .get(&ty) + .map(|ty| annotations_to_pairs(&ty.annotations)) + } + + /// Get annotation value of a common type declaration by annotation key + /// `annotation_key` + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] + /// or `ty` is not a valid common type ID + /// or `ty` is not found in the corresponding namespace definition + /// or `annotation_key` is not a valid annotation key + /// or it does not exist + pub fn common_type_annotation( + &self, + namespace: Option, + ty: &str, + annotation_key: impl AsRef, + ) -> Option<&str> { + let ns_def = self.lossless.0.get(&namespace.map(|n| n.0))?; + let ty = json_schema::CommonTypeId::new(ast::UnreservedId::from_normalized_str(ty).ok()?) + .ok()?; + get_annotation_by_key(&ns_def.common_types.get(&ty)?.annotations, annotation_key) + } + + /// Get annotations of an entity type declaration + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] or + /// `ty` is not a valid entity type name or `ty` is not found in the + /// corresponding namespace definition + pub fn entity_type_annotations( + &self, + namespace: Option, + ty: &str, + ) -> Option> { + let ns_def = self.lossless.0.get(&namespace.map(|n| n.0))?; + let ty = ast::UnreservedId::from_normalized_str(ty).ok()?; + ns_def + .entity_types + .get(&ty) + .map(|ty| annotations_to_pairs(&ty.annotations)) + } + + /// Get annotation value of an entity type declaration by annotation key + /// `annotation_key` + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] + /// or `ty` is not a valid entity type name + /// or `ty` is not found in the corresponding namespace definition + /// or `annotation_key` is not a valid annotation key + /// or it does not exist + pub fn entity_type_annotation( + &self, + namespace: Option, + ty: &str, + annotation_key: impl AsRef, + ) -> Option<&str> { + let ns_def = self.lossless.0.get(&namespace.map(|n| n.0))?; + let ty = ast::UnreservedId::from_normalized_str(ty).ok()?; + get_annotation_by_key(&ns_def.entity_types.get(&ty)?.annotations, annotation_key) + } + + /// Get annotations of an action declaration + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] or + /// `id` is not found in the corresponding namespace definition + pub fn action_annotations( + &self, + namespace: Option, + id: EntityId, + ) -> Option> { + let ns_def = self.lossless.0.get(&namespace.map(|n| n.0))?; + ns_def + .actions + .get(id.as_ref()) + .map(|a| annotations_to_pairs(&a.annotations)) + } + + /// Get annotation value of an action declaration by annotation key + /// `annotation_key` + /// + /// Returns `None` if `namespace` is not found in the [`SchemaFragment`] + /// or `id` is not found in the corresponding namespace definition + /// or `annotation_key` is not a valid annotation key + /// or it does not exist + pub fn action_annotation( + &self, + namespace: Option, + id: EntityId, + annotation_key: impl AsRef, + ) -> Option<&str> { + let ns_def = self.lossless.0.get(&namespace.map(|n| n.0))?; + get_annotation_by_key( + &ns_def.actions.get(id.as_ref())?.annotations, + annotation_key, + ) + } + /// Extract namespaces defined in this [`SchemaFragment`]. /// /// `None` indicates the empty namespace. diff --git a/cedar-policy/src/tests.rs b/cedar-policy/src/tests.rs index 6a9f6a764..e28920026 100644 --- a/cedar-policy/src/tests.rs +++ b/cedar-policy/src/tests.rs @@ -6524,3 +6524,339 @@ mod reserved_keywords_in_policies { } } } + +mod schema_annotations { + use std::collections::BTreeMap; + + use cool_asserts::assert_matches; + + use crate::EntityNamespace; + + use super::SchemaFragment; + + #[track_caller] + fn example_schema() -> SchemaFragment { + SchemaFragment::from_cedarschema_str( + r#" + @a("a") + @b + entity A1,A2 {}; + @c("c") + @d + type T = Long; + @e("e") + @f + action a1, a2 appliesTo { principal: [A1], resource: [A2] }; + + @m("m") + @n + namespace N { + @a("a") + @b + entity A1,A2 {}; + @c("c") + @d + type T = Long; + @e("e") + @f + action a1, a2 appliesTo { principal: [N::A1], resource: [A2] }; + } + "#, + ) + .expect("should be a valid schema fragment") + .0 + } + + #[test] + fn namespace_annotations() { + let schema = example_schema(); + let namespace: EntityNamespace = "N".parse().expect("should be a valid name"); + let annotations = BTreeMap::from_iter( + schema + .namespace_annotations(namespace.clone()) + .expect("should get annotations"), + ); + assert_eq!(annotations, BTreeMap::from_iter([("m", "m"), ("n", "")])); + assert_matches!( + schema + .namespace_annotations("NM".parse().unwrap()) + .map(|_| ()), + None + ); + + assert_matches!( + schema.namespace_annotation(namespace.clone(), "m"), + Some("m") + ); + assert_matches!( + schema.namespace_annotation(namespace.clone(), "n"), + Some("") + ); + assert_matches!(schema.namespace_annotation(namespace.clone(), "x"), None); + assert_matches!( + schema.namespace_annotation("NM".parse().unwrap(), "n"), + None + ); + } + + #[test] + fn entity_type_annotations() { + let schema = example_schema(); + let annotations = BTreeMap::from_iter([("a", "a"), ("b", "")]); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .entity_type_annotations(None, "A1") + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .entity_type_annotations(None, "A2") + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .entity_type_annotations( + Some("N".parse().expect("should be a valid name")), + "A1" + ) + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .entity_type_annotations( + Some("N".parse().expect("should be a valid name")), + "A2" + ) + .expect("should get annotations") + ) + ); + + assert_matches!(schema.entity_type_annotation(None, "A1", "b",), Some("")); + assert_matches!(schema.entity_type_annotation(None, "A2", "a",), Some("a")); + assert_matches!(schema.entity_type_annotation(None, "A3", "a",), None); + assert_matches!(schema.entity_type_annotation(None, "A2", "x",), None); + assert_matches!( + schema.entity_type_annotation( + Some("N".parse().expect("should be a valid name")), + "A1", + "b", + ), + Some("") + ); + assert_matches!( + schema.entity_type_annotation( + Some("N".parse().expect("should be a valid name")), + "A2", + "a", + ), + Some("a") + ); + assert_matches!( + schema.entity_type_annotation( + Some("N".parse().expect("should be a valid name")), + "A3", + "a", + ), + None + ); + assert_matches!( + schema.entity_type_annotation( + Some("N".parse().expect("should be a valid name")), + "A2", + "x", + ), + None + ); + assert_matches!( + schema.entity_type_annotation( + Some("NM".parse().expect("should be a valid name")), + "A1", + "b", + ), + None + ); + } + + #[test] + fn common_type_annotations() { + let schema = example_schema(); + let annotations = BTreeMap::from_iter([("c", "c"), ("d", "")]); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .common_type_annotations(None, "T") + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .common_type_annotations( + Some("N".parse().expect("should be a valid name")), + "T" + ) + .expect("should get annotations") + ) + ); + assert_matches!(schema.common_type_annotation(None, "T", "c",), Some("c")); + assert_matches!(schema.common_type_annotation(None, "T", "d",), Some("")); + assert_matches!(schema.common_type_annotation(None, "T1", "c",), None); + assert_matches!(schema.common_type_annotation(None, "T", "x",), None); + + assert_matches!( + schema.common_type_annotation( + Some("N".parse().expect("should be a valid name")), + "T", + "c", + ), + Some("c") + ); + assert_matches!( + schema.common_type_annotation( + Some("N".parse().expect("should be a valid name")), + "T", + "d", + ), + Some("") + ); + assert_matches!( + schema.common_type_annotation( + Some("N".parse().expect("should be a valid name")), + "T1", + "c", + ), + None + ); + assert_matches!( + schema.common_type_annotation( + Some("N".parse().expect("should be a valid name")), + "T", + "x", + ), + None + ); + assert_matches!( + schema.common_type_annotation( + Some("NM".parse().expect("should be a valid name")), + "T", + "c", + ), + None + ); + } + + #[test] + fn action_type_annotations() { + let schema = example_schema(); + let annotations = BTreeMap::from_iter([("e", "e"), ("f", "")]); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .action_annotations(None, "a1".parse().unwrap(),) + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .action_annotations(None, "a2".parse().unwrap(),) + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .action_annotations( + Some("N".parse().expect("should be a valid name")), + "a1".parse().unwrap(), + ) + .expect("should get annotations") + ) + ); + assert_eq!( + annotations, + BTreeMap::from_iter( + schema + .action_annotations( + Some("N".parse().expect("should be a valid name")), + "a2".parse().unwrap(), + ) + .expect("should get annotations") + ) + ); + + assert_matches!( + schema.action_annotation(None, "a1".parse().unwrap(), "e",), + Some("e") + ); + assert_matches!( + schema.action_annotation(None, "a2".parse().unwrap(), "f",), + Some("") + ); + assert_matches!( + schema.action_annotation(None, "a3".parse().unwrap(), "e",), + None + ); + assert_matches!( + schema.action_annotation(None, "a2".parse().unwrap(), "x",), + None + ); + + assert_matches!( + schema.action_annotation( + Some("N".parse().expect("should be a valid name")), + "a1".parse().unwrap(), + "e", + ), + Some("e") + ); + assert_matches!( + schema.action_annotation( + Some("N".parse().expect("should be a valid name")), + "a2".parse().unwrap(), + "f", + ), + Some("") + ); + assert_matches!( + schema.action_annotation( + Some("N".parse().expect("should be a valid name")), + "a3".parse().unwrap(), + "e", + ), + None + ); + assert_matches!( + schema.action_annotation( + Some("N".parse().expect("should be a valid name")), + "a2".parse().unwrap(), + "x", + ), + None + ); + assert_matches!( + schema.action_annotation( + Some("NM".parse().expect("should be a valid name")), + "a1".parse().unwrap(), + "e", + ), + None + ); + } +} From 44f80788298e7e8caf9c13fb865b2cbbad4893b9 Mon Sep 17 00:00:00 2001 From: shaobo-he-aws <130499339+shaobo-he-aws@users.noreply.github.com> Date: Mon, 13 Jan 2025 13:21:47 -0800 Subject: [PATCH 4/7] Use `cedar-policy`'s version getters for WASM APIs (#1419) Signed-off-by: Shaobo He --- Cargo.lock | 344 -------------------------------- cedar-policy/src/ffi/mod.rs | 2 + cedar-policy/src/ffi/version.rs | 17 ++ cedar-wasm/Cargo.toml | 5 - cedar-wasm/build.rs | 36 ---- cedar-wasm/src/lib.rs | 13 +- 6 files changed, 29 insertions(+), 388 deletions(-) create mode 100644 cedar-policy/src/ffi/version.rs delete mode 100644 cedar-wasm/build.rs diff --git a/Cargo.lock b/Cargo.lock index 59b298fe7..bd7689f63 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -252,18 +252,6 @@ version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "325918d6fe32f23b19878fe4b34794ae41fc19ddbe53b10571a4874d44ffd39b" -[[package]] -name = "cargo-lock" -version = "10.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6469776d007022d505bbcc2be726f5f096174ae76d710ebc609eb3029a45b551" -dependencies = [ - "semver", - "serde", - "toml", - "url", -] - [[package]] name = "cast" version = "0.3.0" @@ -428,19 +416,16 @@ dependencies = [ name = "cedar-wasm" version = "4.3.0" dependencies = [ - "cargo-lock", "cedar-policy", "cedar-policy-core", "cedar-policy-formatter", "cedar-policy-validator", "console_error_panic_hook", "cool_asserts", - "itertools 0.14.0", "serde", "serde-wasm-bindgen", "serde_json", "tsify", - "url", "wasm-bindgen", "wasm-bindgen-test", ] @@ -754,17 +739,6 @@ dependencies = [ "crypto-common", ] -[[package]] -name = "displaydoc" -version = "0.2.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "doc-comment" version = "0.3.3" @@ -867,15 +841,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "form_urlencoded" -version = "1.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" -dependencies = [ - "percent-encoding", -] - [[package]] name = "futures-core" version = "0.3.31" @@ -1050,151 +1015,12 @@ dependencies = [ "cc", ] -[[package]] -name = "icu_collections" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" -dependencies = [ - "displaydoc", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_locid" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" -dependencies = [ - "displaydoc", - "litemap", - "tinystr", - "writeable", - "zerovec", -] - -[[package]] -name = "icu_locid_transform" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_locid_transform_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_locid_transform_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" - -[[package]] -name = "icu_normalizer" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_normalizer_data", - "icu_properties", - "icu_provider", - "smallvec", - "utf16_iter", - "utf8_iter", - "write16", - "zerovec", -] - -[[package]] -name = "icu_normalizer_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" - -[[package]] -name = "icu_properties" -version = "1.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" -dependencies = [ - "displaydoc", - "icu_collections", - "icu_locid_transform", - "icu_properties_data", - "icu_provider", - "tinystr", - "zerovec", -] - -[[package]] -name = "icu_properties_data" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" - -[[package]] -name = "icu_provider" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" -dependencies = [ - "displaydoc", - "icu_locid", - "icu_provider_macros", - "stable_deref_trait", - "tinystr", - "writeable", - "yoke", - "zerofrom", - "zerovec", -] - -[[package]] -name = "icu_provider_macros" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "ident_case" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39" -[[package]] -name = "idna" -version = "1.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" -dependencies = [ - "idna_adapter", - "smallvec", - "utf8_iter", -] - -[[package]] -name = "idna_adapter" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" -dependencies = [ - "icu_normalizer", - "icu_properties", -] - [[package]] name = "indent_write" version = "2.2.0" @@ -1368,12 +1194,6 @@ version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" -[[package]] -name = "litemap" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" - [[package]] name = "lock_api" version = "0.4.12" @@ -1576,12 +1396,6 @@ dependencies = [ "windows-targets", ] -[[package]] -name = "percent-encoding" -version = "2.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" - [[package]] name = "petgraph" version = "0.6.5" @@ -2073,9 +1887,6 @@ name = "semver" version = "1.0.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3cb6eb87a131f756572d7fb904f6e7b68633f09cca868c5df1c4b8d1a694bbba" -dependencies = [ - "serde", -] [[package]] name = "serde" @@ -2132,15 +1943,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_spanned" -version = "0.6.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" -dependencies = [ - "serde", -] - [[package]] name = "serde_with" version = "3.12.0" @@ -2238,12 +2040,6 @@ dependencies = [ "serde", ] -[[package]] -name = "stable_deref_trait" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" - [[package]] name = "stacker" version = "0.1.17" @@ -2308,17 +2104,6 @@ dependencies = [ "unicode-ident", ] -[[package]] -name = "synstructure" -version = "0.13.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "tempfile" version = "3.15.0" @@ -2446,16 +2231,6 @@ dependencies = [ "time-core", ] -[[package]] -name = "tinystr" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" -dependencies = [ - "displaydoc", - "zerovec", -] - [[package]] name = "tinytemplate" version = "1.2.1" @@ -2481,26 +2256,11 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" -[[package]] -name = "toml" -version = "0.8.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" -dependencies = [ - "serde", - "serde_spanned", - "toml_datetime", - "toml_edit", -] - [[package]] name = "toml_datetime" version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" -dependencies = [ - "serde", -] [[package]] name = "toml_edit" @@ -2509,8 +2269,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap 2.7.0", - "serde", - "serde_spanned", "toml_datetime", "winnow", ] @@ -2613,29 +2371,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" -[[package]] -name = "url" -version = "2.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" -dependencies = [ - "form_urlencoded", - "idna", - "percent-encoding", -] - -[[package]] -name = "utf16_iter" -version = "1.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" - -[[package]] -name = "utf8_iter" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" - [[package]] name = "utf8parse" version = "0.2.2" @@ -2887,42 +2622,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "write16" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" - -[[package]] -name = "writeable" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" - -[[package]] -name = "yoke" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" -dependencies = [ - "serde", - "stable_deref_trait", - "yoke-derive", - "zerofrom", -] - -[[package]] -name = "yoke-derive" -version = "0.7.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - [[package]] name = "zerocopy" version = "0.7.35" @@ -2943,46 +2642,3 @@ dependencies = [ "quote", "syn", ] - -[[package]] -name = "zerofrom" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" -dependencies = [ - "zerofrom-derive", -] - -[[package]] -name = "zerofrom-derive" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" -dependencies = [ - "proc-macro2", - "quote", - "syn", - "synstructure", -] - -[[package]] -name = "zerovec" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" -dependencies = [ - "yoke", - "zerofrom", - "zerovec-derive", -] - -[[package]] -name = "zerovec-derive" -version = "0.10.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] diff --git a/cedar-policy/src/ffi/mod.rs b/cedar-policy/src/ffi/mod.rs index 8a993225c..a95e40ef7 100644 --- a/cedar-policy/src/ffi/mod.rs +++ b/cedar-policy/src/ffi/mod.rs @@ -26,4 +26,6 @@ mod format; pub use format::*; mod convert; pub use convert::*; +mod version; +pub use version::*; mod tests; diff --git a/cedar-policy/src/ffi/version.rs b/cedar-policy/src/ffi/version.rs new file mode 100644 index 000000000..6fee0bac0 --- /dev/null +++ b/cedar-policy/src/ffi/version.rs @@ -0,0 +1,17 @@ +#[cfg(feature = "wasm")] +use wasm_bindgen::prelude::wasm_bindgen; + +use crate::api; + +/// Get language version of Cedar +#[cfg_attr(feature = "wasm", wasm_bindgen(js_name = "getCedarLangVersion"))] +pub fn get_lang_version() -> String { + let version = api::version::get_lang_version(); + format!("{}.{}", version.major, version.minor) +} + +/// Get SDK version of Cedar +pub fn get_sdk_version() -> String { + let version = api::version::get_sdk_version(); + format!("{version}") +} diff --git a/cedar-wasm/Cargo.toml b/cedar-wasm/Cargo.toml index d18aa6b02..2629107d3 100644 --- a/cedar-wasm/Cargo.toml +++ b/cedar-wasm/Cargo.toml @@ -32,10 +32,5 @@ crate-type = ["cdylib", "rlib"] wasm-bindgen-test = "0.3.50" cool_asserts = "2.0" -[build-dependencies] -cargo-lock = "10.0.0" -url = "2.5.4" -itertools = "0.14.0" - [lints] workspace = true diff --git a/cedar-wasm/build.rs b/cedar-wasm/build.rs deleted file mode 100644 index 91b8d0cc6..000000000 --- a/cedar-wasm/build.rs +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright Cedar Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use cargo_lock::Lockfile; - -/// PANIC SAFETY: This is a build script so it's okay for it to panic. Build should fail if underlying assumptions of this script fail -#[allow(clippy::expect_used)] -fn main() { - println!("cargo:rerun-if-changed=../Cargo.lock"); - let lockfile = Lockfile::load("../Cargo.lock").expect("a valid lockfile"); - let mut iter = lockfile - .packages - .into_iter() - .filter(|p| p.name.as_str() == "cedar-policy"); - let version = iter - .next() - .expect("cedar-policy is not found in manifest") - .version; - - assert!(iter.next().is_none()); - - println!("cargo:rustc-env=CEDAR_VERSION={version}"); -} diff --git a/cedar-wasm/src/lib.rs b/cedar-wasm/src/lib.rs index f628820cc..2df253522 100644 --- a/cedar-wasm/src/lib.rs +++ b/cedar-wasm/src/lib.rs @@ -17,13 +17,20 @@ use wasm_bindgen::prelude::*; mod utils; +use cedar_policy::ffi; pub use cedar_policy::ffi::{ check_parse_context, check_parse_entities, check_parse_policy_set, check_parse_schema, format, - is_authorized, policy_to_json, policy_to_text, schema_to_json, schema_to_text, validate, + get_lang_version, is_authorized, policy_to_json, policy_to_text, schema_to_json, + schema_to_text, validate, }; pub use utils::*; #[wasm_bindgen(js_name = "getCedarVersion")] -pub fn get_cedar_version() -> String { - std::env!("CEDAR_VERSION").to_string() +pub fn get_sdk_version_deprecated() -> String { + get_sdk_version() +} + +#[wasm_bindgen(js_name = "getCedarSDKVersion")] +pub fn get_sdk_version() -> String { + ffi::get_sdk_version() } From 569aecfdf1396c64dffb250f62c6a96cf06dfeb6 Mon Sep 17 00:00:00 2001 From: Craig Disselkoen Date: Mon, 13 Jan 2025 16:26:25 -0500 Subject: [PATCH 5/7] fix 1421, usage of `policy_id` in Typechecker (#1424) Signed-off-by: Craig Disselkoen --- cedar-policy-validator/src/entity_manifest.rs | 2 +- cedar-policy-validator/src/level_validate.rs | 22 +- cedar-policy-validator/src/lib.rs | 2 +- cedar-policy-validator/src/schema.rs | 2 +- cedar-policy-validator/src/typecheck.rs | 235 +++++++----------- .../src/typecheck/test/partial.rs | 6 +- .../src/typecheck/test/policy.rs | 14 +- .../src/typecheck/test/strict.rs | 59 +++-- .../src/typecheck/test/test_utils.rs | 31 ++- .../src/typecheck/test/type_annotation.rs | 12 +- cedar-policy/src/api.rs | 6 +- 11 files changed, 165 insertions(+), 226 deletions(-) diff --git a/cedar-policy-validator/src/entity_manifest.rs b/cedar-policy-validator/src/entity_manifest.rs index 4a93ac3bc..8add18ab0 100644 --- a/cedar-policy-validator/src/entity_manifest.rs +++ b/cedar-policy-validator/src/entity_manifest.rs @@ -450,7 +450,7 @@ pub fn compute_entity_manifest( // now, for each policy we add the data it requires to the manifest for policy in policies.policies() { // typecheck the policy and get all the request environments - let typechecker = Typechecker::new(schema, ValidationMode::Strict, policy.id().clone()); + let typechecker = Typechecker::new(schema, ValidationMode::Strict); let request_envs = typechecker.typecheck_by_request_env(policy.template()); for (request_env, policy_check) in request_envs { let new_primary_slice = match policy_check { diff --git a/cedar-policy-validator/src/level_validate.rs b/cedar-policy-validator/src/level_validate.rs index d4e4fbbde..e5c1e388d 100644 --- a/cedar-policy-validator/src/level_validate.rs +++ b/cedar-policy-validator/src/level_validate.rs @@ -41,12 +41,8 @@ impl Validator { // Only perform level validation if validation passed. if peekable_errors.peek().is_none() { - let levels_errors = self.check_entity_deref_level( - p, - mode, - &EntityDerefLevel::from(max_deref_level), - p.id(), - ); + let levels_errors = + self.check_entity_deref_level(p, mode, &EntityDerefLevel::from(max_deref_level)); (peekable_errors.chain(levels_errors), warnings) } else { (peekable_errors.chain(vec![]), warnings) @@ -60,16 +56,14 @@ impl Validator { t: &'a Template, mode: ValidationMode, max_allowed_level: &EntityDerefLevel, - policy_id: &PolicyID, ) -> Vec { - let typechecker = Typechecker::new(&self.schema, mode, t.id().clone()); + let typechecker = Typechecker::new(&self.schema, mode); let type_annotated_asts = typechecker.typecheck_by_request_env(t); let mut errs = vec![]; for (_, policy_check) in type_annotated_asts { match policy_check { PolicyCheck::Success(e) | PolicyCheck::Irrelevant(_, e) => { - let res = - Self::check_entity_deref_level_helper(&e, max_allowed_level, policy_id); + let res = Self::check_entity_deref_level_helper(&e, max_allowed_level, t.id()); if let Some(e) = res.1 { errs.push(ValidationError::EntityDerefLevelViolation(e)) } @@ -302,12 +296,10 @@ mod levels_validation_tests { let p = parser::parse_policy(None, src).unwrap(); set.add_static(p).unwrap(); - let template_name = PolicyID::from_string("policy0"); let result = validator.check_entity_deref_level( - set.get_template(&template_name).unwrap(), + set.get_template(&PolicyID::from_string("policy0")).unwrap(), ValidationMode::default(), &EntityDerefLevel { level: 0 }, - &template_name, ); assert!(result.is_empty()); } @@ -322,12 +314,10 @@ mod levels_validation_tests { let p = parser::parse_policy(None, src).unwrap(); set.add_static(p).unwrap(); - let template_name = PolicyID::from_string("policy0"); let result = validator.check_entity_deref_level( - set.get_template(&template_name).unwrap(), + set.get_template(&PolicyID::from_string("policy0")).unwrap(), ValidationMode::default(), &EntityDerefLevel { level: 0 }, - &template_name, ); assert!(result.len() == 1); } diff --git a/cedar-policy-validator/src/lib.rs b/cedar-policy-validator/src/lib.rs index 201e18db8..711851106 100644 --- a/cedar-policy-validator/src/lib.rs +++ b/cedar-policy-validator/src/lib.rs @@ -255,7 +255,7 @@ impl Validator { impl Iterator + 'a, impl Iterator + 'a, ) { - let typecheck = Typechecker::new(&self.schema, mode, t.id().clone()); + let typecheck = Typechecker::new(&self.schema, mode); let mut errors = HashSet::new(); let mut warnings = HashSet::new(); typecheck.typecheck_policy(t, &mut errors, &mut warnings); diff --git a/cedar-policy-validator/src/schema.rs b/cedar-policy-validator/src/schema.rs index 715941bcb..eb37ddcbb 100644 --- a/cedar-policy-validator/src/schema.rs +++ b/cedar-policy-validator/src/schema.rs @@ -795,7 +795,7 @@ impl ValidatorSchema { /// on `get_entity_types_in`. pub(crate) fn get_entity_types_in_set<'a>( &'a self, - euids: impl IntoIterator + 'a, + euids: impl IntoIterator, ) -> impl Iterator { euids.into_iter().flat_map(|e| self.get_entity_types_in(e)) } diff --git a/cedar-policy-validator/src/typecheck.rs b/cedar-policy-validator/src/typecheck.rs index 49e781aef..e997fc0ff 100644 --- a/cedar-policy-validator/src/typecheck.rs +++ b/cedar-policy-validator/src/typecheck.rs @@ -74,23 +74,15 @@ pub struct Typechecker<'a> { schema: &'a ValidatorSchema, extensions: &'static ExtensionSchemas<'static>, mode: ValidationMode, - policy_id: PolicyID, } impl<'a> Typechecker<'a> { - /// Construct a new typechecker. - pub fn new( - schema: &'a ValidatorSchema, - mode: ValidationMode, - policy_id: PolicyID, - ) -> Typechecker<'a> { - // Set the extensions using `all_available_extension_schemas`. - let extensions = ExtensionSchemas::all_available(); + /// Construct a new typechecker. All extensions are enabled by default. + pub fn new(schema: &'a ValidatorSchema, mode: ValidationMode) -> Typechecker<'a> { Self { schema, - extensions, + extensions: ExtensionSchemas::all_available(), mode, - policy_id, } } @@ -144,6 +136,10 @@ impl<'a> Typechecker<'a> { /// typechecks it under every schema-defined request environment. The result contains /// these environments and the individual typechecking response for each, in no /// particular order. + /// + /// Callers using this as the toplevel entry point, rather than + /// `typecheck_policy()`, will not get `impossible_policy` validation + /// warnings. pub fn typecheck_by_request_env<'b>( &'b self, t: &'b Template, @@ -161,15 +157,25 @@ impl<'a> Typechecker<'a> { /// efficiently than calling `typecheck_by_request_env()` multiple times. /// /// The `Loc` of each policy is also returned, for error reporting purposes. + /// + /// Callers using this as the toplevel entry point, rather than + /// `typecheck_policy()`, will not get `impossible_policy` validation + /// warnings. pub fn typecheck_multi_by_request_env<'b>( &'b self, ts: impl IntoIterator, ) -> HashMap, PolicyCheck)>, Option)> { - self.apply_typecheck_fn_by_request_env(ts, |request, expr| { + self.apply_typecheck_fn_by_request_env(ts, |request_env, policy_id, expr| { let mut type_errors = Vec::new(); + let single_env_typechecker = SingleEnvTypechecker { + schema: self.schema, + extensions: self.extensions, + mode: self.mode, + policy_id, + request_env, + }; let empty_prior_capability = CapabilitySet::new(); - let ans = self.expect_type( - request, + let ans = single_env_typechecker.expect_type( &empty_prior_capability, expr, Type::primitive_boolean(), @@ -202,7 +208,7 @@ impl<'a> Typechecker<'a> { typecheck_fn: F, ) -> HashMap, C)>, Option)> where - F: Fn(&RequestEnv<'b>, &Expr) -> C, + F: Fn(&RequestEnv<'b>, &PolicyID, &Expr) -> C, { let mut ret = HashMap::new(); @@ -227,7 +233,7 @@ impl<'a> Typechecker<'a> { .expect("already inserted this key above") .0 .extend(self.link_request_env(&unlinked_e, t).map(|linked_e| { - let check = typecheck_fn(&linked_e, cond); + let check = typecheck_fn(&linked_e, t.id(), cond); (linked_e, check) })); } @@ -357,16 +363,28 @@ impl<'a> Typechecker<'a> { Box::new(std::iter::once(None)) } } +} + +/// Struct which implements typechecking for policies within a single request +/// env. +struct SingleEnvTypechecker<'a> { + schema: &'a ValidatorSchema, + extensions: &'a ExtensionSchemas<'a>, + mode: ValidationMode, + /// ID of the policy we're typechecking; used for associating any validation + /// errors with the correct policy ID + policy_id: &'a PolicyID, + /// The single env which we're performing typechecking for + request_env: &'a RequestEnv<'a>, +} - /// This method handles the majority of the work. Given an expression, - /// the type for the request, and the prior capability, return the result of - /// typechecking the expression, and add any errors encountered into the - /// `type_errors` list. The result of typechecking contains the type of the - /// expression, any resulting capability after the expression, and a flag - /// indicating whether the expression successfully typechecked. +impl<'a> SingleEnvTypechecker<'a> { + /// This method handles the majority of the work. Given an expression, and + /// the prior capability, return the result of typechecking the expression + /// in the single env this typechecker was constructed for, and add any + /// errors encountered into the `type_errors` list. fn typecheck<'b>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, e: &'b Expr, type_errors: &mut Vec, @@ -380,7 +398,7 @@ impl<'a> Typechecker<'a> { // Principal, resource, and context have types defined by // the request type. ExprKind::Var(Var::Principal) => TypecheckAnswer::success( - ExprBuilder::with_data(Some(request_env.principal_type())) + ExprBuilder::with_data(Some(self.request_env.principal_type())) .with_same_source_loc(e) .var(Var::Principal), ), @@ -389,7 +407,7 @@ impl<'a> Typechecker<'a> { // entity type), so the type of Action is only the entity type name // taken from the euid. ExprKind::Var(Var::Action) => { - match request_env.action_type(self.schema) { + match self.request_env.action_type(self.schema) { Some(ty) => TypecheckAnswer::success( ExprBuilder::with_data(Some(ty)) .with_same_source_loc(e) @@ -407,12 +425,12 @@ impl<'a> Typechecker<'a> { } } ExprKind::Var(Var::Resource) => TypecheckAnswer::success( - ExprBuilder::with_data(Some(request_env.resource_type())) + ExprBuilder::with_data(Some(self.request_env.resource_type())) .with_same_source_loc(e) .var(Var::Resource), ), ExprKind::Var(Var::Context) => TypecheckAnswer::success( - ExprBuilder::with_data(Some(request_env.context_type())) + ExprBuilder::with_data(Some(self.request_env.context_type())) .with_same_source_loc(e) .var(Var::Context), ), @@ -422,13 +440,13 @@ impl<'a> Typechecker<'a> { // Template Slots, always has to be an entity. ExprKind::Slot(slotid) => TypecheckAnswer::success( ExprBuilder::with_data(Some(if slotid.is_principal() { - request_env + self.request_env .principal_slot() .clone() .map(Type::named_entity_reference) .unwrap_or_else(Type::any_entity_reference) } else if slotid.is_resource() { - request_env + self.request_env .resource_slot() .clone() .map(Type::named_entity_reference) @@ -495,7 +513,6 @@ impl<'a> Typechecker<'a> { } => { // The guard expression must be boolean. let ans_test = self.expect_type( - request_env, prior_capability, test_expr, Type::primitive_boolean(), @@ -511,7 +528,6 @@ impl<'a> Typechecker<'a> { // by `test`. This enables an attribute access // `principal.foo` after a condition `principal has foo`. let ans_then = self.typecheck( - request_env, &prior_capability.union(&test_capability), then_expr, type_errors, @@ -530,8 +546,7 @@ impl<'a> Typechecker<'a> { // we know in the `else` branch that the condition // evaluated to `false`. It still can use the original // prior capability. - let ans_else = - self.typecheck(request_env, prior_capability, else_expr, type_errors); + let ans_else = self.typecheck(prior_capability, else_expr, type_errors); ans_else.then_typecheck(|typ_else, else_capability| { TypecheckAnswer::success_with_capability(typ_else, else_capability) @@ -542,14 +557,12 @@ impl<'a> Typechecker<'a> { // prior capability are in their individual cases. let ans_then = self .typecheck( - request_env, &prior_capability.union(&test_capability), then_expr, type_errors, ) .map_capability(|capability| capability.union(&test_capability)); - let ans_else = - self.typecheck(request_env, prior_capability, else_expr, type_errors); + let ans_else = self.typecheck(prior_capability, else_expr, type_errors); // The type of the if expression is then the least // upper bound of the types of the then and else // branches. If either of these fails to typecheck, the @@ -592,7 +605,6 @@ impl<'a> Typechecker<'a> { ExprKind::And { left, right } => { let ans_left = self.expect_type( - request_env, prior_capability, left, Type::primitive_boolean(), @@ -621,7 +633,6 @@ impl<'a> Typechecker<'a> { // the right will only be evaluated after the left // evaluated to `true`. let ans_right = self.expect_type( - request_env, &prior_capability.union(&capability_left), right, Type::primitive_boolean(), @@ -690,7 +701,6 @@ impl<'a> Typechecker<'a> { // capability propagation adjusted as necessary. ExprKind::Or { left, right } => { let ans_left = self.expect_type( - request_env, prior_capability, left, Type::primitive_boolean(), @@ -712,7 +722,6 @@ impl<'a> Typechecker<'a> { // left could have evaluated to either `true` or `false` // when the left is evaluated. let ans_right = self.expect_type( - request_env, prior_capability, right, Type::primitive_boolean(), @@ -778,22 +787,21 @@ impl<'a> Typechecker<'a> { ExprKind::UnaryApp { .. } => { // INVARIANT: typecheck_unary requires a `UnaryApp`, we've just ensured this - self.typecheck_unary(request_env, prior_capability, e, type_errors) + self.typecheck_unary(prior_capability, e, type_errors) } ExprKind::BinaryApp { .. } => { // INVARIANT: typecheck_binary requires a `BinaryApp`, we've just ensured this - self.typecheck_binary(request_env, prior_capability, e, type_errors) + self.typecheck_binary(prior_capability, e, type_errors) } ExprKind::ExtensionFunctionApp { .. } => { // INVARIANT: typecheck_extension requires a `ExtensionFunctionApp`, we've just ensured this - self.typecheck_extension(request_env, prior_capability, e, type_errors) + self.typecheck_extension(prior_capability, e, type_errors) } ExprKind::GetAttr { expr, attr } => { // Accessing an attribute requires either an entity or a record // that has the attribute. let actual = self.expect_one_of_types( - request_env, prior_capability, expr, &[Type::any_entity_reference(), Type::any_record()], @@ -830,7 +838,7 @@ impl<'a> Typechecker<'a> { e.source_loc().cloned(), self.policy_id.clone(), AttributeAccess::from_expr( - request_env, + self.request_env, &typ_expr_actual, attr.clone(), ), @@ -860,7 +868,7 @@ impl<'a> Typechecker<'a> { e.source_loc().cloned(), self.policy_id.clone(), AttributeAccess::from_expr( - request_env, + self.request_env, &typ_expr_actual, attr.clone(), ), @@ -882,7 +890,6 @@ impl<'a> Typechecker<'a> { ExprKind::HasAttr { expr, attr } => { // `has` applies to an entity or a record let actual = self.expect_one_of_types( - request_env, prior_capability, expr, &[Type::any_entity_reference(), Type::any_record()], @@ -992,7 +999,6 @@ impl<'a> Typechecker<'a> { ExprKind::Like { expr, pattern } => { // `like` applies to a string let actual = self.expect_type( - request_env, prior_capability, expr, Type::primitive_string(), @@ -1017,7 +1023,6 @@ impl<'a> Typechecker<'a> { ExprKind::Is { expr, entity_type } => { self.expect_type( - request_env, prior_capability, expr, Type::any_entity_reference(), @@ -1091,7 +1096,7 @@ impl<'a> Typechecker<'a> { ExprKind::Set(exprs) => { let elem_types = exprs .iter() - .map(|elem| self.typecheck(request_env, prior_capability, elem, type_errors)) + .map(|elem| self.typecheck(prior_capability, elem, type_errors)) .collect::>(); // If we cannot compute a least upper bound for the element @@ -1140,7 +1145,7 @@ impl<'a> Typechecker<'a> { // Typecheck each attribute initializer expression individually. let record_attr_tys = map .values() - .map(|value| self.typecheck(request_env, prior_capability, value, type_errors)); + .map(|value| self.typecheck(prior_capability, value, type_errors)); // This will cause the return value to be `TypecheckFail` if any // of the attributes did not typecheck. TypecheckAnswer::sequence_all_then_typecheck( @@ -1190,7 +1195,6 @@ impl<'a> Typechecker<'a> { /// INVARIANT: `bin_expr` must be a `BinaryApp` fn typecheck_binary<'b>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, bin_expr: &'b Expr, type_errors: &mut Vec, @@ -1205,12 +1209,11 @@ impl<'a> Typechecker<'a> { // The arguments to `==` may typecheck with any type, but we will // return false if the types are disjoint. BinaryOp::Eq => { - let lhs_ty = self.typecheck(request_env, prior_capability, arg1, type_errors); - let rhs_ty = self.typecheck(request_env, prior_capability, arg2, type_errors); + let lhs_ty = self.typecheck(prior_capability, arg1, type_errors); + let rhs_ty = self.typecheck(prior_capability, arg2, type_errors); lhs_ty.then_typecheck(|lhs_ty, _| { rhs_ty.then_typecheck(|rhs_ty, _| { let type_of_eq = self.type_of_equality( - request_env, arg1, lhs_ty.data().as_ref(), arg2, @@ -1246,9 +1249,9 @@ impl<'a> Typechecker<'a> { .map(Type::extension) .chain(std::iter::once(Type::primitive_long())) .collect_vec(); - let ans_arg1 = self.typecheck(request_env, prior_capability, arg1, type_errors); + let ans_arg1 = self.typecheck(prior_capability, arg1, type_errors); ans_arg1.then_typecheck(|expr_ty_arg1, _| { - let ans_arg2 = self.typecheck(request_env, prior_capability, arg2, type_errors); + let ans_arg2 = self.typecheck(prior_capability, arg2, type_errors); ans_arg2.then_typecheck(|expr_ty_arg2, _| { let expr = ExprBuilder::with_data(Some(Type::primitive_boolean())) .with_same_source_loc(bin_expr) @@ -1373,7 +1376,6 @@ impl<'a> Typechecker<'a> { _ => None, }; let ans_arg1 = self.expect_type( - request_env, prior_capability, arg1, Type::primitive_long(), @@ -1382,7 +1384,6 @@ impl<'a> Typechecker<'a> { ); ans_arg1.then_typecheck(|expr_ty_arg1, _| { let ans_arg2 = self.expect_type( - request_env, prior_capability, arg2, Type::primitive_long(), @@ -1399,19 +1400,11 @@ impl<'a> Typechecker<'a> { }) } - BinaryOp::In => self.typecheck_in( - request_env, - prior_capability, - bin_expr, - arg1, - arg2, - type_errors, - ), + BinaryOp::In => self.typecheck_in(prior_capability, bin_expr, arg1, arg2, type_errors), BinaryOp::Contains => { // The first argument must be a set. self.expect_type( - request_env, prior_capability, arg1, Type::any_set(), @@ -1433,7 +1426,7 @@ impl<'a> Typechecker<'a> { ) .then_typecheck(|expr_ty_arg1, _| { // The second argument may be any type. We do not care if the element type cannot be in the set. - self.typecheck(request_env, prior_capability, arg2, type_errors) + self.typecheck(prior_capability, arg2, type_errors) .then_typecheck(|expr_ty_arg2, _| { if self.mode.is_strict() { let annotated_expr = @@ -1471,7 +1464,6 @@ impl<'a> Typechecker<'a> { BinaryOp::ContainsAll | BinaryOp::ContainsAny => { // Both arguments to a `containsAll` or `containsAny` must be sets. self.expect_type( - request_env, prior_capability, arg1, Type::any_set(), @@ -1492,14 +1484,9 @@ impl<'a> Typechecker<'a> { }, ) .then_typecheck(|expr_ty_arg1, _| { - self.expect_type( - request_env, - prior_capability, - arg2, - Type::any_set(), - type_errors, - |_| Some(UnexpectedTypeHelp::TryUsingSingleContains), - ) + self.expect_type(prior_capability, arg2, Type::any_set(), type_errors, |_| { + Some(UnexpectedTypeHelp::TryUsingSingleContains) + }) .then_typecheck(|expr_ty_arg2, _| { if self.mode.is_strict() { let annotated_expr = @@ -1527,7 +1514,6 @@ impl<'a> Typechecker<'a> { BinaryOp::HasTag => self .expect_type( - request_env, prior_capability, arg1, Type::any_entity_reference(), @@ -1536,7 +1522,6 @@ impl<'a> Typechecker<'a> { ) .then_typecheck(|expr_ty_arg1, _| { self.expect_type( - request_env, prior_capability, arg2, Type::primitive_string(), @@ -1600,7 +1585,6 @@ impl<'a> Typechecker<'a> { BinaryOp::GetTag => { self.expect_type( - request_env, prior_capability, arg1, Type::any_entity_reference(), @@ -1609,7 +1593,6 @@ impl<'a> Typechecker<'a> { ) .then_typecheck(|expr_ty_arg1, _| { self.expect_type( - request_env, prior_capability, arg2, Type::primitive_string(), @@ -1783,7 +1766,6 @@ impl<'a> Typechecker<'a> { /// Get the type for an `==` expression given the input types. fn type_of_equality<'b>( &self, - request_env: &RequestEnv<'_>, lhs_expr: &'b Expr, lhs_ty: Option<&Type>, rhs_expr: &'b Expr, @@ -1805,8 +1787,8 @@ impl<'a> Typechecker<'a> { // the action variable (which is converted into a literal euid // according to the binding in the request environment), then we // compare the euids on either side. - let lhs_euid = Typechecker::euid_from_euid_literal_or_action(request_env, lhs_expr); - let rhs_euid = Typechecker::euid_from_euid_literal_or_action(request_env, rhs_expr); + let lhs_euid = self.euid_from_euid_literal_or_action(lhs_expr); + let rhs_euid = self.euid_from_euid_literal_or_action(rhs_expr); if let (Some(lhs_euid), Some(rhs_euid)) = (lhs_euid, rhs_euid) { if lhs_euid == rhs_euid { // If lhs and rhs euid are the same, the equality has type `True`. @@ -1861,7 +1843,6 @@ impl<'a> Typechecker<'a> { /// type false, allowing for short circuiting in `if` and `and` expressions. fn typecheck_in<'b>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, in_expr: &Expr, lhs: &'b Expr, @@ -1871,7 +1852,6 @@ impl<'a> Typechecker<'a> { // First, the basic typechecking rules for `in` that apply regardless of // the syntactic special cases that follow. let ty_lhs = self.expect_type( - request_env, prior_capability, lhs, Type::any_entity_reference(), @@ -1879,7 +1859,6 @@ impl<'a> Typechecker<'a> { |_| Some(UnexpectedTypeHelp::TryUsingContains), ); let ty_rhs = self.expect_one_of_types( - request_env, prior_capability, rhs, &[ @@ -1912,8 +1891,8 @@ impl<'a> Typechecker<'a> { } let lhs_ty = lhs_expr.data().clone(); let rhs_ty = rhs_expr.data().clone(); - let lhs_as_euid_lit = Typechecker::replace_action_var_with_euid(request_env, lhs); - let rhs_as_euid_lit = Typechecker::replace_action_var_with_euid(request_env, rhs); + let lhs_as_euid_lit = self.replace_action_var_with_euid(lhs); + let rhs_as_euid_lit = self.replace_action_var_with_euid(rhs); match (lhs_as_euid_lit.expr_kind(), rhs_as_euid_lit.expr_kind()) { // var in EntityLiteral. Lookup the descendant types of the entity // literals. If the principal/resource type is not one of the @@ -1923,7 +1902,6 @@ impl<'a> Typechecker<'a> { ExprKind::Var(var @ (Var::Principal | Var::Resource)), ExprKind::Lit(Literal::EntityUID(_)), ) => self.type_of_var_in_entity_literals( - request_env, *var, [rhs_as_euid_lit.as_ref()], in_expr, @@ -1938,7 +1916,6 @@ impl<'a> Typechecker<'a> { ExprKind::Var(var @ (Var::Principal | Var::Resource)), ExprKind::Set(elems), ) => self.type_of_var_in_entity_literals( - request_env, *var, elems.as_ref(), in_expr, @@ -1956,7 +1933,6 @@ impl<'a> Typechecker<'a> { ExprKind::Lit(Literal::EntityUID(euid0)), ExprKind::Lit(Literal::EntityUID(_)), ) => self.type_of_entity_literal_in_entity_literals( - request_env, euid0, [rhs_as_euid_lit.as_ref()], in_expr, @@ -1967,7 +1943,6 @@ impl<'a> Typechecker<'a> { // As above, with the same complication, but applied to set of entities. (ExprKind::Lit(Literal::EntityUID(euid)), ExprKind::Set(elems)) => self .type_of_entity_literal_in_entity_literals( - request_env, euid, elems.as_ref(), in_expr, @@ -2069,14 +2044,8 @@ impl<'a> Typechecker<'a> { // Given an expression, if that expression is a literal or the `action` // variable, return it as an EntityUID. Return `None` otherwise. - fn euid_from_euid_literal_or_action( - request_env: &RequestEnv<'_>, - e: &Expr, - ) -> Option { - match Typechecker::replace_action_var_with_euid(request_env, e) - .as_ref() - .expr_kind() - { + fn euid_from_euid_literal_or_action(&self, e: &Expr) -> Option { + match self.replace_action_var_with_euid(e).expr_kind() { ExprKind::Lit(Literal::EntityUID(e)) => Some((**e).clone()), _ => None, } @@ -2086,12 +2055,12 @@ impl<'a> Typechecker<'a> { // extracted by `euid_from_uid_literal_or_action`. Return `None` if any // cannot be converted. fn euids_from_euid_literals_or_action<'b>( - request_env: &RequestEnv<'_>, + &self, exprs: impl IntoIterator, ) -> Option> { exprs .into_iter() - .map(|e| Self::euid_from_euid_literal_or_action(request_env, e)) + .map(|e| self.euid_from_euid_literal_or_action(e)) .collect::>>() } @@ -2099,18 +2068,17 @@ impl<'a> Typechecker<'a> { /// entity literal or set of entity literals. fn type_of_var_in_entity_literals<'b, 'c>( &self, - request_env: &RequestEnv<'_>, lhs_var: Var, rhs_elems: impl IntoIterator, in_expr: &Expr, lhs_expr: Expr>, rhs_expr: Expr>, ) -> TypecheckAnswer<'c> { - if let Some(rhs) = Typechecker::euids_from_euid_literals_or_action(request_env, rhs_elems) { + if let Some(rhs) = self.euids_from_euid_literals_or_action(rhs_elems) { let var_etype = if matches!(lhs_var, Var::Principal) { - request_env.principal_entity_type() + self.request_env.principal_entity_type() } else { - request_env.resource_entity_type() + self.request_env.resource_entity_type() }; match var_etype { None => { @@ -2136,7 +2104,7 @@ impl<'a> Typechecker<'a> { .all(|e| self.schema.euid_has_known_entity_type(e)); if self.schema.is_known_entity_type(var_name) && all_rhs_known { let descendants = self.schema.get_entity_types_in_set(rhs.iter()); - Typechecker::entity_in_descendants( + Self::entity_in_descendants( var_name, descendants, in_expr, @@ -2176,14 +2144,13 @@ impl<'a> Typechecker<'a> { fn type_of_entity_literal_in_entity_literals<'b, 'c>( &self, - request_env: &RequestEnv<'_>, lhs_euid: &EntityUID, rhs_elems: impl IntoIterator, in_expr: &Expr, lhs_expr: Expr>, rhs_expr: Expr>, ) -> TypecheckAnswer<'c> { - if let Some(rhs) = Typechecker::euids_from_euid_literals_or_action(request_env, rhs_elems) { + if let Some(rhs) = self.euids_from_euid_literals_or_action(rhs_elems) { let name = lhs_euid.entity_type(); // We don't want to apply the action hierarchy check to // non-action entities, but now we have a set of entities. @@ -2246,7 +2213,7 @@ impl<'a> Typechecker<'a> { ) -> TypecheckAnswer<'b> { let rhs_descendants = self.schema.get_actions_in_set(rhs); if let Some(rhs_descendants) = rhs_descendants { - Typechecker::entity_in_descendants(lhs, rhs_descendants, in_expr, lhs_expr, rhs_expr) + Self::entity_in_descendants(lhs, rhs_descendants, in_expr, lhs_expr, rhs_expr) } else { let annotated_expr = ExprBuilder::with_data(Some(Type::primitive_boolean())) .with_same_source_loc(in_expr) @@ -2264,10 +2231,10 @@ impl<'a> Typechecker<'a> { // based on the precise EUIDs when they're not actions, so we only look at // entity types. The type will be `False` is none of the entities on the rhs // have a type which may be an ancestor of the rhs entity type. - fn type_of_non_action_in_entities<'b>( + fn type_of_non_action_in_entities<'b, 'c>( &self, lhs: &EntityUID, - rhs: &[EntityUID], + rhs: &'c [EntityUID], in_expr: &Expr, lhs_expr: Expr>, rhs_expr: Expr>, @@ -2278,13 +2245,7 @@ impl<'a> Typechecker<'a> { .all(|e| self.schema.euid_has_known_entity_type(e)); if self.schema.is_known_entity_type(lhs_ety) && all_rhs_known { let rhs_descendants = self.schema.get_entity_types_in_set(rhs.iter()); - Typechecker::entity_in_descendants( - lhs_ety, - rhs_descendants, - in_expr, - lhs_expr, - rhs_expr, - ) + Self::entity_in_descendants(lhs_ety, rhs_descendants, in_expr, lhs_expr, rhs_expr) } else { let annotated_expr = ExprBuilder::with_data(Some(Type::primitive_boolean())) .with_same_source_loc(in_expr) @@ -2299,9 +2260,9 @@ impl<'a> Typechecker<'a> { /// Check if the entity is in the list of descendants. Return the singleton /// type false if it is not, and boolean otherwise. - fn entity_in_descendants<'b, K>( + fn entity_in_descendants<'b, 'c, K: 'c>( lhs_entity: &K, - rhs_descendants: impl IntoIterator, + rhs_descendants: impl IntoIterator, in_expr: &Expr, lhs_expr: Expr>, rhs_expr: Expr>, @@ -2326,7 +2287,6 @@ impl<'a> Typechecker<'a> { /// INVARIANT: `unary_expr` must be of kind `UnaryApp` fn typecheck_unary<'b>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, unary_expr: &'b Expr, type_errors: &mut Vec, @@ -2339,7 +2299,6 @@ impl<'a> Typechecker<'a> { match op { UnaryOp::Not => { let ans_arg = self.expect_type( - request_env, prior_capability, arg, Type::primitive_boolean(), @@ -2371,7 +2330,6 @@ impl<'a> Typechecker<'a> { } UnaryOp::Neg => { let ans_arg = self.expect_type( - request_env, prior_capability, arg, Type::primitive_long(), @@ -2388,7 +2346,6 @@ impl<'a> Typechecker<'a> { } UnaryOp::IsEmpty => { let ans_arg = self.expect_type( - request_env, prior_capability, arg, Type::any_set(), @@ -2416,7 +2373,6 @@ impl<'a> Typechecker<'a> { /// Return `TypecheckSuccess` with the type otherwise. fn expect_one_of_types<'b, F>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, expr: &'b Expr, expected: &[Type], @@ -2426,7 +2382,7 @@ impl<'a> Typechecker<'a> { where F: FnOnce(&Type) -> Option, { - let actual = self.typecheck(request_env, prior_capability, expr, type_errors); + let actual = self.typecheck(prior_capability, expr, type_errors); actual.then_typecheck(|mut typ_actual, capability| match typ_actual.data() { Some(actual_ty) => { if !expected.iter().any(|expected_ty| { @@ -2476,7 +2432,6 @@ impl<'a> Typechecker<'a> { /// type. fn expect_type<'b, F>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, expr: &'b Expr, expected: Type, @@ -2487,7 +2442,6 @@ impl<'a> Typechecker<'a> { F: FnOnce(&Type) -> Option, { self.expect_one_of_types( - request_env, prior_capability, expr, &[expected], @@ -2539,12 +2493,9 @@ impl<'a> Typechecker<'a> { /// If the `maybe_action_var` expression is `Expr::Var(Var::Action)`, return /// a expression for the entity uid for the action variable in the request /// environment. Otherwise, return the expression unchanged. - fn replace_action_var_with_euid( - request_env: &RequestEnv<'_>, - maybe_action_var: &'a Expr, - ) -> Cow<'a, Expr> { + fn replace_action_var_with_euid(&self, maybe_action_var: &'a Expr) -> Cow<'a, Expr> { match maybe_action_var.expr_kind() { - ExprKind::Var(Var::Action) => match request_env.action_entity_uid() { + ExprKind::Var(Var::Action) => match self.request_env.action_entity_uid() { Some(action) => Cow::Owned(Expr::val(action.clone())), None => Cow::Borrowed(maybe_action_var), }, @@ -2572,7 +2523,6 @@ impl<'a> Typechecker<'a> { /// INVARIANT `ext_expr` must be a `ExtensionFunctionApp` fn typecheck_extension<'b>( &self, - request_env: &RequestEnv<'_>, prior_capability: &CapabilitySet<'b>, ext_expr: &'b Expr, type_errors: &mut Vec, @@ -2586,7 +2536,7 @@ impl<'a> Typechecker<'a> { let typed_arg_exprs = |type_errors: &mut Vec| { args.iter() .map(|arg| { - self.typecheck(request_env, prior_capability, arg, type_errors) + self.typecheck(prior_capability, arg, type_errors) .into_typed_expr() }) .collect::>>() @@ -2639,14 +2589,7 @@ impl<'a> Typechecker<'a> { } } else { let typechecked_args = zip(args.as_ref(), arg_tys).map(|(arg, ty)| { - self.expect_type( - request_env, - prior_capability, - arg, - ty.clone(), - type_errors, - |_| None, - ) + self.expect_type(prior_capability, arg, ty.clone(), type_errors, |_| None) }); TypecheckAnswer::sequence_all_then_typecheck( typechecked_args, diff --git a/cedar-policy-validator/src/typecheck/test/partial.rs b/cedar-policy-validator/src/typecheck/test/partial.rs index eeda516cf..aa1f896c9 100644 --- a/cedar-policy-validator/src/typecheck/test/partial.rs +++ b/cedar-policy-validator/src/typecheck/test/partial.rs @@ -35,7 +35,7 @@ pub(crate) fn assert_partial_typecheck( policy: StaticPolicy, ) { let schema = schema.try_into().expect("Failed to construct schema."); - let typechecker = Typechecker::new(&schema, ValidationMode::Partial, policy.id().clone()); + let typechecker = Typechecker::new(&schema, ValidationMode::Partial); let mut errors: HashSet = HashSet::new(); let mut warnings: HashSet = HashSet::new(); let typechecked = typechecker.typecheck_policy( @@ -54,7 +54,7 @@ pub(crate) fn assert_partial_typecheck_fails( expected_errors: impl IntoIterator, ) { let schema = schema.try_into().expect("Failed to construct schema."); - let typechecker = Typechecker::new(&schema, ValidationMode::Partial, policy.id().clone()); + let typechecker = Typechecker::new(&schema, ValidationMode::Partial); let mut errors: HashSet = HashSet::new(); let mut warnings: HashSet = HashSet::new(); let typechecked = typechecker.typecheck_policy( @@ -73,7 +73,7 @@ pub(crate) fn assert_partial_typecheck_warns( expected_warnings: impl IntoIterator, ) { let schema = schema.try_into().expect("Failed to construct schema."); - let typechecker = Typechecker::new(&schema, ValidationMode::Partial, policy.id().clone()); + let typechecker = Typechecker::new(&schema, ValidationMode::Partial); let mut errors: HashSet = HashSet::new(); let mut warnings: HashSet = HashSet::new(); let typechecked = typechecker.typecheck_policy( diff --git a/cedar-policy-validator/src/typecheck/test/policy.rs b/cedar-policy-validator/src/typecheck/test/policy.rs index 92b7510ad..3719be271 100644 --- a/cedar-policy-validator/src/typecheck/test/policy.rs +++ b/cedar-policy-validator/src/typecheck/test/policy.rs @@ -135,11 +135,7 @@ fn policy_checked_in_multiple_envs() { let schema = simple_schema_file() .try_into() .expect("Failed to construct schema."); - let typechecker = Typechecker::new( - &schema, - ValidationMode::default(), - PolicyID::from_string("0"), - ); + let typechecker = Typechecker::new(&schema, ValidationMode::default()); let env_checks = typechecker.typecheck_by_request_env(&t); // There are 3 possible envs in schema: // - User, "view_photo", Photo @@ -162,11 +158,7 @@ fn policy_checked_in_multiple_envs() { let schema = simple_schema_file() .try_into() .expect("Failed to construct schema."); - let typechecker = Typechecker::new( - &schema, - ValidationMode::default(), - PolicyID::from_string("0"), - ); + let typechecker = Typechecker::new(&schema, ValidationMode::default()); let env_checks = typechecker.typecheck_by_request_env(&t); // With the new action, policy is always false for the other two assert!( @@ -1050,7 +1042,7 @@ fn extended_has() { y?: { z?: Long, } - } + } }; action "action" appliesTo { diff --git a/cedar-policy-validator/src/typecheck/test/strict.rs b/cedar-policy-validator/src/typecheck/test/strict.rs index 8ec6178ea..3ae934d4a 100644 --- a/cedar-policy-validator/src/typecheck/test/strict.rs +++ b/cedar-policy-validator/src/typecheck/test/strict.rs @@ -29,11 +29,11 @@ use cedar_policy_core::{ }; use crate::{ + extensions::ExtensionSchemas, json_schema, - typecheck::Typechecker, + typecheck::SingleEnvTypechecker, types::{AttributeType, CapabilitySet, OpenTag, RequestEnv, Type}, - validation_errors::LubContext, - validation_errors::LubHelp, + validation_errors::{LubContext, LubHelp}, RawName, ValidationError, ValidationMode, }; @@ -44,21 +44,23 @@ use super::test_utils::{ #[track_caller] // report the caller's location as the location of the panic, not the location in this function fn assert_typechecks_strict( schema: json_schema::Fragment, - env: &RequestEnv<'_>, + request_env: &RequestEnv<'_>, e: Expr, expected_type: Type, ) { let schema = schema.try_into().expect("Failed to construct schema."); - let typechecker = Typechecker::new(&schema, ValidationMode::Strict, expr_id_placeholder()); + let typechecker = SingleEnvTypechecker { + schema: &schema, + extensions: ExtensionSchemas::all_available(), + mode: ValidationMode::Strict, + policy_id: &expr_id_placeholder(), + request_env, + }; let mut errs = Vec::new(); - let answer = typechecker.expect_type( - env, - &CapabilitySet::new(), - &e, - expected_type, - &mut errs, - |_| None, - ); + let answer = + typechecker.expect_type(&CapabilitySet::new(), &e, expected_type, &mut errs, |_| { + None + }); assert_eq!(errs, vec![], "Expression should not contain any errors."); assert_matches!( @@ -70,22 +72,24 @@ fn assert_typechecks_strict( #[track_caller] // report the caller's location as the location of the panic, not the location in this function fn assert_strict_type_error( schema: json_schema::Fragment, - env: &RequestEnv<'_>, + request_env: &RequestEnv<'_>, e: Expr, expected_type: Type, expected_error: ValidationError, ) { let schema = schema.try_into().expect("Failed to construct schema."); - let typechecker = Typechecker::new(&schema, ValidationMode::Strict, expr_id_placeholder()); + let typechecker = SingleEnvTypechecker { + schema: &schema, + extensions: ExtensionSchemas::all_available(), + mode: ValidationMode::Strict, + policy_id: &expr_id_placeholder(), + request_env, + }; let mut errs = Vec::new(); - let answer = typechecker.expect_type( - env, - &CapabilitySet::new(), - &e, - expected_type, - &mut errs, - |_| None, - ); + let answer = + typechecker.expect_type(&CapabilitySet::new(), &e, expected_type, &mut errs, |_| { + None + }); assert_eq!(errs.into_iter().collect::>(), vec![expected_error]); assert_matches!( @@ -168,10 +172,15 @@ where fn strict_typecheck_catches_regular_type_error() { with_simple_schema_and_request(|s, q| { let schema = s.try_into().expect("Failed to construct schema."); - let typechecker = Typechecker::new(&schema, ValidationMode::Strict, expr_id_placeholder()); + let typechecker = SingleEnvTypechecker { + schema: &schema, + extensions: ExtensionSchemas::all_available(), + mode: ValidationMode::Strict, + policy_id: &expr_id_placeholder(), + request_env: &q, + }; let mut errs = Vec::new(); typechecker.expect_type( - &q, &CapabilitySet::new(), &Expr::from_str("1 + false").unwrap(), Type::primitive_long(), diff --git a/cedar-policy-validator/src/typecheck/test/test_utils.rs b/cedar-policy-validator/src/typecheck/test/test_utils.rs index db94a61b6..9768bdbfb 100644 --- a/cedar-policy-validator/src/typecheck/test/test_utils.rs +++ b/cedar-policy-validator/src/typecheck/test/test_utils.rs @@ -28,7 +28,7 @@ use cedar_policy_core::parser::Loc; use crate::{ json_schema, - typecheck::{TypecheckAnswer, Typechecker}, + typecheck::{SingleEnvTypechecker, TypecheckAnswer, Typechecker}, types::{CapabilitySet, OpenTag, RequestEnv, Type}, validation_errors::UnexpectedTypeHelp, NamespaceDefinitionWithActionAttributes, RawName, ValidationError, ValidationMode, @@ -82,9 +82,13 @@ impl Type { impl Typechecker<'_> { /// Typecheck an expression outside the context of a policy. This is /// currently only used for testing. + /// + /// `policy_id`: Policy ID to associate with this `Expr`, for the purposes + /// of reporting the policy ID in validation errors pub(crate) fn typecheck_expr<'a>( &self, e: &'a Expr, + policy_id: &'a PolicyID, unique_type_errors: &mut HashSet, ) -> TypecheckAnswer<'a> { // Using bogus entity type names here for testing. They'll be treated as @@ -102,8 +106,15 @@ impl Typechecker<'_> { principal_slot: None, resource_slot: None, }; + let typechecker = SingleEnvTypechecker { + schema: self.schema, + extensions: self.extensions, + mode: self.mode, + policy_id, + request_env: &request_env, + }; let mut type_errors = Vec::new(); - let ans = self.typecheck(&request_env, &CapabilitySet::new(), e, &mut type_errors); + let ans = typechecker.typecheck(&CapabilitySet::new(), e, &mut type_errors); unique_type_errors.extend(type_errors); ans } @@ -191,7 +202,7 @@ pub(crate) fn assert_policy_typechecks_for_mode( ) { let policy = policy.into(); let schema = schema.schema(); - let mut typechecker = Typechecker::new(&schema, mode, expr_id_placeholder()); + let mut typechecker = Typechecker::new(&schema, mode); let mut type_errors: HashSet = HashSet::new(); let mut warnings: HashSet = HashSet::new(); let typechecked = typechecker.typecheck_policy(&policy, &mut type_errors, &mut warnings); @@ -261,7 +272,7 @@ pub(crate) fn assert_policy_typecheck_fails_for_mode( ) -> HashSet { let policy = policy.into(); let schema = schema.schema(); - let typechecker = Typechecker::new(&schema, mode, policy.id().clone()); + let typechecker = Typechecker::new(&schema, mode); let mut type_errors: HashSet = HashSet::new(); let mut warnings: HashSet = HashSet::new(); let typechecked = typechecker.typecheck_policy(&policy, &mut type_errors, &mut warnings); @@ -279,7 +290,7 @@ pub(crate) fn assert_policy_typecheck_warns_for_mode( ) -> HashSet { let policy = policy.into(); let schema = schema.schema(); - let typechecker = Typechecker::new(&schema, mode, policy.id().clone()); + let typechecker = Typechecker::new(&schema, mode); let mut type_errors: HashSet = HashSet::new(); let mut warnings: HashSet = HashSet::new(); let typechecked = typechecker.typecheck_policy(&policy, &mut type_errors, &mut warnings); @@ -309,9 +320,10 @@ pub(crate) fn assert_typechecks_for_mode( mode: ValidationMode, ) { let schema = schema.schema(); - let typechecker = Typechecker::new(&schema, mode, expr_id_placeholder()); + let typechecker = Typechecker::new(&schema, mode); let mut type_errors = HashSet::new(); - let actual = typechecker.typecheck_expr(&expr, &mut type_errors); + let pid = expr_id_placeholder(); + let actual = typechecker.typecheck_expr(&expr, &pid, &mut type_errors); assert_matches!(actual, TypecheckAnswer::TypecheckSuccess { expr_type, .. } => { assert_types_eq(typechecker.schema, &expected, &expr_type.into_data().expect("Typechecked expression must have type")); }); @@ -352,9 +364,10 @@ pub(crate) fn assert_typecheck_fails_for_mode( mode: ValidationMode, ) -> HashSet { let schema = schema.schema(); - let typechecker = Typechecker::new(&schema, mode, expr_id_placeholder()); + let typechecker = Typechecker::new(&schema, mode); let mut type_errors = HashSet::new(); - let actual = typechecker.typecheck_expr(&expr, &mut type_errors); + let pid = expr_id_placeholder(); + let actual = typechecker.typecheck_expr(&expr, &pid, &mut type_errors); assert_matches!(actual, TypecheckAnswer::TypecheckFail { expr_recovery_type } => { match (expected_ty.as_ref(), expr_recovery_type.data()) { (None, None) => (), diff --git a/cedar-policy-validator/src/typecheck/test/type_annotation.rs b/cedar-policy-validator/src/typecheck/test/type_annotation.rs index 19d798945..752205cd6 100644 --- a/cedar-policy-validator/src/typecheck/test/type_annotation.rs +++ b/cedar-policy-validator/src/typecheck/test/type_annotation.rs @@ -31,13 +31,9 @@ fn assert_expr_has_annotated_ast(e: &Expr, annotated: &Expr>) { let schema = empty_schema_file() .try_into() .expect("Failed to construct schema."); - let typechecker = Typechecker::new( - &schema, - ValidationMode::default(), - PolicyID::from_string("0"), - ); + let typechecker = Typechecker::new(&schema, ValidationMode::default()); let mut errs = HashSet::new(); - assert_matches!(typechecker.typecheck_expr(e, &mut errs), crate::typecheck::TypecheckAnswer::TypecheckSuccess { expr_type, .. } => { + assert_matches!(typechecker.typecheck_expr(e, &PolicyID::from_string("0"), &mut errs), crate::typecheck::TypecheckAnswer::TypecheckSuccess { expr_type, .. } => { assert_eq!(&expr_type, annotated); }); } @@ -146,10 +142,10 @@ fn expr_typechecks_with_correct_annotation() { .unwrap() .try_into() .expect("Failed to construct schema."); - let tc = Typechecker::new(&schema, ValidationMode::default(), expr_id_placeholder()); + let tc = Typechecker::new(&schema, ValidationMode::default()); let mut errs = HashSet::new(); let euid = EntityUID::with_eid_and_type("Foo", "bar").unwrap(); - match tc.typecheck_expr(&Expr::val(euid.clone()), &mut errs) { + match tc.typecheck_expr(&Expr::val(euid.clone()), &expr_id_placeholder(), &mut errs) { crate::typecheck::TypecheckAnswer::TypecheckSuccess { expr_type, .. } => { assert_eq!( &expr_type, diff --git a/cedar-policy/src/api.rs b/cedar-policy/src/api.rs index ea84298fd..69df9eedf 100644 --- a/cedar-policy/src/api.rs +++ b/cedar-policy/src/api.rs @@ -2668,11 +2668,7 @@ impl RequestEnv { // This function is called by [`Template::get_valid_request_envs`] and // [`Policy::get_valid_request_envs`] fn get_valid_request_envs(ast: &ast::Template, s: &Schema) -> impl Iterator { - let tc = Typechecker::new( - &s.0, - cedar_policy_validator::ValidationMode::default(), - ast.id().clone(), - ); + let tc = Typechecker::new(&s.0, cedar_policy_validator::ValidationMode::default()); tc.typecheck_by_request_env(ast) .into_iter() .filter_map(|(env, pc)| { From 8976579daa557480e8ae8a029ca2d20e5cc3a5c3 Mon Sep 17 00:00:00 2001 From: B-Lorentz <44694582+B-Lorentz@users.noreply.github.com> Date: Tue, 14 Jan 2025 15:07:34 +0100 Subject: [PATCH 6/7] Minor refactor to partial eval (#1417) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Lőrinc Bódy --- .../src/authorizer/partial_response.rs | 82 +++++++++---------- cedar-policy-core/src/evaluator.rs | 28 +++++-- 2 files changed, 62 insertions(+), 48 deletions(-) diff --git a/cedar-policy-core/src/authorizer/partial_response.rs b/cedar-policy-core/src/authorizer/partial_response.rs index fdf3b500c..8059babd0 100644 --- a/cedar-policy-core/src/authorizer/partial_response.rs +++ b/cedar-policy-core/src/authorizer/partial_response.rs @@ -337,23 +337,11 @@ impl PartialResponse { ) -> Result { let mut context = self.request.context.clone(); - let principal = if let Some((key, val)) = mapping.get_key_value("principal") { - self.request.principal().concretize(key, val)? - } else { - self.request.principal().clone() - }; + let principal = self.request.principal().concretize("principal", mapping)?; - let action = if let Some((key, val)) = mapping.get_key_value("action") { - self.request.action().concretize(key, val)? - } else { - self.request.action().clone() - }; + let action = self.request.action.concretize("action", mapping)?; - let resource = if let Some((key, val)) = mapping.get_key_value("resource") { - self.request.resource().concretize(key, val)? - } else { - self.request.resource().clone() - }; + let resource = self.request.resource.concretize("resource", mapping)?; if let Some((key, val)) = mapping.get_key_value("context") { if let Ok(attrs) = val.get_as_record() { @@ -406,38 +394,48 @@ impl PartialResponse { } impl EntityUIDEntry { - fn concretize(&self, key: &SmolStr, val: &Value) -> Result { - if let Ok(uid) = val.get_as_entity() { - match self { - EntityUIDEntry::Known { euid, .. } => Err(ConcretizationError::VarConfictError { - id: key.to_owned(), - existing_value: euid.as_ref().clone().into(), - given_value: val.clone(), - }), - EntityUIDEntry::Unknown { ty: None, .. } => { - Ok(EntityUIDEntry::known(uid.clone(), None)) - } - EntityUIDEntry::Unknown { - ty: Some(type_of_unknown), - .. - } => { - if type_of_unknown == uid.entity_type() { - Ok(EntityUIDEntry::known(uid.clone(), None)) - } else { - Err(ConcretizationError::EntityTypeConfictError { - id: key.to_owned(), - existing_value: type_of_unknown.clone(), - given_value: val.to_owned(), + fn concretize( + &self, + key: &str, + mapping: &HashMap, + ) -> Result { + if let Some(val) = mapping.get(key) { + if let Ok(uid) = val.get_as_entity() { + match self { + EntityUIDEntry::Known { euid, .. } => { + Err(ConcretizationError::VarConfictError { + id: key.into(), + existing_value: euid.as_ref().clone().into(), + given_value: val.clone(), }) } + EntityUIDEntry::Unknown { ty: None, .. } => { + Ok(EntityUIDEntry::known(uid.clone(), None)) + } + EntityUIDEntry::Unknown { + ty: Some(type_of_unknown), + .. + } => { + if type_of_unknown == uid.entity_type() { + Ok(EntityUIDEntry::known(uid.clone(), None)) + } else { + Err(ConcretizationError::EntityTypeConfictError { + id: key.into(), + existing_value: type_of_unknown.clone(), + given_value: val.to_owned(), + }) + } + } } + } else { + Err(ConcretizationError::ValueError { + id: key.into(), + expected_type: "entity", + given_value: val.to_owned(), + }) } } else { - Err(ConcretizationError::ValueError { - id: key.to_owned(), - expected_type: "entity", - given_value: val.to_owned(), - }) + Ok(self.clone()) } } } diff --git a/cedar-policy-core/src/evaluator.rs b/cedar-policy-core/src/evaluator.rs index 5eb158a53..a369cea65 100644 --- a/cedar-policy-core/src/evaluator.rs +++ b/cedar-policy-core/src/evaluator.rs @@ -380,20 +380,20 @@ impl<'e> Evaluator<'e> { // NOTE: There are more precise partial eval opportunities here, esp w/ typed unknowns // Current limitations: // Operators are not partially evaluated, except in a few 'simple' cases when comparing a concrete value with an unknown of known type - // implemented in short_circuit_typed_residual + // implemented in short_circuit_* let (arg1, arg2) = match ( self.partial_interpret(arg1, slots)?, self.partial_interpret(arg2, slots)?, ) { (PartialValue::Value(v1), PartialValue::Value(v2)) => (v1, v2), (PartialValue::Value(v1), PartialValue::Residual(e2)) => { - if let Some(val) = self.short_circuit_typed_residual(&v1, &e2, *op) { + if let Some(val) = self.short_circuit_value_and_residual(&v1, &e2, *op) { return Ok(val); } return Ok(PartialValue::Residual(Expr::binary_app(*op, v1.into(), e2))); } (PartialValue::Residual(e1), PartialValue::Value(v2)) => { - if let Some(val) = self.short_circuit_typed_residual(&v2, &e1, *op) { + if let Some(val) = self.short_circuit_residual_and_value(&e1, &v2, *op) { return Ok(val); } return Ok(PartialValue::Residual(Expr::binary_app(*op, e1, v2.into()))); @@ -922,10 +922,26 @@ impl<'e> Evaluator<'e> { } } - /// Evaluate a binary operation between a value and a residual expression. If despite the unknown contained in the residual, concrete result + /// Evaluate a binary operation between a residual expression (left) and a value (right). If despite the unknown contained in the residual, concrete result /// can be obtained (using the type annotation on the residual), it is returned. - /// Since it is not aware which of the inputs is on the left side, and which on the right, it needs to return None for all non-commutative operations. - fn short_circuit_typed_residual( + fn short_circuit_residual_and_value( + &self, + e1: &Expr, + v2: &Value, + op: BinaryOp, + ) -> Option { + match op { + // Since these operators are commutative, we can use just one order, and have one implementation of the actual logic + BinaryOp::Add | BinaryOp::Eq | BinaryOp::Mul | BinaryOp::ContainsAny => { + self.short_circuit_value_and_residual(v2, e1, op) + } + _ => None, + } + } + + /// Evaluate a binary operation between a value (left) and a residual expression (right). If despite the unknown contained in the residual, concrete result + /// can be obtained (using the type annotation on the residual), it is returned. + fn short_circuit_value_and_residual( &self, v1: &Value, e2: &Expr, From 875ba48e25eaa0617c681c942cb533266708b7c6 Mon Sep 17 00:00:00 2001 From: shaobo-he-aws <130499339+shaobo-he-aws@users.noreply.github.com> Date: Tue, 14 Jan 2025 08:46:00 -0800 Subject: [PATCH 7/7] Make datetime extension an experimental feature (#1415) Signed-off-by: Shaobo He Co-authored-by: John Kastner <130772734+john-h-kastner-aws@users.noreply.github.com> --- cedar-policy-core/Cargo.toml | 2 +- cedar-policy-core/src/ast/extension.rs | 36 +- cedar-policy-core/src/evaluator.rs | 320 ++++++--------- cedar-policy-core/src/extensions.rs | 21 +- cedar-policy-core/src/extensions/datetime.rs | 4 + cedar-policy-core/src/extensions/decimal.rs | 5 +- cedar-policy-core/src/extensions/ipaddr.rs | 5 +- .../src/extensions/partial_evaluation.rs | 1 + cedar-policy-validator/Cargo.toml | 2 +- .../src/extension_schema.rs | 11 + cedar-policy-validator/src/extensions.rs | 27 +- .../src/extensions/datetime.rs | 9 +- .../src/extensions/decimal.rs | 2 +- .../src/extensions/ipaddr.rs | 2 +- .../src/extensions/partial_evaluation.rs | 2 +- cedar-policy-validator/src/schema.rs | 51 +-- cedar-policy-validator/src/typecheck.rs | 50 ++- .../src/typecheck/test/expr.rs | 363 +++++++++++------- .../src/typecheck/test/partial.rs | 13 +- cedar-policy-validator/src/types.rs | 9 - cedar-policy/CHANGELOG.md | 1 + cedar-policy/Cargo.toml | 3 +- 22 files changed, 507 insertions(+), 432 deletions(-) diff --git a/cedar-policy-core/Cargo.toml b/cedar-policy-core/Cargo.toml index 0b25a5fff..2870caab5 100644 --- a/cedar-policy-core/Cargo.toml +++ b/cedar-policy-core/Cargo.toml @@ -46,7 +46,7 @@ prost = { version = "0.13", optional = true } [features] # by default, enable all Cedar extensions -default = ["ipaddr", "decimal", "datetime"] +default = ["ipaddr", "decimal"] ipaddr = [] decimal = ["dep:regex"] datetime = ["dep:chrono", "dep:regex"] diff --git a/cedar-policy-core/src/ast/extension.rs b/cedar-policy-core/src/ast/extension.rs index 559af3f28..d8161ca89 100644 --- a/cedar-policy-core/src/ast/extension.rs +++ b/cedar-policy-core/src/ast/extension.rs @@ -14,35 +14,15 @@ * limitations under the License. */ -pub use names::TYPES_WITH_OPERATOR_OVERLOADING; - use crate::ast::*; use crate::entities::SchemaType; use crate::evaluator; use std::any::Any; -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use std::fmt::Debug; use std::panic::{RefUnwindSafe, UnwindSafe}; use std::sync::Arc; -// PANIC SAFETY: `Name`s in here are valid `Name`s -#[allow(clippy::expect_used)] -mod names { - use std::collections::BTreeSet; - - use super::Name; - - lazy_static::lazy_static! { - /// Extension type names that support operator overloading - // INVARIANT: this set must not be empty. - pub static ref TYPES_WITH_OPERATOR_OVERLOADING : BTreeSet = - BTreeSet::from_iter( - [Name::parse_unqualified_name("datetime").expect("valid identifier"), - Name::parse_unqualified_name("duration").expect("valid identifier")] - ); - } -} - /// Cedar extension. /// /// An extension can define new types and functions on those types. (Currently, @@ -54,14 +34,21 @@ pub struct Extension { name: Name, /// Extension functions. These are legal to call in Cedar expressions. functions: HashMap, + /// Types with operator overloading + types_with_operator_overloading: BTreeSet, } impl Extension { /// Create a new `Extension` with the given name and extension functions - pub fn new(name: Name, functions: impl IntoIterator) -> Self { + pub fn new( + name: Name, + functions: impl IntoIterator, + types_with_operator_overloading: impl IntoIterator, + ) -> Self { Self { name, functions: functions.into_iter().map(|f| (f.name.clone(), f)).collect(), + types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(), } } @@ -86,6 +73,11 @@ impl Extension { pub fn ext_types(&self) -> impl Iterator + '_ { self.funcs().flat_map(|func| func.ext_types()) } + + /// Iterate over extension types with operator overloading + pub fn types_with_operator_overloading(&self) -> impl Iterator + '_ { + self.types_with_operator_overloading.iter() + } } impl std::fmt::Debug for Extension { diff --git a/cedar-policy-core/src/evaluator.rs b/cedar-policy-core/src/evaluator.rs index a369cea65..eb73c6568 100644 --- a/cedar-policy-core/src/evaluator.rs +++ b/cedar-policy-core/src/evaluator.rs @@ -30,7 +30,7 @@ pub use err::evaluation_errors; pub use err::EvaluationError; pub(crate) use err::*; use evaluation_errors::*; -use itertools::Either; +use itertools::{Either, Itertools}; use nonempty::nonempty; use smol_str::SmolStr; @@ -185,6 +185,16 @@ impl<'e> RestrictedEvaluator<'e> { } } +pub(crate) fn valid_comparison_op_types(extensions: &Extensions<'_>) -> nonempty::NonEmpty { + let mut expected_types = nonempty::NonEmpty::singleton(Type::Long); + expected_types.extend( + extensions + .types_with_operator_overloading() + .map(|n| Type::Extension { name: n.clone() }), + ); + expected_types +} + impl<'e> Evaluator<'e> { /// Create a fresh `Evaluator` for the given `request`, which uses the given /// `Entities` to resolve entity references. Use the given `Extension`s when @@ -432,15 +442,38 @@ impl<'e> Evaluator<'e> { Ok(ext_op(x, y).into()) } // throw type errors - (ValueKind::Lit(Literal::Long(_)), _) => Err(EvaluationError::type_error_single(Type::Long, &arg2)), - (_, ValueKind::Lit(Literal::Long(_))) => Err(EvaluationError::type_error_single(Type::Long, &arg1)), - (ValueKind::ExtensionValue(x), _) if x.supports_operator_overloading() => Err(EvaluationError::type_error_single(Type::Extension { name: x.typename() }, &arg2)), - (_, ValueKind::ExtensionValue(y)) if y.supports_operator_overloading() => Err(EvaluationError::type_error_single(Type::Extension { name: y.typename() }, &arg1)), - (ValueKind::ExtensionValue(x), ValueKind::ExtensionValue(y)) if x.typename() == y.typename() => Err(EvaluationError::type_error_with_advice(Extensions::types_with_operator_overloading().map(|name| Type::Extension { name} ), &arg1, "Only extension types `datetime` and `duration` support operator overloading".to_string())), + (ValueKind::Lit(Literal::Long(_)), _) => { + Err(EvaluationError::type_error_single(Type::Long, &arg2)) + } + (_, ValueKind::Lit(Literal::Long(_))) => { + Err(EvaluationError::type_error_single(Type::Long, &arg1)) + } + (ValueKind::ExtensionValue(x), _) + if x.supports_operator_overloading() => + { + Err(EvaluationError::type_error_single( + Type::Extension { name: x.typename() }, + &arg2, + )) + } + (_, ValueKind::ExtensionValue(y)) + if y.supports_operator_overloading() => + { + Err(EvaluationError::type_error_single( + Type::Extension { name: y.typename() }, + &arg1, + )) + } _ => { - let mut expected_types = Extensions::types_with_operator_overloading().map(|name| Type::Extension { name }); - expected_types.push(Type::Long); - Err(EvaluationError::type_error_with_advice(expected_types, &arg1, "Only `Long` and extension types `datetime`, `duration` support comparison".to_string())) + let expected_types = valid_comparison_op_types(&self.extensions); + Err(EvaluationError::type_error_with_advice( + expected_types.clone(), + &arg1, + format!( + "Only types {} support comparison", + expected_types.into_iter().sorted().join(", ") + ), + )) } } } @@ -2951,7 +2984,18 @@ pub(crate) mod test { fn interpret_compares() { let request = basic_request(); let entities = basic_entities(); - let eval = Evaluator::new(request, &entities, Extensions::all_available()); + let extensions = Extensions::all_available(); + let eval = Evaluator::new(request, &entities, extensions); + let expected_types = valid_comparison_op_types(extensions); + let assert_type_error = |expr, actual_type| { + assert_matches!( + eval.interpret_inline_policy(&expr), + Err(EvaluationError::TypeError(TypeError { expected, actual, .. })) => { + assert_eq!(expected, expected_types.clone()); + assert_eq!(actual, actual_type); + } + ); + }; // 3 < 303 assert_eq!( eval.interpret_inline_policy(&Expr::less(Expr::val(3), Expr::val(303))), @@ -3013,214 +3057,91 @@ pub(crate) mod test { Ok(Value::from(true)) ); // false < true - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val(false), Expr::val(true))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val(false), Expr::val(true)), Type::Bool); + // false < false - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val(false), Expr::val(false))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val(false), Expr::val(false)), Type::Bool); + // true <= false - assert_matches!( - eval.interpret_inline_policy(&Expr::lesseq(Expr::val(true), Expr::val(false))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::lesseq(Expr::val(true), Expr::val(false)), Type::Bool); + // false <= false - assert_matches!( - eval.interpret_inline_policy(&Expr::lesseq(Expr::val(false), Expr::val(false))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected,nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::lesseq(Expr::val(false), Expr::val(false)), Type::Bool); + // false > true - assert_matches!( - eval.interpret_inline_policy(&Expr::greater(Expr::val(false), Expr::val(true))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::greater(Expr::val(false), Expr::val(true)), Type::Bool); + // true > true - assert_matches!( - eval.interpret_inline_policy(&Expr::greater(Expr::val(true), Expr::val(true))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::greater(Expr::val(true), Expr::val(true)), Type::Bool); + // true >= false - assert_matches!( - eval.interpret_inline_policy(&Expr::greatereq(Expr::val(true), Expr::val(false))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::greatereq(Expr::val(true), Expr::val(false)), + Type::Bool, ); + // true >= true - assert_matches!( - eval.interpret_inline_policy(&Expr::greatereq(Expr::val(true), Expr::val(true))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Bool); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::greatereq(Expr::val(true), Expr::val(true)), + Type::Bool, ); + // bc < zzz - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val("bc"), Expr::val("zzz"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val("bc"), Expr::val("zzz")), Type::String); // banana < zzz - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val("banana"), Expr::val("zzz"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::less(Expr::val("banana"), Expr::val("zzz")), + Type::String, ); // "" < zzz - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val(""), Expr::val("zzz"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val(""), Expr::val("zzz")), Type::String); // a < 1 - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val("a"), Expr::val("1"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val("a"), Expr::val("1")), Type::String); // a < A - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val("a"), Expr::val("A"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val("a"), Expr::val("A")), Type::String); // A < A - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val("A"), Expr::val("A"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::less(Expr::val("A"), Expr::val("A")), Type::String); // zebra < zebras - assert_matches!( - eval.interpret_inline_policy(&Expr::less(Expr::val("zebra"), Expr::val("zebras"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::less(Expr::val("zebra"), Expr::val("zebras")), + Type::String, ); // zebra <= zebras - assert_matches!( - eval.interpret_inline_policy(&Expr::lesseq(Expr::val("zebra"), Expr::val("zebras"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::lesseq(Expr::val("zebra"), Expr::val("zebras")), + Type::String, ); // zebras <= zebras - assert_matches!( - eval.interpret_inline_policy(&Expr::lesseq(Expr::val("zebras"), Expr::val("zebras"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::lesseq(Expr::val("zebras"), Expr::val("zebras")), + Type::String, ); // zebras <= Zebras - assert_matches!( - eval.interpret_inline_policy(&Expr::lesseq(Expr::val("zebras"), Expr::val("Zebras"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::lesseq(Expr::val("zebras"), Expr::val("Zebras")), + Type::String, ); // 123 > 78 - assert_matches!( - eval.interpret_inline_policy(&Expr::greater(Expr::val("123"), Expr::val("78"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::greater(Expr::val("123"), Expr::val("78")), + Type::String, ); // zebras >= zebras - assert_matches!( - eval.interpret_inline_policy(&Expr::greatereq( - Expr::val(" zebras"), - Expr::val("zebras") - )), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::greatereq(Expr::val(" zebras"), Expr::val("zebras")), + Type::String, ); // "" >= "" - assert_matches!( - eval.interpret_inline_policy(&Expr::greatereq(Expr::val(""), Expr::val(""))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } - ); + assert_type_error(Expr::greatereq(Expr::val(""), Expr::val("")), Type::String); // "" >= _hi - assert_matches!( - eval.interpret_inline_policy(&Expr::greatereq(Expr::val(""), Expr::val("_hi"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::greatereq(Expr::val(""), Expr::val("_hi")), + Type::String, ); // 🦀 >= _hi - assert_matches!( - eval.interpret_inline_policy(&Expr::greatereq(Expr::val("🦀"), Expr::val("_hi"))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::String); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + assert_type_error( + Expr::greatereq(Expr::val("🦀"), Expr::val("_hi")), + Type::String, ); // 2 < "4" assert_matches!( @@ -3259,18 +3180,22 @@ pub(crate) mod test { } ); // [1, 2] < [47, 0] - assert_matches!( - eval.interpret_inline_policy(&Expr::less( + assert_type_error( + Expr::less( Expr::set(vec![Expr::val(1), Expr::val(2)]), - Expr::set(vec![Expr::val(47), Expr::val(0)]) - )), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}, Type::Long]); - assert_eq!(actual, Type::Set); - assert_eq!(advice, Some("Only `Long` and extension types `datetime`, `duration` support comparison".into())); - } + Expr::set(vec![Expr::val(47), Expr::val(0)]), + ), + Type::Set, ); + } + #[cfg(feature = "datetime")] + #[test] + fn interpret_datetime_extension_compares() { + let request = basic_request(); + let entities = basic_entities(); + let extensions = Extensions::all_available(); + let eval = Evaluator::new(request, &entities, extensions); let datetime_constructor: Name = "datetime".parse().unwrap(); let duration_constructor: Name = "duration".parse().unwrap(); assert_matches!(eval.interpret_inline_policy( @@ -3494,10 +3419,9 @@ pub(crate) mod test { Expr::call_extension_fn( "decimal".parse().unwrap(), vec![Value::from("3.0").into()]))), - Err(EvaluationError::TypeError(TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: datetime_constructor }, Type::Extension { name: duration_constructor }]); + Err(EvaluationError::TypeError(TypeError { expected, actual, .. })) => { + assert_eq!(expected, valid_comparison_op_types(extensions)); assert_eq!(actual, Type::Extension { name: "decimal".parse().unwrap() }); - assert_eq!(advice, Some("Only extension types `datetime` and `duration` support operator overloading".into())); }); } diff --git a/cedar-policy-core/src/extensions.rs b/cedar-policy-core/src/extensions.rs index 57bf8587c..bfb66c916 100644 --- a/cedar-policy-core/src/extensions.rs +++ b/cedar-policy-core/src/extensions.rs @@ -28,11 +28,10 @@ pub mod partial_evaluation; use std::collections::HashMap; -use crate::ast::{Extension, ExtensionFunction, Name, TYPES_WITH_OPERATOR_OVERLOADING}; +use crate::ast::{Extension, ExtensionFunction, Name}; use crate::entities::SchemaType; use crate::parser::Loc; use miette::Diagnostic; -use nonempty::NonEmpty; use thiserror::Error; use self::extension_function_lookup_errors::FuncDoesNotExistError; @@ -98,21 +97,15 @@ impl Extensions<'static> { pub fn none() -> &'static Extensions<'static> { &EXTENSIONS_NONE } - - /// Obtain the non-empty vector of types supporting operator overloading - pub fn types_with_operator_overloading() -> NonEmpty { - // PANIC SAFETY: There are more than one element in `TYPES_WITH_OPERATOR_OVERLOADING` - #[allow(clippy::unwrap_used)] - NonEmpty::collect(TYPES_WITH_OPERATOR_OVERLOADING.iter().cloned()).unwrap() - } - - /// Iterate over extension types that support operator overloading - pub fn iter_type_with_operator_overloading() -> impl Iterator { - TYPES_WITH_OPERATOR_OVERLOADING.iter() - } } impl<'a> Extensions<'a> { + /// Obtain the non-empty vector of types supporting operator overloading + pub fn types_with_operator_overloading(&self) -> impl Iterator + '_ { + self.extensions + .iter() + .flat_map(|ext| ext.types_with_operator_overloading()) + } /// Get a new `Extensions` with these specific extensions enabled. pub fn specific_extensions( extensions: &'a [Extension], diff --git a/cedar-policy-core/src/extensions/datetime.rs b/cedar-policy-core/src/extensions/datetime.rs index b98c11856..3327ee4bf 100644 --- a/cedar-policy-core/src/extensions/datetime.rs +++ b/cedar-policy-core/src/extensions/datetime.rs @@ -718,6 +718,10 @@ pub fn extension() -> Extension { duration_type, ), ], + [ + DATETIME_CONSTRUCTOR_NAME.clone(), + DURATION_CONSTRUCTOR_NAME.clone(), + ], ) } diff --git a/cedar-policy-core/src/extensions/decimal.rs b/cedar-policy-core/src/extensions/decimal.rs index bd601a952..3cd6bd77e 100644 --- a/cedar-policy-core/src/extensions/decimal.rs +++ b/cedar-policy-core/src/extensions/decimal.rs @@ -308,6 +308,7 @@ pub fn extension() -> Extension { (decimal_type.clone(), decimal_type), ), ], + std::iter::empty(), ) } @@ -627,12 +628,12 @@ mod tests { &parse_expr(r#"decimal("1.23") < decimal("1.24")"#).expect("parsing error") ), Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}]); + assert_eq!(expected, nonempty![Type::Long]); assert_eq!(actual, Type::Extension { name: Name::parse_unqualified_name("decimal") .expect("should be a valid identifier") }); - assert_eq!(advice, Some("Only extension types `datetime` and `duration` support operator overloading".into())); + assert_eq!(advice, Some("Only types long support comparison".into())); } ); assert_matches!( diff --git a/cedar-policy-core/src/extensions/ipaddr.rs b/cedar-policy-core/src/extensions/ipaddr.rs index ed88f790b..fa0932f99 100644 --- a/cedar-policy-core/src/extensions/ipaddr.rs +++ b/cedar-policy-core/src/extensions/ipaddr.rs @@ -438,6 +438,7 @@ pub fn extension() -> Extension { (ipaddr_type.clone(), ipaddr_type), ), ], + std::iter::empty(), ) } @@ -609,12 +610,12 @@ mod tests { assert_matches!( eval.interpret_inline_policy(&Expr::less(ip("127.0.0.1"), ip("10.0.0.10"))), Err(EvaluationError::TypeError(evaluation_errors::TypeError { expected, actual, advice, .. })) => { - assert_eq!(expected, nonempty![Type::Extension { name: "datetime".parse().unwrap()}, Type::Extension { name: "duration".parse().unwrap()}]); + assert_eq!(expected, nonempty![Type::Long]); assert_eq!(actual, Type::Extension { name: Name::parse_unqualified_name("ipaddr") .expect("should be a valid identifier") }); - assert_eq!(advice, Some("Only extension types `datetime` and `duration` support operator overloading".into())); + assert_eq!(advice, Some("Only types long support comparison".into())); } ); // test that isIpv4 on a String is an error diff --git a/cedar-policy-core/src/extensions/partial_evaluation.rs b/cedar-policy-core/src/extensions/partial_evaluation.rs index c271c55b2..63ba9af43 100644 --- a/cedar-policy-core/src/extensions/partial_evaluation.rs +++ b/cedar-policy-core/src/extensions/partial_evaluation.rs @@ -41,5 +41,6 @@ pub fn extension() -> Extension { Box::new(create_new_unknown), SchemaType::String, )], + std::iter::empty(), ) } diff --git a/cedar-policy-validator/Cargo.toml b/cedar-policy-validator/Cargo.toml index ec7cc59e2..a5f6de4eb 100644 --- a/cedar-policy-validator/Cargo.toml +++ b/cedar-policy-validator/Cargo.toml @@ -38,7 +38,7 @@ prost = { version = "0.13", optional = true } [features] # by default, enable all Cedar extensions -default = ["ipaddr", "decimal", "datetime"] +default = ["ipaddr", "decimal"] # when enabling a feature, make sure that the Core feature is also enabled ipaddr = ["cedar-policy-core/ipaddr"] decimal = ["cedar-policy-core/decimal"] diff --git a/cedar-policy-validator/src/extension_schema.rs b/cedar-policy-validator/src/extension_schema.rs index 764b5264a..648f017e5 100644 --- a/cedar-policy-validator/src/extension_schema.rs +++ b/cedar-policy-validator/src/extension_schema.rs @@ -14,6 +14,8 @@ * limitations under the License. */ +use std::collections::BTreeSet; + use crate::types::Type; use cedar_policy_core::ast::{Expr, Name}; @@ -23,6 +25,8 @@ pub struct ExtensionSchema { name: Name, /// Type information for extension functions function_types: Vec, + /// Types that support operator overloading + types_with_operator_overloading: BTreeSet, } impl std::fmt::Debug for ExtensionSchema { @@ -36,10 +40,12 @@ impl ExtensionSchema { pub fn new( name: Name, function_types: impl IntoIterator, + types_with_operator_overloading: impl IntoIterator, ) -> Self { Self { name, function_types: function_types.into_iter().collect(), + types_with_operator_overloading: types_with_operator_overloading.into_iter().collect(), } } @@ -51,6 +57,11 @@ impl ExtensionSchema { pub fn function_types(&self) -> impl Iterator { self.function_types.iter() } + + /// Get all extension types that support operator overloading + pub fn types_with_operator_overloading(&self) -> impl Iterator { + self.types_with_operator_overloading.iter() + } } /// The type of a function used to perform custom argument validation on an diff --git a/cedar-policy-validator/src/extensions.rs b/cedar-policy-validator/src/extensions.rs index da393d137..937a0db5c 100644 --- a/cedar-policy-validator/src/extensions.rs +++ b/cedar-policy-validator/src/extensions.rs @@ -16,7 +16,7 @@ //! This module contains type information for all of the standard Cedar extensions. -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap}; use cedar_policy_core::{ ast::{Name, RestrictedExpr, Value}, @@ -57,7 +57,7 @@ lazy_static::lazy_static! { static ref ALL_AVAILABLE_EXTENSION_SCHEMAS : ExtensionSchemas<'static> = ExtensionSchemas::build_all_available(); } -/// Aggregate structure containing function signatures for multiple [`ExtensionSchema`]. +/// Aggregate structure containing information such as function signatures for multiple [`ExtensionSchema`]. /// Ensures that no function name is defined mode than once. /// Intentionally does not derive `Clone` to avoid clones of the `HashMap`. For the /// moment, it's easy to pass this around by reference. We could make this @@ -69,6 +69,8 @@ pub struct ExtensionSchemas<'a> { /// extension function lookup that at most one extension functions exists /// for a name. function_types: HashMap<&'a Name, &'a ExtensionFunctionType>, + /// Extension types that support operator overloading + types_with_operator_overloading: BTreeSet<&'a Name>, } impl<'a> ExtensionSchemas<'a> { @@ -98,7 +100,16 @@ impl<'a> ExtensionSchemas<'a> { ) .map_err(|name| FuncMultiplyDefinedError { name: name.clone() })?; - Ok(Self { function_types }) + // We already ensure that names of extension types do not collide, at the language level + let types_with_operator_overloading = extension_schemas + .iter() + .flat_map(|f| f.types_with_operator_overloading()) + .collect(); + + Ok(Self { + function_types, + types_with_operator_overloading, + }) } /// Get the [`ExtensionFunctionType`] for a function with this [`Name`]. @@ -106,6 +117,16 @@ impl<'a> ExtensionSchemas<'a> { pub fn func_type(&self, name: &Name) -> Option<&ExtensionFunctionType> { self.function_types.get(name).copied() } + + /// Query if `ext_ty_name` supports operator overloading + pub fn has_type_with_operator_overloading(&self, ext_ty_name: &Name) -> bool { + self.types_with_operator_overloading.contains(ext_ty_name) + } + + /// Get all extension types that support operator overloading + pub fn types_with_operator_overloading(&self) -> impl Iterator + '_ { + self.types_with_operator_overloading.iter().cloned() + } } /// Evaluates ane extension function on a single string literal argument. Used diff --git a/cedar-policy-validator/src/extensions/datetime.rs b/cedar-policy-validator/src/extensions/datetime.rs index 5425992bb..a337591c8 100644 --- a/cedar-policy-validator/src/extensions/datetime.rs +++ b/cedar-policy-validator/src/extensions/datetime.rs @@ -97,7 +97,8 @@ pub fn extension_schema() -> ExtensionSchema { let datetime_ty = Type::extension(datetime_ext.name().clone()); //PANIC SAFETY: `duration` is a valid name #[allow(clippy::unwrap_used)] - let duration_ty = Type::extension("duration".parse().unwrap()); + let duration_ty_name: Name = "duration".parse().unwrap(); + let duration_ty = Type::extension(duration_ty_name.clone()); let fun_tys = datetime_ext.funcs().map(|f| { let return_type = get_return_type(f.name(), &datetime_ty, &duration_ty); @@ -112,7 +113,11 @@ pub fn extension_schema() -> ExtensionSchema { get_argument_check(f.name()), ) }); - ExtensionSchema::new(datetime_ext.name().clone(), fun_tys) + ExtensionSchema::new( + datetime_ext.name().clone(), + fun_tys, + [datetime_ext.name().clone(), duration_ty_name], + ) } /// Extra validation step for the `datetime` function. diff --git a/cedar-policy-validator/src/extensions/decimal.rs b/cedar-policy-validator/src/extensions/decimal.rs index 4916b5f56..37758f76a 100644 --- a/cedar-policy-validator/src/extensions/decimal.rs +++ b/cedar-policy-validator/src/extensions/decimal.rs @@ -97,7 +97,7 @@ pub fn extension_schema() -> ExtensionSchema { get_argument_check(f.name()), ) }); - ExtensionSchema::new(decimal_ext.name().clone(), fun_tys) + ExtensionSchema::new(decimal_ext.name().clone(), fun_tys, std::iter::empty()) } /// Extra validation step for the `decimal` function. diff --git a/cedar-policy-validator/src/extensions/ipaddr.rs b/cedar-policy-validator/src/extensions/ipaddr.rs index eef5b4700..f1fd7bf69 100644 --- a/cedar-policy-validator/src/extensions/ipaddr.rs +++ b/cedar-policy-validator/src/extensions/ipaddr.rs @@ -96,7 +96,7 @@ pub fn extension_schema() -> ExtensionSchema { get_argument_check(f.name()), ) }); - ExtensionSchema::new(ipaddr_ext.name().clone(), fun_tys) + ExtensionSchema::new(ipaddr_ext.name().clone(), fun_tys, std::iter::empty()) } /// Extra validation step for the `ip` function. diff --git a/cedar-policy-validator/src/extensions/partial_evaluation.rs b/cedar-policy-validator/src/extensions/partial_evaluation.rs index 450ac5ba3..4bfdc0948 100644 --- a/cedar-policy-validator/src/extensions/partial_evaluation.rs +++ b/cedar-policy-validator/src/extensions/partial_evaluation.rs @@ -67,7 +67,7 @@ pub fn extension_schema() -> ExtensionSchema { None, ) }); - ExtensionSchema::new(pe_ext.name().clone(), fun_tys) + ExtensionSchema::new(pe_ext.name().clone(), fun_tys, std::iter::empty()) } #[cfg(test)] diff --git a/cedar-policy-validator/src/schema.rs b/cedar-policy-validator/src/schema.rs index eb37ddcbb..c55058366 100644 --- a/cedar-policy-validator/src/schema.rs +++ b/cedar-policy-validator/src/schema.rs @@ -3073,32 +3073,35 @@ pub(crate) mod test { .build()); }); - let src: serde_json::Value = json!({ - "": { - "commonTypes": { - "ty": { - "type": "Record", - "attributes": { - "a": { - "type": "Extension", - "name": "partial_evaluation", + #[cfg(feature = "datetime")] + { + let src: serde_json::Value = json!({ + "": { + "commonTypes": { + "ty": { + "type": "Record", + "attributes": { + "a": { + "type": "Extension", + "name": "partial_evaluation", + } } } - } - }, - "entityTypes": { }, - "actions": { }, - } - }); - let schema = ValidatorSchema::from_json_value(src.clone(), Extensions::all_available()); - assert_matches!(schema, Err(e) => { - expect_err( - &src, - &miette::Report::new(e), - &ExpectedErrorMessageBuilder::error("unknown extension type `partial_evaluation`") - .help("did you mean `duration`?") - .build()); - }); + }, + "entityTypes": { }, + "actions": { }, + } + }); + let schema = ValidatorSchema::from_json_value(src.clone(), Extensions::all_available()); + assert_matches!(schema, Err(e) => { + expect_err( + &src, + &miette::Report::new(e), + &ExpectedErrorMessageBuilder::error("unknown extension type `partial_evaluation`") + .help("did you mean `duration`?") + .build()); + }); + } } #[track_caller] diff --git a/cedar-policy-validator/src/typecheck.rs b/cedar-policy-validator/src/typecheck.rs index e997fc0ff..bed4fc4c5 100644 --- a/cedar-policy-validator/src/typecheck.rs +++ b/cedar-policy-validator/src/typecheck.rs @@ -49,7 +49,6 @@ use cedar_policy_core::{ PrincipalOrResourceConstraint, SlotId, Template, UnaryOp, Var, }, expr_builder::ExprBuilder as _, - extensions::Extensions, parser::Loc, }; @@ -1190,6 +1189,33 @@ impl<'a> SingleEnvTypechecker<'a> { } } + // Return if `ty` is a valid comparison operator type + // Currently, only primitive long and certain extension types are valid + fn is_valid_comparison_op_type(&self, ty: &Type) -> bool { + match ty { + Type::Primitive { + primitive_type: Primitive::Long, + } => true, + Type::ExtensionType { name } => { + self.extensions.has_type_with_operator_overloading(name) + } + _ => false, + } + } + + // Get all valid types satisfying `is_valid_comparison_op_type` + // Only used for error message construction + fn expected_comparison_op_types(&self) -> Vec { + let expected_types = self + .extensions + .types_with_operator_overloading() + .cloned() + .map(Type::extension) + .chain(std::iter::once(Type::primitive_long())) + .collect_vec(); + expected_types + } + /// A utility called by the main typecheck method to handle binary operator /// application. /// INVARIANT: `bin_expr` must be a `BinaryApp` @@ -1244,11 +1270,6 @@ impl<'a> SingleEnvTypechecker<'a> { } BinaryOp::Less | BinaryOp::LessEq => { - let expected_types = Extensions::iter_type_with_operator_overloading() - .cloned() - .map(Type::extension) - .chain(std::iter::once(Type::primitive_long())) - .collect_vec(); let ans_arg1 = self.typecheck(prior_capability, arg1, type_errors); ans_arg1.then_typecheck(|expr_ty_arg1, _| { let ans_arg2 = self.typecheck(prior_capability, arg2, type_errors); @@ -1261,13 +1282,13 @@ impl<'a> SingleEnvTypechecker<'a> { match (t1, t2) { (Some(Type::Never), Some(Type::Never)) => TypecheckAnswer::fail(expr), (Some(Type::Never), Some(other)) => { - if expected_types.contains(other) { + if self.is_valid_comparison_op_type(other) { TypecheckAnswer::success(expr) } else { type_errors.push(ValidationError::expected_one_of_types( expr_ty_arg2.source_loc().cloned(), self.policy_id.clone(), - expected_types, + self.expected_comparison_op_types(), other.clone(), None, )); @@ -1275,20 +1296,22 @@ impl<'a> SingleEnvTypechecker<'a> { } } (Some(other), Some(Type::Never)) => { - if expected_types.contains(other) { + if self.is_valid_comparison_op_type(other) { TypecheckAnswer::success(expr) } else { type_errors.push(ValidationError::expected_one_of_types( expr_ty_arg1.source_loc().cloned(), self.policy_id.clone(), - expected_types, + self.expected_comparison_op_types(), other.clone(), None, )); TypecheckAnswer::fail(expr) } } - (Some(t1), Some(t2)) if t1 == t2 && expected_types.contains(t1) => { + (Some(t1), Some(t2)) + if t1 == t2 && self.is_valid_comparison_op_type(t1) => + { TypecheckAnswer::success(expr) } ( @@ -1321,7 +1344,7 @@ impl<'a> SingleEnvTypechecker<'a> { )); TypecheckAnswer::fail(expr) } - (Some(lhs), Some(rhs)) if lhs.support_operator_overloading() => { + (Some(lhs), Some(rhs)) if self.is_valid_comparison_op_type(lhs) => { type_errors.push(ValidationError::expected_one_of_types( expr_ty_arg2.source_loc().cloned(), self.policy_id.clone(), @@ -1331,7 +1354,7 @@ impl<'a> SingleEnvTypechecker<'a> { )); TypecheckAnswer::fail(expr) } - (Some(lhs), Some(rhs)) if rhs.support_operator_overloading() => { + (Some(lhs), Some(rhs)) if self.is_valid_comparison_op_type(rhs) => { type_errors.push(ValidationError::expected_one_of_types( expr_ty_arg1.source_loc().cloned(), self.policy_id.clone(), @@ -1342,6 +1365,7 @@ impl<'a> SingleEnvTypechecker<'a> { TypecheckAnswer::fail(expr) } (Some(lhs), Some(rhs)) => { + let expected_types = self.expected_comparison_op_types(); type_errors.push(ValidationError::expected_one_of_types( expr_ty_arg1.source_loc().cloned(), self.policy_id.clone(), diff --git a/cedar-policy-validator/src/typecheck/test/expr.rs b/cedar-policy-validator/src/typecheck/test/expr.rs index d9b7960c7..1f165ca91 100644 --- a/cedar-policy-validator/src/typecheck/test/expr.rs +++ b/cedar-policy-validator/src/typecheck/test/expr.rs @@ -21,7 +21,7 @@ use std::{str::FromStr, vec}; use cedar_policy_core::{ - ast::{BinaryOp, EntityUID, Expr, Name, Pattern, PatternElem, SlotId, Value, Var}, + ast::{BinaryOp, EntityUID, Expr, Pattern, PatternElem, SlotId, Var}, est::Annotations, extensions::Extensions, }; @@ -1045,80 +1045,20 @@ fn like_typecheck_fails() { ); } -#[inline] -fn get_datetime_constructor_name() -> Name { - "datetime".parse().unwrap() -} - -#[inline] -fn get_duration_constructor_name() -> Name { - "duration".parse().unwrap() -} - #[test] fn less_than_typechecks() { assert_typechecks_empty_schema( Expr::less(Expr::val(1), Expr::val(2)), Type::primitive_boolean(), ); - assert_typechecks_empty_schema( - Expr::less( - Expr::call_extension_fn( - get_datetime_constructor_name(), - vec![Value::from("1970-01-01").into()], - ), - Expr::call_extension_fn( - get_datetime_constructor_name(), - vec![Value::from("1970-01-02").into()], - ), - ), - Type::primitive_boolean(), - ); - assert_typechecks_empty_schema( - Expr::lesseq( - Expr::call_extension_fn( - get_datetime_constructor_name(), - vec![Value::from("1970-01-01").into()], - ), - Expr::call_extension_fn( - get_datetime_constructor_name(), - vec![Value::from("1970-01-02").into()], - ), - ), - Type::primitive_boolean(), - ); - assert_typechecks_empty_schema( - Expr::less( - Expr::call_extension_fn( - get_duration_constructor_name(), - vec![Value::from("1h").into()], - ), - Expr::call_extension_fn( - get_duration_constructor_name(), - vec![Value::from("2h").into()], - ), - ), - Type::primitive_boolean(), - ); - assert_typechecks_empty_schema( - Expr::lesseq( - Expr::call_extension_fn( - get_duration_constructor_name(), - vec![Value::from("1h").into()], - ), - Expr::call_extension_fn( - get_duration_constructor_name(), - vec![Value::from("2h").into()], - ), - ), - Type::primitive_boolean(), - ); } #[test] fn less_than_typecheck_fails() { - let expected_types = Extensions::types_with_operator_overloading() - .into_iter() + let extensions = Extensions::all_available(); + let expected_types = extensions + .types_with_operator_overloading() + .cloned() .map(Type::extension) .chain(std::iter::once(Type::primitive_long())) .collect_vec(); @@ -1181,73 +1121,6 @@ fn less_than_typecheck_fails() { None, )], ); - - let src = r#"true < duration("1h")"#; - let errors = - assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); - assert_sets_equal( - errors, - [ValidationError::expected_type( - get_loc(src, "true"), - expr_id_placeholder(), - Type::ExtensionType { - name: get_duration_constructor_name(), - }, - Type::singleton_boolean(true), - None, - )], - ); - - // Error reporting favors long - let src = r#"duration("1d") < 1"#; - let errors = - assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); - assert_sets_equal( - errors, - [ValidationError::expected_type( - get_loc(src, r#"duration("1d")"#), - expr_id_placeholder(), - Type::primitive_long(), - Type::ExtensionType { - name: get_duration_constructor_name(), - }, - None, - )], - ); - - let src = r#"1 < duration("1d")"#; - let errors = - assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); - assert_sets_equal( - errors, - [ValidationError::expected_type( - get_loc(src, r#"duration("1d")"#), - expr_id_placeholder(), - Type::primitive_long(), - Type::ExtensionType { - name: get_duration_constructor_name(), - }, - None, - )], - ); - - let src = r#"datetime("1970-01-01") < duration("1d")"#; - let errors = - assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); - assert_sets_equal( - errors, - [ValidationError::expected_type( - get_loc(src, r#"duration("1d")"#), - expr_id_placeholder(), - Type::ExtensionType { - name: get_datetime_constructor_name(), - }, - Type::ExtensionType { - name: get_duration_constructor_name(), - }, - None, - )], - ); } #[test] @@ -1476,3 +1349,229 @@ fn is_typechecks() { Type::singleton_boolean(false), ); } + +#[cfg(feature = "datetime")] +mod datetime { + use cedar_policy_core::{ + ast::{Expr, Name, Value}, + extensions::Extensions, + }; + use itertools::Itertools; + + use crate::{ + typecheck::test::test_utils::{expr_id_placeholder, get_loc}, + types::Type, + ValidationError, + }; + + use super::{ + assert_sets_equal, assert_typecheck_fails_empty_schema, assert_typechecks_empty_schema, + }; + + #[inline] + fn get_datetime_constructor_name() -> Name { + "datetime".parse().unwrap() + } + + #[inline] + fn get_duration_constructor_name() -> Name { + "duration".parse().unwrap() + } + + #[test] + fn less_than_typechecks() { + assert_typechecks_empty_schema( + Expr::less(Expr::val(1), Expr::val(2)), + Type::primitive_boolean(), + ); + assert_typechecks_empty_schema( + Expr::less( + Expr::call_extension_fn( + get_datetime_constructor_name(), + vec![Value::from("1970-01-01").into()], + ), + Expr::call_extension_fn( + get_datetime_constructor_name(), + vec![Value::from("1970-01-02").into()], + ), + ), + Type::primitive_boolean(), + ); + assert_typechecks_empty_schema( + Expr::lesseq( + Expr::call_extension_fn( + get_datetime_constructor_name(), + vec![Value::from("1970-01-01").into()], + ), + Expr::call_extension_fn( + get_datetime_constructor_name(), + vec![Value::from("1970-01-02").into()], + ), + ), + Type::primitive_boolean(), + ); + assert_typechecks_empty_schema( + Expr::less( + Expr::call_extension_fn( + get_duration_constructor_name(), + vec![Value::from("1h").into()], + ), + Expr::call_extension_fn( + get_duration_constructor_name(), + vec![Value::from("2h").into()], + ), + ), + Type::primitive_boolean(), + ); + assert_typechecks_empty_schema( + Expr::lesseq( + Expr::call_extension_fn( + get_duration_constructor_name(), + vec![Value::from("1h").into()], + ), + Expr::call_extension_fn( + get_duration_constructor_name(), + vec![Value::from("2h").into()], + ), + ), + Type::primitive_boolean(), + ); + } + + #[test] + fn less_than_typecheck_fails() { + let extensions = Extensions::all_available(); + let expected_types = extensions + .types_with_operator_overloading() + .cloned() + .map(Type::extension) + .chain(std::iter::once(Type::primitive_long())) + .collect_vec(); + let src = "true < false"; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ + ValidationError::expected_one_of_types( + get_loc(src, "true"), + expr_id_placeholder(), + expected_types.clone(), + Type::singleton_boolean(true), + None, + ), + ValidationError::expected_one_of_types( + get_loc(src, "false"), + expr_id_placeholder(), + expected_types.clone(), + Type::singleton_boolean(false), + None, + ), + ], + ); + + let src = "true < \"\""; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ + ValidationError::expected_one_of_types( + get_loc(src, "true"), + expr_id_placeholder(), + expected_types.clone(), + Type::singleton_boolean(true), + None, + ), + ValidationError::expected_one_of_types( + get_loc(src, "\"\""), + expr_id_placeholder(), + expected_types, + Type::primitive_string(), + None, + ), + ], + ); + + let src = "true < 1"; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ValidationError::expected_type( + get_loc(src, "true"), + expr_id_placeholder(), + Type::primitive_long(), + Type::singleton_boolean(true), + None, + )], + ); + + let src = r#"true < duration("1h")"#; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ValidationError::expected_type( + get_loc(src, "true"), + expr_id_placeholder(), + Type::ExtensionType { + name: get_duration_constructor_name(), + }, + Type::singleton_boolean(true), + None, + )], + ); + + // Error reporting favors long + let src = r#"duration("1d") < 1"#; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ValidationError::expected_type( + get_loc(src, r#"duration("1d")"#), + expr_id_placeholder(), + Type::primitive_long(), + Type::ExtensionType { + name: get_duration_constructor_name(), + }, + None, + )], + ); + + let src = r#"1 < duration("1d")"#; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ValidationError::expected_type( + get_loc(src, r#"duration("1d")"#), + expr_id_placeholder(), + Type::primitive_long(), + Type::ExtensionType { + name: get_duration_constructor_name(), + }, + None, + )], + ); + + let src = r#"datetime("1970-01-01") < duration("1d")"#; + let errors = + assert_typecheck_fails_empty_schema(src.parse().unwrap(), Type::primitive_boolean()); + assert_sets_equal( + errors, + [ValidationError::expected_type( + get_loc(src, r#"duration("1d")"#), + expr_id_placeholder(), + Type::ExtensionType { + name: get_datetime_constructor_name(), + }, + Type::ExtensionType { + name: get_duration_constructor_name(), + }, + None, + )], + ); + } +} diff --git a/cedar-policy-validator/src/typecheck/test/partial.rs b/cedar-policy-validator/src/typecheck/test/partial.rs index aa1f896c9..84d7cc693 100644 --- a/cedar-policy-validator/src/typecheck/test/partial.rs +++ b/cedar-policy-validator/src/typecheck/test/partial.rs @@ -398,15 +398,16 @@ mod fails_empty_schema { // We expect to see a type error for the incorrect literal argument to // various operators. No error should be generated for missing // attributes or the type of the attributes. - + let extensions = Extensions::all_available(); let src = r#"permit(principal, action, resource) when { principal.foo > "a" };"#; assert_typecheck_fails_empty_schema( parse_policy(None, src).unwrap(), [ValidationError::expected_one_of_types( get_loc(src, r#""a""#), PolicyID::from_string("policy0"), - Extensions::types_with_operator_overloading() - .into_iter() + extensions + .types_with_operator_overloading() + .cloned() .map(Type::extension) .chain(std::iter::once(Type::primitive_long())) .collect(), @@ -645,6 +646,7 @@ mod fail_partial_schema { #[test] fn error_on_declared_attr() { + let extensions = Extensions::all_available(); // `name` is declared as a `String` in the partial schema, so we can // error even though `principal.unknown` is not declared. let src = r#"permit(principal == User::"alice", action, resource) when { principal.name > principal.unknown };"#; @@ -653,8 +655,9 @@ mod fail_partial_schema { [ValidationError::expected_one_of_types( get_loc(src, "principal.name"), PolicyID::from_string("policy0"), - Extensions::types_with_operator_overloading() - .into_iter() + extensions + .types_with_operator_overloading() + .cloned() .map(Type::extension) .chain(std::iter::once(Type::primitive_long())) .collect(), diff --git a/cedar-policy-validator/src/types.rs b/cedar-policy-validator/src/types.rs index d8e22c4a9..bbd2711d9 100644 --- a/cedar-policy-validator/src/types.rs +++ b/cedar-policy-validator/src/types.rs @@ -671,15 +671,6 @@ impl Type { | Type::EntityOrRecord(EntityRecordKind::ActionEntity { .. }) ) } - - pub(crate) fn support_operator_overloading(&self) -> bool { - match self { - Self::ExtensionType { name } => { - Extensions::iter_type_with_operator_overloading().contains(name) - } - _ => false, - } - } } impl Display for Type { diff --git a/cedar-policy/CHANGELOG.md b/cedar-policy/CHANGELOG.md index 2163a3ebd..9abe40919 100644 --- a/cedar-policy/CHANGELOG.md +++ b/cedar-policy/CHANGELOG.md @@ -15,6 +15,7 @@ Cedar Language Version: TBD ### Added +- Implemented [RFC 80 (`datetime` extension)](https://github.com/strongdm/cedar-rfcs/blob/datetime-rfc/text/0080-datetime-extension.md) as an experimental feature under flag `datetime` (#1276, #1415) - Implemented [RFC 48 (schema annotations)](https://github.com/cedar-policy/rfcs/blob/main/text/0048-schema-annotations.md) (#1316) - New `.isEmpty()` operator on sets (#1358, resolving #1356) - New `Entity::new_with_tags()` and `Entity::tag()` functions (#1402, resolving #1374) diff --git a/cedar-policy/Cargo.toml b/cedar-policy/Cargo.toml index f5c29d119..5b8971071 100644 --- a/cedar-policy/Cargo.toml +++ b/cedar-policy/Cargo.toml @@ -41,6 +41,7 @@ default = ["ipaddr", "decimal"] # Cedar extensions ipaddr = ["cedar-policy-core/ipaddr", "cedar-policy-validator/ipaddr"] decimal = ["cedar-policy-core/decimal", "cedar-policy-validator/decimal"] +datetime = ["cedar-policy-core/datetime", "cedar-policy-validator/datetime"] # Features for memory or runtime profiling heap-profiling = ["dep:dhat"] @@ -48,7 +49,7 @@ corpus-timing = [] # Experimental features. # Enable all experimental features with `cargo build --features "experimental"` -experimental = ["partial-eval", "permissive-validate", "partial-validate", "level-validate", "entity-manifest", "protobufs"] +experimental = ["partial-eval", "permissive-validate", "partial-validate", "level-validate", "entity-manifest", "protobufs", "datetime"] entity-manifest = ["cedar-policy-validator/entity-manifest"] partial-eval = ["cedar-policy-core/partial-eval", "cedar-policy-validator/partial-eval"] permissive-validate = []