diff --git a/cedar-policy-core/src/est.rs b/cedar-policy-core/src/est.rs index 30b7b7c65..36b2b7142 100644 --- a/cedar-policy-core/src/est.rs +++ b/cedar-policy-core/src/est.rs @@ -166,7 +166,7 @@ impl TryFrom for Clause { }); Err(ParseError::ToAST(ToASTError::EmptyClause(ident)).into()) } - Some(e) => e.try_into(), + Some(ref e) => e.try_into(), }; let expr = match expr { Ok(expr) => Some(expr), diff --git a/cedar-policy-core/src/est/expr.rs b/cedar-policy-core/src/est/expr.rs index 8c2e090ec..258a6cab6 100644 --- a/cedar-policy-core/src/est/expr.rs +++ b/cedar-policy-core/src/est/expr.rs @@ -22,7 +22,7 @@ use crate::entities::{ }; use crate::extensions::Extensions; use crate::parser::cst::{self, Ident}; -use crate::parser::err::{ParseError, ParseErrors, ToASTError}; +use crate::parser::err::{ParseErrors, ToASTError}; use crate::parser::unescape; use crate::parser::ASTNode; use crate::{ast, FromNormalizedStr}; @@ -783,10 +783,10 @@ impl From for Expr { } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(e: ASTNode>) -> Result { - match *e.node.ok_or(ToASTError::MissingNodeData)?.expr { + fn try_from(e: &ASTNode>) -> Result { + match &*e.ok_or_missing()?.expr { cst::ExprData::Or(node) => node.try_into(), cst::ExprData::If(if_node, then_node, else_node) => { let cond_expr = if_node.try_into()?; @@ -798,12 +798,12 @@ impl TryFrom>> for Expr { } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(o: ASTNode>) -> Result { - let o_node = o.node.ok_or(ToASTError::MissingNodeData)?; - let mut expr = o_node.initial.try_into()?; - for node in o_node.extended { + fn try_from(o: &ASTNode>) -> Result { + let o_node = o.ok_or_missing()?; + let mut expr = (&o_node.initial).try_into()?; + for node in &o_node.extended { let rhs = node.try_into()?; expr = Expr::or(expr, rhs); } @@ -811,12 +811,12 @@ impl TryFrom>> for Expr { } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(a: ASTNode>) -> Result { - let a_node = a.node.ok_or(ToASTError::MissingNodeData)?; - let mut expr = a_node.initial.try_into()?; - for node in a_node.extended { + fn try_from(a: &ASTNode>) -> Result { + let a_node = a.ok_or_missing()?; + let mut expr = (&a_node.initial).try_into()?; + for node in &a_node.extended { let rhs = node.try_into()?; expr = Expr::and(expr, rhs); } @@ -824,10 +824,10 @@ impl TryFrom>> for Expr { } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(r: ASTNode>) -> Result { - match r.node.ok_or(ToASTError::MissingNodeData)? { + fn try_from(r: &ASTNode>) -> Result { + match r.ok_or_missing()? { cst::Relation::Common { initial, extended } => { let mut expr = initial.try_into()?; for (op, node) in extended { @@ -860,16 +860,16 @@ impl TryFrom>> for Expr { } cst::Relation::Has { target, field } => { let target_expr = target.try_into()?; - match Expr::try_from(field.clone()) { + match Expr::try_from(field) { Ok(field_expr) => { let field_str = field_expr .into_string_literal() - .map_err(|_| ParseError::ToAST(ToASTError::HasNonLiteralRHS))?; + .map_err(|_| ToASTError::HasNonLiteralRHS)?; Ok(Expr::has_attr(target_expr, field_str)) } - Err(_) => match is_add_name(field.node.ok_or(ToASTError::MissingNodeData)?) { + Err(_) => match is_add_name(field.ok_or_missing()?) { Some(name) => Ok(Expr::has_attr(target_expr, name.to_string().into())), - None => Err(ParseError::ToAST(ToASTError::HasNonLiteralRHS).into()), + None => Err(ToASTError::HasNonLiteralRHS.into()), }, } } @@ -877,9 +877,9 @@ impl TryFrom>> for Expr { let target_expr = target.try_into()?; let pat_expr: Expr = pattern.try_into()?; let pat_str = pat_expr.into_string_literal().map_err(|e| { - ParseError::ToAST(ToASTError::InvalidPattern( + ToASTError::InvalidPattern( serde_json::to_string(&e).unwrap_or_else(|_| "".to_string()), - )) + ) })?; Ok(Expr::like(target_expr, pat_str)) } @@ -889,11 +889,7 @@ impl TryFrom>> for Expr { in_entity, } => { let target = target.try_into()?; - let type_str = entity_type - .node - .ok_or(ToASTError::MissingNodeData)? - .to_string() - .into(); + let type_str = entity_type.ok_or_missing()?.to_string().into(); match in_entity { Some(in_entity) => Ok(Expr::is_entity_type_in( target, @@ -907,12 +903,12 @@ impl TryFrom>> for Expr { } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(a: ASTNode>) -> Result { - let a_node = a.node.ok_or(ToASTError::MissingNodeData)?; - let mut expr = a_node.initial.try_into()?; - for (op, node) in a_node.extended { + fn try_from(a: &ASTNode>) -> Result { + let a_node = a.ok_or_missing()?; + let mut expr = (&a_node.initial).try_into()?; + for (op, node) in &a_node.extended { let rhs = node.try_into()?; match op { cst::AddOp::Plus => { @@ -929,9 +925,9 @@ impl TryFrom>> for Expr { /// Returns `Some` if this is just a cst::Name. For example the /// `foobar` in `context has foobar` -fn is_add_name(add: cst::Add) -> Option { +fn is_add_name(add: &cst::Add) -> Option<&cst::Name> { if add.extended.is_empty() { - match add.initial.node { + match &add.initial.node { Some(mult) => is_mult_name(mult), None => None, } @@ -942,9 +938,9 @@ fn is_add_name(add: cst::Add) -> Option { /// Returns `Some` if this is just a cst::Name. For example the /// `foobar` in `context has foobar` -fn is_mult_name(mult: cst::Mult) -> Option { +fn is_mult_name(mult: &cst::Mult) -> Option<&cst::Name> { if mult.extended.is_empty() { - match mult.initial.node { + match &mult.initial.node { Some(unary) => is_unary_name(unary), None => None, } @@ -955,9 +951,9 @@ fn is_mult_name(mult: cst::Mult) -> Option { /// Returns `Some` if this is just a cst::Name. For example the /// `foobar` in `context has foobar` -fn is_unary_name(unary: cst::Unary) -> Option { +fn is_unary_name(unary: &cst::Unary) -> Option<&cst::Name> { if unary.op.is_none() { - match unary.item.node { + match &unary.item.node { Some(mem) => is_mem_name(mem), None => None, } @@ -968,9 +964,9 @@ fn is_unary_name(unary: cst::Unary) -> Option { /// Returns `Some` if this is just a cst::Name. For example the /// `foobar` in `context has foobar` -fn is_mem_name(mem: cst::Member) -> Option { +fn is_mem_name(mem: &cst::Member) -> Option<&cst::Name> { if mem.access.is_empty() { - match mem.item.node { + match &mem.item.node { Some(primary) => is_primary_name(primary), None => None, } @@ -981,41 +977,37 @@ fn is_mem_name(mem: cst::Member) -> Option { /// Returns `Some` if this is just a cst::Name. For example the /// `foobar` in `context has foobar` -fn is_primary_name(primary: cst::Primary) -> Option { +fn is_primary_name(primary: &cst::Primary) -> Option<&cst::Name> { match primary { - cst::Primary::Name(node) => node.node, + cst::Primary::Name(node) => node.node.as_ref(), _ => None, } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(m: ASTNode>) -> Result { - let m_node = m.node.ok_or(ToASTError::MissingNodeData)?; - let mut expr = m_node.initial.try_into()?; - for (op, node) in m_node.extended { + fn try_from(m: &ASTNode>) -> Result { + let m_node = m.ok_or_missing()?; + let mut expr = (&m_node.initial).try_into()?; + for (op, node) in &m_node.extended { let rhs = node.try_into()?; match op { cst::MultOp::Times => { expr = Expr::mul(expr, rhs); } - cst::MultOp::Divide => { - return Err(ParseError::ToAST(ToASTError::UnsupportedDivision).into()) - } - cst::MultOp::Mod => { - return Err(ParseError::ToAST(ToASTError::UnsupportedModulo).into()) - } + cst::MultOp::Divide => return Err(ToASTError::UnsupportedDivision.into()), + cst::MultOp::Mod => return Err(ToASTError::UnsupportedModulo.into()), } } Ok(expr) } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(u: ASTNode>) -> Result { - let u_node = u.node.ok_or(ToASTError::MissingNodeData)?; - let inner = u_node.item.try_into()?; + fn try_from(u: &ASTNode>) -> Result { + let u_node = u.ok_or_missing()?; + let inner = (&u_node.item).try_into()?; match u_node.op { Some(cst::NegOp::Bang(0)) => Ok(inner), Some(cst::NegOp::Bang(1)) => Ok(Expr::not(inner)), @@ -1063,12 +1055,8 @@ impl TryFrom>> for Expr { } } } - Some(cst::NegOp::OverBang) => { - Err(ParseError::ToAST(ToASTError::UnaryOpLimit(ast::UnaryOp::Not)).into()) - } - Some(cst::NegOp::OverDash) => { - Err(ParseError::ToAST(ToASTError::UnaryOpLimit(ast::UnaryOp::Neg)).into()) - } + Some(cst::NegOp::OverBang) => Err(ToASTError::UnaryOpLimit(ast::UnaryOp::Not).into()), + Some(cst::NegOp::OverDash) => Err(ToASTError::UnaryOpLimit(ast::UnaryOp::Neg).into()), None => Ok(inner), } } @@ -1081,11 +1069,11 @@ impl TryFrom>> for Expr { /// handling, because in that case it is not a valid expression. In all other /// cases a `Primary` can be converted into an `Expr`.) fn interpret_primary( - p: ASTNode>, + p: &ASTNode>, ) -> Result, ParseErrors> { - match p.node.ok_or(ToASTError::MissingNodeData)? { + match p.ok_or_missing()? { cst::Primary::Literal(lit) => Ok(Either::Right(lit.try_into()?)), - cst::Primary::Ref(node) => match node.node.ok_or(ToASTError::MissingNodeData)? { + cst::Primary::Ref(node) => match node.ok_or_missing()? { cst::Ref::Uid { path, eid } => { let mut errs = ParseErrors::new(); let maybe_name = path.to_name(&mut errs); @@ -1103,13 +1091,11 @@ fn interpret_primary( _ => Err(errs), } } - cst::Ref::Ref { .. } => { - Err(ParseError::ToAST(ToASTError::UnsupportedEntityLiterals).into()) - } + cst::Ref::Ref { .. } => Err(ToASTError::UnsupportedEntityLiterals.into()), }, cst::Primary::Name(node) => { - let name = node.node.ok_or(ToASTError::MissingNodeData)?; - let base_name = name.name.node.ok_or(ToASTError::MissingNodeData)?; + let name = node.ok_or_missing()?; + let base_name = name.name.ok_or_missing()?; match (&name.path[..], base_name) { (&[], cst::Ident::Principal) => Ok(Either::Right(Expr::var(ast::Var::Principal))), (&[], cst::Ident::Action) => Ok(Either::Right(Expr::var(ast::Var::Action))), @@ -1118,14 +1104,12 @@ fn interpret_primary( (path, cst::Ident::Ident(id)) => Ok(Either::Left(ast::Name::new( id.parse()?, path.iter() - .map(|ASTNode { node, .. }| { - node.as_ref() - .ok_or_else(|| { - Into::::into(ToASTError::MissingNodeData) - }) + .map(|node| { + node.ok_or_missing() + .map_err(Into::into) .and_then(|id| id.to_string().parse().map_err(Into::into)) }) - .collect::, _>>()?, + .collect::, ParseErrors>>()?, ))), (path, id) => { let (l, r) = match (path.first(), path.last()) { @@ -1135,15 +1119,15 @@ fn interpret_primary( ), (_, _) => (0, 0), }; - Err(ParseError::ToAST(ToASTError::InvalidExpression(cst::Name { + Err(ToASTError::InvalidExpression(cst::Name { path: path.to_vec(), - name: ASTNode::new(Some(id), l, r), - })) + name: ASTNode::new(Some(id.clone()), l, r), + }) .into()) } } } - cst::Primary::Slot(node) => match node.node.ok_or(ToASTError::MissingNodeData)? { + cst::Primary::Slot(node) => match node.ok_or_missing()? { cst::Slot::Principal => Ok(Either::Right(Expr::slot(ast::SlotId::principal()))), cst::Slot::Resource => Ok(Either::Right(Expr::slot(ast::SlotId::resource()))), }, @@ -1157,7 +1141,7 @@ fn interpret_primary( cst::Primary::RInits(nodes) => nodes .into_iter() .map(|node| { - let cst::RecInit(k, v) = node.node.ok_or(ToASTError::MissingNodeData)?; + let cst::RecInit(k, v) = node.ok_or_missing()?; let mut errs = ParseErrors::new(); let s = k .to_expr_or_special(&mut errs) @@ -1177,34 +1161,24 @@ fn interpret_primary( } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(m: ASTNode>) -> Result { - let m_node = m.node.ok_or(ToASTError::MissingNodeData)?; - let mut item: Either = interpret_primary(m_node.item)?; - for access in m_node.access { - match access.node.ok_or(ToASTError::MissingNodeData)? { - cst::MemAccess::Field(node) => { - match node.node.ok_or(ToASTError::MissingNodeData)? { - cst::Ident::Ident(i) => { - item = match item { - Either::Left(name) => { - return Err(ParseError::ToAST(ToASTError::InvalidAccess( - name, i, - )) - .into()) - } - Either::Right(expr) => Either::Right(Expr::get_attr(expr, i)), - }; - } - i => { - return Err(ParseError::ToAST(ToASTError::InvalidIdentifier( - i.to_string(), - )) - .into()) - } + fn try_from(m: &ASTNode>) -> Result { + let m_node = m.ok_or_missing()?; + let mut item: Either = interpret_primary(&m_node.item)?; + for access in &m_node.access { + match access.ok_or_missing()? { + cst::MemAccess::Field(node) => match node.ok_or_missing()? { + cst::Ident::Ident(i) => { + item = match item { + Either::Left(name) => { + return Err(ToASTError::InvalidAccess(name, i.clone()).into()) + } + Either::Right(expr) => Either::Right(Expr::get_attr(expr, i.clone())), + }; } - } + i => return Err(ToASTError::InvalidIdentifier(i.to_string()).into()), + }, cst::MemAccess::Call(args) => { // we have item(args). We hope item is either: // - an `ast::Name`, in which case we have a standard function call @@ -1249,24 +1223,22 @@ impl TryFrom>> for Expr { } } } - _ => return Err(ParseError::ToAST(ToASTError::ExpressionCall).into()), + _ => return Err(ToASTError::ExpressionCall.into()), }; } cst::MemAccess::Index(node) => { let s = Expr::try_from(node)? .into_string_literal() - .map_err(|_| ParseError::ToAST(ToASTError::NonStringIndex))?; + .map_err(|_| ToASTError::NonStringIndex)?; item = match item { - Either::Left(name) => { - return Err(ParseError::ToAST(ToASTError::InvalidIndex(name, s)).into()) - } + Either::Left(name) => return Err(ToASTError::InvalidIndex(name, s).into()), Either::Right(expr) => Either::Right(Expr::get_attr(expr, s)), }; } } } match item { - Either::Left(_) => Err(ParseError::ToAST(ToASTError::MembershipInvariantViolation))?, + Either::Left(_) => Err(ToASTError::MembershipInvariantViolation)?, Either::Right(expr) => Ok(expr), } } @@ -1280,40 +1252,37 @@ fn extract_single_argument( let first = iter.next(); let second = iter.peek(); match (first, second) { - (None, _) => Err(ParseError::ToAST(ToASTError::wrong_arity(fn_name, 1, 0)).into()), - (Some(_), Some(_)) => { - Err(ParseError::ToAST(ToASTError::wrong_arity(fn_name, 1, iter.len())).into()) - } + (None, _) => Err(ToASTError::wrong_arity(fn_name, 1, 0).into()), + (Some(_), Some(_)) => Err(ToASTError::wrong_arity(fn_name, 1, iter.len()).into()), (Some(first), None) => Ok(first), } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(lit: ASTNode>) -> Result { - match lit.node.ok_or(ToASTError::MissingNodeData)? { + fn try_from(lit: &ASTNode>) -> Result { + match lit.ok_or_missing()? { cst::Literal::True => Ok(Expr::lit(CedarValueJson::Bool(true))), cst::Literal::False => Ok(Expr::lit(CedarValueJson::Bool(false))), cst::Literal::Num(n) => Ok(Expr::lit(CedarValueJson::Long( - n.try_into() - .map_err(|_| ParseError::ToAST(ToASTError::IntegerLiteralTooLarge(n)))?, + (*n).try_into() + .map_err(|_| ToASTError::IntegerLiteralTooLarge(*n))?, ))), - cst::Literal::Str(node) => match node.node.ok_or(ToASTError::MissingNodeData)? { - cst::Str::String(s) => Ok(Expr::lit(CedarValueJson::String(s))), - cst::Str::Invalid(invalid_str) => Err(ParseError::ToAST( - ToASTError::InvalidString(invalid_str.to_string()), - ) - .into()), + cst::Literal::Str(node) => match node.ok_or_missing()? { + cst::Str::String(s) => Ok(Expr::lit(CedarValueJson::String(s.clone()))), + cst::Str::Invalid(invalid_str) => { + Err(ToASTError::InvalidString(invalid_str.to_string()).into()) + } }, } } } -impl TryFrom>> for Expr { +impl TryFrom<&ASTNode>> for Expr { type Error = ParseErrors; - fn try_from(name: ASTNode>) -> Result { - let name_node = name.node.ok_or(ToASTError::MissingNodeData)?; - let base_name = name_node.name.node.ok_or(ToASTError::MissingNodeData)?; + fn try_from(name: &ASTNode>) -> Result { + let name_node = name.ok_or_missing()?; + let base_name = name_node.name.ok_or_missing()?; match (&name_node.path[..], base_name) { (&[], cst::Ident::Principal) => Ok(Expr::var(ast::Var::Principal)), (&[], cst::Ident::Action) => Ok(Expr::var(ast::Var::Action)), @@ -1327,10 +1296,10 @@ impl TryFrom>> for Expr { ), (_, _) => (0, 0), }; - Err(ParseError::ToAST(ToASTError::InvalidExpression(cst::Name { + Err(ToASTError::InvalidExpression(cst::Name { path: path.to_vec(), - name: ASTNode::new(Some(id), l, r), - })) + name: ASTNode::new(Some(id.clone()), l, r), + }) .into()) } } @@ -1368,6 +1337,8 @@ fn ident_to_str_len(i: &Ident) -> usize { // PANIC SAFETY: Unit Test Code #[allow(clippy::panic)] mod test { + use crate::parser::err::ParseError; + use super::*; #[test] fn test_invalid_expr_from_cst_name() { @@ -1379,7 +1350,7 @@ mod test { let name = ASTNode::new(Some(cst::Ident::Else), 13, 16); let cst_name = ASTNode::new(Some(cst::Name { path, name }), 0, 16); - match Expr::try_from(cst_name) { + match Expr::try_from(&cst_name) { Ok(_) => panic!("wrong error"), Err(e) => { assert!(e.len() == 1); diff --git a/cedar-policy-core/src/parser/node.rs b/cedar-policy-core/src/parser/node.rs index 9dd09571b..ab6b5b392 100644 --- a/cedar-policy-core/src/parser/node.rs +++ b/cedar-policy-core/src/parser/node.rs @@ -23,6 +23,8 @@ use std::ops::Range; use miette::{Diagnostic, LabeledSpan, Severity, SourceCode}; use serde::{Deserialize, Serialize}; +use super::err::ToASTError; + /// Describes where in policy source code a node in the CST or expression AST /// occurs. #[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)] @@ -292,4 +294,10 @@ impl ASTNode> { { f(self.node?, self.info) } + + /// Get node data if present, or return an error result for `MissingNodeData` + /// if it is `None`. + pub fn ok_or_missing(&self) -> Result<&N, ToASTError> { + self.node.as_ref().ok_or(ToASTError::MissingNodeData) + } }