Skip to content

Commit

Permalink
Support heterogeneous comparisons for numeric types (#37)
Browse files Browse the repository at this point in the history
* Support heterogeneous comparisons for numeric types

* Fix complex type equality, add tests for heterogeneous equality, add int and uint functions

* Fix float and int comparison behavior

* Use standard error message in max

* Add test for out of bounds compares with ints and uints

* Add max tests for empty list/args

* Simplify max implementation, fix Arguments doc string to use simpler example
  • Loading branch information
fore5fire authored Feb 21, 2024
1 parent 2c8dd61 commit cca79e6
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 37 deletions.
2 changes: 2 additions & 0 deletions interpreter/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ impl<'a> Default for Context<'a> {
ctx.add_function("double", functions::double);
ctx.add_function("exists", functions::exists);
ctx.add_function("exists_one", functions::exists_one);
ctx.add_function("int", functions::int);
ctx.add_function("uint", functions::uint);
ctx
}
}
106 changes: 94 additions & 12 deletions interpreter/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use crate::resolvers::{Argument, Resolver};
use crate::ExecutionError;
use cel_parser::Expression;
use chrono::{DateTime, Duration, FixedOffset};
use std::cmp::Ordering;
use std::convert::TryInto;
use std::sync::Arc;

Expand Down Expand Up @@ -49,7 +50,7 @@ impl<'context> FunctionContext<'context> {
}

/// Returns an execution error for the currently execution function.
pub fn error(&self, message: &str) -> ExecutionError {
pub fn error<M: ToString>(&self, message: M) -> ExecutionError {
ExecutionError::function_error(self.name.as_str(), message)
}
}
Expand All @@ -74,7 +75,7 @@ pub fn size(ftx: &FunctionContext, value: Value) -> Result<i64> {
Value::Map(m) => m.map.len(),
Value::String(s) => s.len(),
Value::Bytes(b) => b.len(),
value => return Err(ftx.error(&format!("cannot determine the size of {:?}", value))),
value => return Err(ftx.error(format!("cannot determine the size of {:?}", value))),
};
Ok(size as i64)
}
Expand Down Expand Up @@ -153,18 +154,62 @@ pub fn string(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Value::UInt(v) => Value::String(v.to_string().into()),
Value::Float(v) => Value::String(v.to_string().into()),
Value::Bytes(v) => Value::String(Arc::new(String::from_utf8_lossy(v.as_slice()).into())),
v => return Err(ftx.error(&format!("cannot convert {:?} to string", v))),
v => return Err(ftx.error(format!("cannot convert {:?} to string", v))),
})
}

// Performs a type conversion on the target.
pub fn double(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => v.parse::<f64>().map(Value::Float).unwrap(),
Value::String(v) => v
.parse::<f64>()
.map(Value::Float)
.map_err(|e| ftx.error(format!("string parse error: {e}")))?,
Value::Float(v) => Value::Float(v),
Value::Int(v) => Value::Float(v as f64),
Value::UInt(v) => Value::Float(v as f64),
v => return Err(ftx.error(&format!("cannot convert {:?} to double", v))),
v => return Err(ftx.error(format!("cannot convert {:?} to double", v))),
})
}

// Performs a type conversion on the target.
pub fn uint(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => v
.parse::<u64>()
.map(Value::UInt)
.map_err(|e| ftx.error(format!("string parse error: {e}")))?,
Value::Float(v) => {
if v > u64::MAX as f64 || v < u64::MIN as f64 {
return Err(ftx.error("unsigned integer overflow"));
}
Value::UInt(v as u64)
}
Value::Int(v) => Value::UInt(
v.try_into()
.map_err(|_| ftx.error("unsigned integer overflow"))?,
),
Value::UInt(v) => Value::UInt(v),
v => return Err(ftx.error(format!("cannot convert {:?} to uint", v))),
})
}

// Performs a type conversion on the target.
pub fn int(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
Ok(match this {
Value::String(v) => v
.parse::<i64>()
.map(Value::Int)
.map_err(|e| ftx.error(format!("string parse error: {e}")))?,
Value::Float(v) => {
if v > i64::MAX as f64 || v < i64::MIN as f64 {
return Err(ftx.error("integer overflow"));
}
Value::Int(v as i64)
}
Value::Int(v) => Value::Int(v),
Value::UInt(v) => Value::Int(v.try_into().map_err(|_| ftx.error("integer overflow"))?),
v => return Err(ftx.error(format!("cannot convert {:?} to int", v))),
})
}

Expand Down Expand Up @@ -434,13 +479,26 @@ pub fn timestamp(value: Arc<String>) -> Result<Value> {

pub fn max(Arguments(args): Arguments) -> Result<Value> {
// If items is a list of values, then operate on the list
if args.len() == 1 {
return Ok(match args[0] {
Value::List(ref values) => values.iter().max().cloned().unwrap_or(Value::Null),
_ => args[0].clone(),
});
}
args.iter().max().cloned().unwrap_or(Value::Null).into()
let items = if args.len() == 1 {
match &args[0] {
Value::List(values) => values,
_ => return Ok(args[0].clone()),
}
} else {
&args
};

items
.iter()
.skip(1)
.try_fold(items.first().unwrap_or(&Value::Null), |acc, x| {
match acc.partial_cmp(x) {
Some(Ordering::Greater) => Ok(acc),
Some(_) => Ok(x),
None => Err(ExecutionError::ValuesNotComparable(acc.clone(), x.clone())),
}
})
.map(|v| v.clone())
}

/// A wrapper around [`parse_duration`] that converts errors into [`ExecutionError`].
Expand Down Expand Up @@ -557,6 +615,8 @@ mod tests {
("max negative", "max(-1, 0) == 0"),
("max float", "max(-1.0, 0.0) == 0.0"),
("max list", "max([1, 2, 3]) == 3"),
("max empty list", "max([]) == null"),
("max no args", "max() == null"),
]
.iter()
.for_each(assert_script);
Expand Down Expand Up @@ -650,4 +710,26 @@ mod tests {
.iter()
.for_each(assert_script);
}

#[test]
fn test_uint() {
[
("string", "'10'.uint() == 10.uint()"),
("double", "10.5.uint() == 10.uint()"),
]
.iter()
.for_each(assert_script);
}

#[test]
fn test_int() {
[
("string", "'10'.int() == 10"),
("int", "10.int() == 10"),
("uint", "10.uint().int() == 10"),
("double", "10.5.int() == 10"),
]
.iter()
.for_each(assert_script);
}
}
5 changes: 4 additions & 1 deletion interpreter/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ pub enum ExecutionError {
/// called with at least one parameter.
#[error("Missing argument or target")]
MissingArgumentOrTarget,
/// Indicates that a comparison could not be performed.
#[error("{0:?} can not be compared to {1:?}")]
ValuesNotComparable(Value, Value),
/// Indicates that a function had an error during execution.
#[error("Error executing function '{function}': {message}")]
FunctionError { function: String, message: String },
Expand All @@ -76,7 +79,7 @@ impl ExecutionError {
ExecutionError::InvalidArgumentCount { expected, actual }
}

pub fn function_error(function: &str, error: &str) -> Self {
pub fn function_error<E: ToString>(function: &str, error: E) -> Self {
ExecutionError::FunctionError {
function: function.to_string(),
message: error.to_string(),
Expand Down
13 changes: 9 additions & 4 deletions interpreter/src/magic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,19 +217,24 @@ impl FromValue for List {

/// An argument extractor that extracts all the arguments passed to a function, resolves their
/// expressions and returns a vector of [`Value`]. This is useful for functions that accept a
/// variable number of arguments rather than known arguments and types (for example the `max`
/// variable number of arguments rather than known arguments and types (for example a `sum`
/// function).
///
/// # Example
/// ```javascript
/// max(1, 2, 3) == 3
/// sum(1, 2.0, uint(3)) == 5.0
/// ```
///
/// ```rust
/// # use cel_interpreter::{Value};
/// use cel_interpreter::extractors::Arguments;
/// pub fn max(Arguments(args): Arguments) -> Value {
/// args.iter().max().cloned().unwrap_or(Value::Null).into()
/// pub fn sum(Arguments(args): Arguments) -> Value {
/// args.iter().fold(0.0, |acc, val| match val {
/// Value::Int(x) => *x as f64 + acc,
/// Value::UInt(x) => *x as f64 + acc,
/// Value::Float(x) => *x + acc,
/// _ => acc,
/// }).into()
/// }
/// ```
#[derive(Clone)]
Expand Down
Loading

0 comments on commit cca79e6

Please sign in to comment.