diff --git a/cedar-policy-validator/src/coreschema.rs b/cedar-policy-validator/src/coreschema.rs index 433ab943b..a19991860 100644 --- a/cedar-policy-validator/src/coreschema.rs +++ b/cedar-policy-validator/src/coreschema.rs @@ -13,13 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - use crate::{ValidatorEntityType, ValidatorSchema}; use cedar_policy_core::extensions::{ExtensionFunctionLookupError, Extensions}; use cedar_policy_core::{ast, entities}; use miette::Diagnostic; use smol_str::SmolStr; -use std::collections::{HashMap, HashSet}; +use std::collections::hash_map::Values; +use std::collections::HashSet; +use std::iter::Cloned; use std::sync::Arc; use thiserror::Error; @@ -28,37 +29,25 @@ use thiserror::Error; pub struct CoreSchema<'a> { /// Contains all the information schema: &'a ValidatorSchema, - /// For easy lookup, this is a map from action name to `Entity` object - /// for each action in the schema. This information is contained in the - /// `ValidatorSchema`, but not efficient to extract -- getting the `Entity` - /// from the `ValidatorSchema` is O(N) as of this writing, but with this - /// cache it's O(1). - actions: HashMap>, } impl<'a> CoreSchema<'a> { /// Create a new `CoreSchema` for the given `ValidatorSchema` pub fn new(schema: &'a ValidatorSchema) -> Self { - Self { - actions: schema - .action_entities_iter() - .map(|e| (e.uid().clone(), Arc::new(e))) - .collect(), - schema, - } + Self { schema } } } impl<'a> entities::Schema for CoreSchema<'a> { type EntityTypeDescription = EntityTypeDescription; - type ActionEntityIterator = Vec>; + type ActionEntityIterator = Cloned>>; fn entity_type(&self, entity_type: &ast::EntityType) -> Option { EntityTypeDescription::new(self.schema, entity_type) } fn action(&self, action: &ast::EntityUID) -> Option> { - self.actions.get(action).cloned() + self.schema.actions.get(action).cloned() } fn entity_types_with_basename<'b>( @@ -79,7 +68,7 @@ impl<'a> entities::Schema for CoreSchema<'a> { } fn action_entities(&self) -> Self::ActionEntityIterator { - self.actions.values().map(Arc::clone).collect() + self.schema.actions.values().cloned() } } diff --git a/cedar-policy-validator/src/schema.rs b/cedar-policy-validator/src/schema.rs index b94a1ad85..2abe5364a 100644 --- a/cedar-policy-validator/src/schema.rs +++ b/cedar-policy-validator/src/schema.rs @@ -20,9 +20,6 @@ //! `member_of` relation from the schema is reversed and the transitive closure is //! computed to obtain a `descendants` relation. -use std::collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; -use std::str::FromStr; - use cedar_policy_core::{ ast::{Entity, EntityType, EntityUID, InternalName, Name, UnreservedId}, entities::{err::EntitiesError, Entities, TCComputation}, @@ -34,6 +31,9 @@ use nonempty::NonEmpty; use serde::{Deserialize, Serialize}; use serde_with::serde_as; use smol_str::ToSmolStr; +use std::collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}; +use std::str::FromStr; +use std::sync::Arc; use crate::{ cedar_schema::SchemaWarning, @@ -132,8 +132,8 @@ impl ValidatorSchemaFragment { )) } - /// Convert this [`ValidatorSchemaFragment`] into a - /// [`ValidatorSchemaFragment`] by fully-qualifying all typenames that + /// Convert this [`ValidatorSchemaFragment`] into a + /// [`ValidatorSchemaFragment`] by fully-qualifying all typenames that /// appear anywhere in any definitions. /// /// `all_defs` needs to contain the full set of all fully-qualified typenames @@ -170,6 +170,14 @@ pub struct ValidatorSchema { /// Map from action id names to the [`ValidatorActionId`] object. #[serde_as(as = "Vec<(_, _)>")] action_ids: HashMap, + + /// For easy lookup, this is a map from action name to `Entity` object + /// for each action in the schema. This information is contained in the + /// `ValidatorSchema`, but not efficient to extract -- getting the `Entity` + /// from the `ValidatorSchema` is O(N) as of this writing, but with this + /// cache it's O(1). + #[serde_as(as = "Vec<(_, _)>")] + pub(crate) actions: HashMap>, } /// Construct [`ValidatorSchema`] from a string containing a schema formatted @@ -288,6 +296,7 @@ impl ValidatorSchema { Self { entity_types: HashMap::new(), action_ids: HashMap::new(), + actions: HashMap::new(), } } @@ -590,9 +599,14 @@ impl ValidatorSchema { common_types.into_values(), )?; + let actions = Self::action_entities_iter(&action_ids) + .map(|e| (e.uid().clone(), Arc::new(e))) + .collect(); + Ok(ValidatorSchema { entity_types, action_ids, + actions, }) } @@ -813,7 +827,7 @@ impl ValidatorSchema { /// Invert the action hierarchy to get the ancestor relation expected for /// the `Entity` datatype instead of descendants as stored by the schema. pub(crate) fn action_entities_iter( - &self, + action_ids: &HashMap, ) -> impl Iterator + '_ { // We could store the un-inverted `memberOf` relation for each action, // but I [john-h-kastner-aws] judge that the current implementation is @@ -821,7 +835,7 @@ impl ValidatorSchema { // structures through some complicated bits of schema construction code, // and avoids computing the TC twice. let mut action_ancestors: HashMap<&EntityUID, HashSet> = HashMap::new(); - for (action_euid, action_def) in &self.action_ids { + for (action_euid, action_def) in action_ids { for descendant in &action_def.descendants { action_ancestors .entry(descendant) @@ -829,7 +843,7 @@ impl ValidatorSchema { .insert(action_euid.clone()); } } - self.action_ids.iter().map(move |(action_id, action)| { + action_ids.iter().map(move |(action_id, action)| { Entity::new_with_attr_partial_value_serialized_as_expr( action_id.clone(), action.attributes.clone(), @@ -842,7 +856,7 @@ impl ValidatorSchema { pub fn action_entities(&self) -> std::result::Result { let extensions = Extensions::all_available(); Entities::from_entities( - self.action_entities_iter(), + self.actions.values().map(|entity| entity.as_ref().clone()), None::<&cedar_policy_core::entities::NoEntitiesSchema>, // we don't want to tell `Entities::from_entities()` to add the schema's action entities, that would infinitely recurse TCComputation::AssumeAlreadyComputed, extensions, @@ -880,6 +894,28 @@ impl From<&proto::ValidatorSchema> for ValidatorSchema { // PANIC SAFETY: experimental feature #[allow(clippy::expect_used)] fn from(v: &proto::ValidatorSchema) -> Self { + let action_ids = v + .action_ids + .iter() + .map(|kvp| { + let k = ast::EntityUID::from( + kvp.key + .as_ref() + .expect("`as_ref()` for field that should exist"), + ); + let v = ValidatorActionId::from( + kvp.value + .as_ref() + .expect("`as_ref()` for field that should exist"), + ); + (k, v) + }) + .collect(); + + let actions = Self::action_entities_iter(&action_ids) + .map(|e| (e.uid().clone(), Arc::new(e))) + .collect(); + Self { entity_types: v .entity_types @@ -898,23 +934,8 @@ impl From<&proto::ValidatorSchema> for ValidatorSchema { (k, v) }) .collect(), - action_ids: v - .action_ids - .iter() - .map(|kvp| { - let k = ast::EntityUID::from( - kvp.key - .as_ref() - .expect("`as_ref()` for field that should exist"), - ); - let v = ValidatorActionId::from( - kvp.value - .as_ref() - .expect("`as_ref()` for field that should exist"), - ); - (k, v) - }) - .collect(), + action_ids, + actions, } } } @@ -2665,8 +2686,7 @@ pub(crate) mod test { let schema_fragment = json_schema::Fragment::from_json_value(src).expect("Failed to parse schema"); let schema: ValidatorSchema = schema_fragment.try_into().expect("Schema should construct"); - let view_photo = schema - .action_entities_iter() + let view_photo = ValidatorSchema::action_entities_iter(&schema.action_ids) .find(|e| e.uid() == &r#"ExampleCo::Personnel::Action::"viewPhoto""#.parse().unwrap()) .unwrap(); let ancestors = view_photo.ancestors().collect::>(); @@ -2726,8 +2746,7 @@ pub(crate) mod test { let schema_fragment = json_schema::Fragment::from_json_value(src).expect("Failed to parse schema"); let schema: ValidatorSchema = schema_fragment.try_into().unwrap(); - let view_photo = schema - .action_entities_iter() + let view_photo = ValidatorSchema::action_entities_iter(&schema.action_ids) .find(|e| e.uid() == &r#"ExampleCo::Personnel::Action::"viewPhoto""#.parse().unwrap()) .unwrap(); let ancestors = view_photo.ancestors().collect::>(); diff --git a/cedar-policy/CHANGELOG.md b/cedar-policy/CHANGELOG.md index 5731529e2..2a7f8864d 100644 --- a/cedar-policy/CHANGELOG.md +++ b/cedar-policy/CHANGELOG.md @@ -22,6 +22,8 @@ Cedar Language Version: TBD includes a suggestion based on available extension functions (#1280, resolving #332). - The error associated with parsing a non-existent extension method additionally includes a suggestion based on available extension methods (#1289, resolving #246). +- Extract action graph inversion from `CoreSchema` to `ValidatorSchema` instantiation + to improve schema validation speeds. (#1290, as part of resolving #1285) ### Fixed