From cca79e653adc0abf3d78d08d9a51d987f15de9e1 Mon Sep 17 00:00:00 2001 From: Landon Smith Date: Wed, 21 Feb 2024 15:07:24 -0700 Subject: [PATCH] Support heterogeneous comparisons for numeric types (#37) * Support heterogeneous comparisons for numeric types * Fix complex type equality, add tests for heterogeneous equality, add int and uint functions * Fix float and int comparison behavior * Use standard error message in max * Add test for out of bounds compares with ints and uints * Add max tests for empty list/args * Simplify max implementation, fix Arguments doc string to use simpler example --- interpreter/src/context.rs | 2 + interpreter/src/functions.rs | 106 ++++++++++++++++++++++--- interpreter/src/lib.rs | 5 +- interpreter/src/magic.rs | 13 +++- interpreter/src/objects.rs | 145 ++++++++++++++++++++++++++++++----- 5 files changed, 234 insertions(+), 37 deletions(-) diff --git a/interpreter/src/context.rs b/interpreter/src/context.rs index fd50b72..4c5d270 100644 --- a/interpreter/src/context.rs +++ b/interpreter/src/context.rs @@ -160,6 +160,8 @@ impl<'a> Default for Context<'a> { ctx.add_function("double", functions::double); ctx.add_function("exists", functions::exists); ctx.add_function("exists_one", functions::exists_one); + ctx.add_function("int", functions::int); + ctx.add_function("uint", functions::uint); ctx } } diff --git a/interpreter/src/functions.rs b/interpreter/src/functions.rs index 18c0d43..8a24bf4 100644 --- a/interpreter/src/functions.rs +++ b/interpreter/src/functions.rs @@ -6,6 +6,7 @@ use crate::resolvers::{Argument, Resolver}; use crate::ExecutionError; use cel_parser::Expression; use chrono::{DateTime, Duration, FixedOffset}; +use std::cmp::Ordering; use std::convert::TryInto; use std::sync::Arc; @@ -49,7 +50,7 @@ impl<'context> FunctionContext<'context> { } /// Returns an execution error for the currently execution function. - pub fn error(&self, message: &str) -> ExecutionError { + pub fn error(&self, message: M) -> ExecutionError { ExecutionError::function_error(self.name.as_str(), message) } } @@ -74,7 +75,7 @@ pub fn size(ftx: &FunctionContext, value: Value) -> Result { Value::Map(m) => m.map.len(), Value::String(s) => s.len(), Value::Bytes(b) => b.len(), - value => return Err(ftx.error(&format!("cannot determine the size of {:?}", value))), + value => return Err(ftx.error(format!("cannot determine the size of {:?}", value))), }; Ok(size as i64) } @@ -153,18 +154,62 @@ pub fn string(ftx: &FunctionContext, This(this): This) -> Result { Value::UInt(v) => Value::String(v.to_string().into()), Value::Float(v) => Value::String(v.to_string().into()), Value::Bytes(v) => Value::String(Arc::new(String::from_utf8_lossy(v.as_slice()).into())), - v => return Err(ftx.error(&format!("cannot convert {:?} to string", v))), + v => return Err(ftx.error(format!("cannot convert {:?} to string", v))), }) } // Performs a type conversion on the target. pub fn double(ftx: &FunctionContext, This(this): This) -> Result { Ok(match this { - Value::String(v) => v.parse::().map(Value::Float).unwrap(), + Value::String(v) => v + .parse::() + .map(Value::Float) + .map_err(|e| ftx.error(format!("string parse error: {e}")))?, Value::Float(v) => Value::Float(v), Value::Int(v) => Value::Float(v as f64), Value::UInt(v) => Value::Float(v as f64), - v => return Err(ftx.error(&format!("cannot convert {:?} to double", v))), + v => return Err(ftx.error(format!("cannot convert {:?} to double", v))), + }) +} + +// Performs a type conversion on the target. +pub fn uint(ftx: &FunctionContext, This(this): This) -> Result { + Ok(match this { + Value::String(v) => v + .parse::() + .map(Value::UInt) + .map_err(|e| ftx.error(format!("string parse error: {e}")))?, + Value::Float(v) => { + if v > u64::MAX as f64 || v < u64::MIN as f64 { + return Err(ftx.error("unsigned integer overflow")); + } + Value::UInt(v as u64) + } + Value::Int(v) => Value::UInt( + v.try_into() + .map_err(|_| ftx.error("unsigned integer overflow"))?, + ), + Value::UInt(v) => Value::UInt(v), + v => return Err(ftx.error(format!("cannot convert {:?} to uint", v))), + }) +} + +// Performs a type conversion on the target. +pub fn int(ftx: &FunctionContext, This(this): This) -> Result { + Ok(match this { + Value::String(v) => v + .parse::() + .map(Value::Int) + .map_err(|e| ftx.error(format!("string parse error: {e}")))?, + Value::Float(v) => { + if v > i64::MAX as f64 || v < i64::MIN as f64 { + return Err(ftx.error("integer overflow")); + } + Value::Int(v as i64) + } + Value::Int(v) => Value::Int(v), + Value::UInt(v) => Value::Int(v.try_into().map_err(|_| ftx.error("integer overflow"))?), + v => return Err(ftx.error(format!("cannot convert {:?} to int", v))), }) } @@ -434,13 +479,26 @@ pub fn timestamp(value: Arc) -> Result { pub fn max(Arguments(args): Arguments) -> Result { // If items is a list of values, then operate on the list - if args.len() == 1 { - return Ok(match args[0] { - Value::List(ref values) => values.iter().max().cloned().unwrap_or(Value::Null), - _ => args[0].clone(), - }); - } - args.iter().max().cloned().unwrap_or(Value::Null).into() + let items = if args.len() == 1 { + match &args[0] { + Value::List(values) => values, + _ => return Ok(args[0].clone()), + } + } else { + &args + }; + + items + .iter() + .skip(1) + .try_fold(items.first().unwrap_or(&Value::Null), |acc, x| { + match acc.partial_cmp(x) { + Some(Ordering::Greater) => Ok(acc), + Some(_) => Ok(x), + None => Err(ExecutionError::ValuesNotComparable(acc.clone(), x.clone())), + } + }) + .map(|v| v.clone()) } /// A wrapper around [`parse_duration`] that converts errors into [`ExecutionError`]. @@ -557,6 +615,8 @@ mod tests { ("max negative", "max(-1, 0) == 0"), ("max float", "max(-1.0, 0.0) == 0.0"), ("max list", "max([1, 2, 3]) == 3"), + ("max empty list", "max([]) == null"), + ("max no args", "max() == null"), ] .iter() .for_each(assert_script); @@ -650,4 +710,26 @@ mod tests { .iter() .for_each(assert_script); } + + #[test] + fn test_uint() { + [ + ("string", "'10'.uint() == 10.uint()"), + ("double", "10.5.uint() == 10.uint()"), + ] + .iter() + .for_each(assert_script); + } + + #[test] + fn test_int() { + [ + ("string", "'10'.int() == 10"), + ("int", "10.int() == 10"), + ("uint", "10.uint().int() == 10"), + ("double", "10.5.int() == 10"), + ] + .iter() + .for_each(assert_script); + } } diff --git a/interpreter/src/lib.rs b/interpreter/src/lib.rs index 2aac94c..a5a50be 100644 --- a/interpreter/src/lib.rs +++ b/interpreter/src/lib.rs @@ -58,6 +58,9 @@ pub enum ExecutionError { /// called with at least one parameter. #[error("Missing argument or target")] MissingArgumentOrTarget, + /// Indicates that a comparison could not be performed. + #[error("{0:?} can not be compared to {1:?}")] + ValuesNotComparable(Value, Value), /// Indicates that a function had an error during execution. #[error("Error executing function '{function}': {message}")] FunctionError { function: String, message: String }, @@ -76,7 +79,7 @@ impl ExecutionError { ExecutionError::InvalidArgumentCount { expected, actual } } - pub fn function_error(function: &str, error: &str) -> Self { + pub fn function_error(function: &str, error: E) -> Self { ExecutionError::FunctionError { function: function.to_string(), message: error.to_string(), diff --git a/interpreter/src/magic.rs b/interpreter/src/magic.rs index 8f89525..28cefd9 100644 --- a/interpreter/src/magic.rs +++ b/interpreter/src/magic.rs @@ -217,19 +217,24 @@ impl FromValue for List { /// An argument extractor that extracts all the arguments passed to a function, resolves their /// expressions and returns a vector of [`Value`]. This is useful for functions that accept a -/// variable number of arguments rather than known arguments and types (for example the `max` +/// variable number of arguments rather than known arguments and types (for example a `sum` /// function). /// /// # Example /// ```javascript -/// max(1, 2, 3) == 3 +/// sum(1, 2.0, uint(3)) == 5.0 /// ``` /// /// ```rust /// # use cel_interpreter::{Value}; /// use cel_interpreter::extractors::Arguments; -/// pub fn max(Arguments(args): Arguments) -> Value { -/// args.iter().max().cloned().unwrap_or(Value::Null).into() +/// pub fn sum(Arguments(args): Arguments) -> Value { +/// args.iter().fold(0.0, |acc, val| match val { +/// Value::Int(x) => *x as f64 + acc, +/// Value::UInt(x) => *x as f64 + acc, +/// Value::Float(x) => *x + acc, +/// _ => acc, +/// }).into() /// } /// ``` #[derive(Clone)] diff --git a/interpreter/src/objects.rs b/interpreter/src/objects.rs index 8005cf4..725d984 100644 --- a/interpreter/src/objects.rs +++ b/interpreter/src/objects.rs @@ -141,7 +141,7 @@ impl TryIntoValue for &Key { } } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone)] pub enum Value { List(Arc>), Map(Map), @@ -226,25 +226,73 @@ impl From<&Value> for Value { } } +impl PartialEq for Value { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Value::Map(a), Value::Map(b)) => a == b, + (Value::List(a), Value::List(b)) => a == b, + (Value::Function(a1, a2), Value::Function(b1, b2)) => a1 == b1 && a2 == b2, + (Value::Int(a), Value::Int(b)) => a == b, + (Value::UInt(a), Value::UInt(b)) => a == b, + (Value::Float(a), Value::Float(b)) => a == b, + (Value::String(a), Value::String(b)) => a == b, + (Value::Bytes(a), Value::Bytes(b)) => a == b, + (Value::Bool(a), Value::Bool(b)) => a == b, + (Value::Null, Value::Null) => true, + (Value::Duration(a), Value::Duration(b)) => a == b, + (Value::Timestamp(a), Value::Timestamp(b)) => a == b, + // Allow different numeric types to be compared without explicit casting. + (Value::Int(a), Value::UInt(b)) => a + .to_owned() + .try_into() + .and_then(|a: u64| Ok(a == *b)) + .unwrap_or(false), + (Value::Int(a), Value::Float(b)) => (*a as f64) == *b, + (Value::UInt(a), Value::Int(b)) => a + .to_owned() + .try_into() + .and_then(|a: i64| Ok(a == *b)) + .unwrap_or(false), + (Value::UInt(a), Value::Float(b)) => (*a as f64) == *b, + (Value::Float(a), Value::Int(b)) => *a == (*b as f64), + (Value::Float(a), Value::UInt(b)) => *a == (*b as f64), + (a, b) => panic!("unable to compare {:?} with {:?}", a, b), + } + } +} + impl Eq for Value {} impl PartialOrd for Value { fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Value { - fn cmp(&self, other: &Self) -> Ordering { match (self, other) { - (Value::Int(a), Value::Int(b)) => a.cmp(b), - (Value::UInt(a), Value::UInt(b)) => a.cmp(b), - (Value::Float(a), Value::Float(b)) => a.partial_cmp(b).unwrap_or(Ordering::Equal), - (Value::String(a), Value::String(b)) => a.cmp(b), - (Value::Bool(a), Value::Bool(b)) => a.cmp(b), - (Value::Null, Value::Null) => Ordering::Equal, - (Value::Duration(a), Value::Duration(b)) => a.cmp(b), - (Value::Timestamp(a), Value::Timestamp(b)) => a.cmp(b), + (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)), + (Value::UInt(a), Value::UInt(b)) => Some(a.cmp(b)), + (Value::Float(a), Value::Float(b)) => a.partial_cmp(b), + (Value::String(a), Value::String(b)) => Some(a.cmp(b)), + (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)), + (Value::Null, Value::Null) => Some(Ordering::Equal), + (Value::Duration(a), Value::Duration(b)) => Some(a.cmp(b)), + (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)), + // Allow different numeric types to be compared without explicit casting. + (Value::Int(a), Value::UInt(b)) => Some( + a.to_owned() + .try_into() + .and_then(|a: u64| Ok(a.cmp(b))) + // If the i64 doesn't fit into a u64 it must be less than 0. + .unwrap_or(Ordering::Less), + ), + (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b), + (Value::UInt(a), Value::Int(b)) => Some( + a.to_owned() + .try_into() + .and_then(|a: i64| Ok(a.cmp(b))) + // If the u64 doesn't fit into a i64 it must be greater than i64::MAX. + .unwrap_or(Ordering::Greater), + ), + (Value::UInt(a), Value::Float(b)) => (*a as f64).partial_cmp(b), + (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)), + (Value::Float(a), Value::UInt(b)) => a.partial_cmp(&(*b as f64)), (a, b) => panic!("unable to compare {:?} with {:?}", a, b), } } @@ -366,10 +414,26 @@ impl<'a> Value { let left = Value::resolve(left, ctx)?; let right = Value::resolve(right, ctx)?; let res = match op { - RelationOp::LessThan => left < right, - RelationOp::LessThanEq => left <= right, - RelationOp::GreaterThan => left > right, - RelationOp::GreaterThanEq => left >= right, + RelationOp::LessThan => { + left.partial_cmp(&right) + .ok_or(ExecutionError::ValuesNotComparable(left, right))? + == Ordering::Less + } + RelationOp::LessThanEq => { + left.partial_cmp(&right) + .ok_or(ExecutionError::ValuesNotComparable(left, right))? + != Ordering::Greater + } + RelationOp::GreaterThan => { + left.partial_cmp(&right) + .ok_or(ExecutionError::ValuesNotComparable(left, right))? + == Ordering::Greater + } + RelationOp::GreaterThanEq => { + left.partial_cmp(&right) + .ok_or(ExecutionError::ValuesNotComparable(left, right))? + != Ordering::Less + } RelationOp::Equals => right.eq(&left), RelationOp::NotEquals => right.ne(&left), RelationOp::In => match (left, right) { @@ -696,7 +760,7 @@ impl ops::Rem for Value { #[cfg(test)] mod tests { - use crate::{Context, Program}; + use crate::{objects::Key, Context, Program}; use std::collections::HashMap; #[test] @@ -710,4 +774,45 @@ mod tests { let value = program.execute(&context).unwrap(); assert_eq!(value, "application/json".into()); } + + #[test] + fn test_heterogeneous_compare() { + let context = Context::default(); + + let program = Program::compile("1 < uint(2)").unwrap(); + let value = program.execute(&context).unwrap(); + assert_eq!(value, true.into()); + + let program = Program::compile("1 < 1.1").unwrap(); + let value = program.execute(&context).unwrap(); + assert_eq!(value, true.into()); + + let program = Program::compile("uint(0) > -10").unwrap(); + let value = program.execute(&context).unwrap(); + assert_eq!( + value, + true.into(), + "negative signed ints should be less than uints" + ); + } + + #[test] + fn test_float_compare() { + let context = Context::default(); + + let program = Program::compile("1.0 > 0.0").unwrap(); + let value = program.execute(&context).unwrap(); + assert_eq!(value, true.into()); + + let program = Program::compile("double('NaN') == double('NaN')").unwrap(); + let value = program.execute(&context).unwrap(); + assert_eq!(value, false.into(), "NaN should not equal itself"); + + let program = Program::compile("1.0 > double('NaN')").unwrap(); + let result = program.execute(&context); + assert!( + result.is_err(), + "NaN should not be comparable with inequality operators" + ); + } }