diff --git a/src/ast.rs b/src/ast.rs index b4988d8..323b333 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -166,28 +166,61 @@ impl EvaluateResult { } } -macro_rules! impl_arithmetic { - ($trait:ident, $method:ident, $op:tt) => { - impl $trait for EvaluateResult { - type Output = EvaluateResult; - - fn $method(self, other: EvaluateResult) -> EvaluateResult { - use EvaluateResult::*; - match (self, other) { - (Number(x), Number(y)) => EvaluateResult::Number(x $op y), - (Number(x), Array(y)) => EvaluateResult::Array(x $op y), - (Array(x), Number(y)) => EvaluateResult::Array(x $op y), - (Array(x), Array(y)) => EvaluateResult::Array(x $op y), - } - } +impl Mul for EvaluateResult { + type Output = EvaluateResult; + + fn mul(self, other: EvaluateResult) -> EvaluateResult { + use EvaluateResult::*; + match (self, other) { + (Number(x), Number(y)) => EvaluateResult::Number(x * y), + (Number(x), Array(y)) => EvaluateResult::Array(x * y), + (Array(x), Number(y)) => EvaluateResult::Array(x * y), + (Array(x), Array(y)) => EvaluateResult::Array(x * y), + } + } +} + +impl Div for EvaluateResult { + type Output = EvaluateResult; + + fn div(self, other: EvaluateResult) -> EvaluateResult { + use EvaluateResult::*; + match (self, other) { + (Number(x), Number(y)) => EvaluateResult::Number(x / y), + (Number(x), Array(y)) => EvaluateResult::Array(x / y), + (Array(x), Number(y)) => EvaluateResult::Array(x / y), + (Array(x), Array(y)) => EvaluateResult::Array(x / y), + } + } +} + +impl Add for EvaluateResult { + type Output = EvaluateResult; + + fn add(self, other: EvaluateResult) -> EvaluateResult { + use EvaluateResult::*; + match (self, other) { + (Number(x), Number(y)) => EvaluateResult::Number(x + y), + (Number(x), Array(y)) => EvaluateResult::Array(x + y), + (Array(x), Number(y)) => EvaluateResult::Array(x + y), + (Array(x), Array(y)) => EvaluateResult::Array(x + y), } - }; + } } -impl_arithmetic!(Mul, mul, *); -impl_arithmetic!(Div, div, /); -impl_arithmetic!(Add, add, +); -impl_arithmetic!(Sub, sub, -); +impl Sub for EvaluateResult { + type Output = EvaluateResult; + + fn sub(self, other: EvaluateResult) -> EvaluateResult { + use EvaluateResult::*; + match (self, other) { + (Number(x), Number(y)) => EvaluateResult::Number(x - y), + (Number(x), Array(y)) => EvaluateResult::Array(x - y), + (Array(x), Number(y)) => EvaluateResult::Array(x - y), + (Array(x), Array(y)) => EvaluateResult::Array(x - y), + } + } +} impl Expr<'_> { pub fn get_representation<'a>(self) -> Result<&'a str, Box> { diff --git a/src/lib.rs b/src/lib.rs index b6f8703..0996565 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,3 @@ -#![feature(test)] -extern crate test; - use crate::ast::Expr; use crate::ast::ExprParams; use ast::EvaluateResult; @@ -25,132 +22,66 @@ extern crate lalrpop_util; lalrpop_mod!(formula_parser); mod ast; -#[cfg(test)] -mod tests { - use super::*; +#[test] +fn basic_execution_test() { use crate::ast::Expr::*; use crate::ast::Opcode; - use numpy::ndarray::Array; - use test::Bencher; - - #[test] - fn basic_execution_test() { - assert!(formula_parser::FormulaParser::new() - .parse(" eps = 22") - .is_ok()); - assert!(formula_parser::FormulaParser::new() - .parse("n = (22)") - .is_ok()); - assert!(formula_parser::FormulaParser::new() - .parse("eps = (22)") - .is_ok()); - let expr = formula_parser::FormulaParser::new() - .parse("n = 22 * 44 + 66") - .unwrap(); - assert_eq!( - Index(Box::new(Op( - Box::new(Op( - Box::new(Number(22.)), - Opcode::Mul, - Box::new(Number(44.)) - )), - Opcode::Add, - Box::new(Number(66.)), - ))), - *expr - ); - assert_eq!(&format!("{:?}", expr), "n = ((22.0 * 44.0) + 66.0)"); - let expr = formula_parser::FormulaParser::new() - .parse("eps = 3 * 22 ** 4") - .unwrap(); - assert_eq!(&format!("{:?}", expr), "eps = (3.0 * (22.0 ** 4.0))"); - let expr = formula_parser::FormulaParser::new() - .parse("eps = 3 * lbda ** 4") - .unwrap(); - assert_eq!(&format!("{:?}", expr), "eps = (3.0 * (lbda ** 4.0))"); - let expr = formula_parser::FormulaParser::new() - .parse("eps = sum[param]") - .unwrap(); - assert_eq!(&format!("{:?}", expr), "eps = sum[r_param]"); - assert!(formula_parser::FormulaParser::new() - .parse("n = ((((22))))") - .is_ok()); - assert!(formula_parser::FormulaParser::new() - .parse("n = sum[2 * 3] + sum[4*5]") - .is_ok()); - assert!(formula_parser::FormulaParser::new() - .parse("n = sum[sum [ 2 * lbda ] * 3] + sum[4*5]") - .is_err()); - assert!(formula_parser::FormulaParser::new() - .parse("n = ((((22))))") - .is_ok()); - assert!(formula_parser::FormulaParser::new() - .parse("eps = ((22)") - .is_err()); - assert!(formula_parser::FormulaParser::new() - .parse("something = ((22)") - .is_err()); - assert!(formula_parser::FormulaParser::new().parse("(22)").is_err()); - } - - #[bench] - fn formula_parsing(b: &mut Bencher) { - let arr = Array::linspace(400., 1000., 100000); - let arr_view = arr.view(); - let single_params = HashMap::new(); - let mut rep_params = HashMap::new(); - - rep_params.insert("A", vec![1., 1., 1.]); - rep_params.insert("B", vec![0.1, 0.1, 0.1]); - b.iter(|| { - parse( - "eps = 1 + sum[A * (lbda * 1e-3)**2 / ((lbda * 1e-3)**2 - B)]", - "lbda", - &arr_view, - &single_params, - &rep_params, - ) - }); - } - - #[bench] - fn constant_value(b: &mut Bencher) { - let arr = Array::linspace(400., 1000., 100000); - let arr_view = arr.view(); - let single_params = HashMap::new(); - let rep_params = HashMap::new(); - - b.iter(|| parse("eps = 1", "lbda", &arr_view, &single_params, &rep_params)); - } - - #[bench] - fn return_axis(b: &mut Bencher) { - let arr = Array::linspace(400., 1000., 100000); - let arr_view = arr.view(); - let single_params = HashMap::new(); - let rep_params = HashMap::new(); - - b.iter(|| parse("eps = lbda", "lbda", &arr_view, &single_params, &rep_params)); - } - - #[bench] - fn simple_sum(b: &mut Bencher) { - let arr = Array::linspace(400., 1000., 100000); - let arr_view = arr.view(); - let single_params = HashMap::new(); - let mut rep_params = HashMap::new(); - rep_params.insert("A", vec![1., 2., 3., 4.]); - - b.iter(|| { - parse( - "eps = sum[A*lbda]", - "lbda", - &arr_view, - &single_params, - &rep_params, - ) - }); - } + assert!(formula_parser::FormulaParser::new() + .parse(" eps = 22") + .is_ok()); + assert!(formula_parser::FormulaParser::new() + .parse("n = (22)") + .is_ok()); + assert!(formula_parser::FormulaParser::new() + .parse("eps = (22)") + .is_ok()); + let expr = formula_parser::FormulaParser::new() + .parse("n = 22 * 44 + 66") + .unwrap(); + assert_eq!( + Index(Box::new(Op( + Box::new(Op( + Box::new(Number(22.)), + Opcode::Mul, + Box::new(Number(44.)) + )), + Opcode::Add, + Box::new(Number(66.)), + ))), + *expr + ); + assert_eq!(&format!("{:?}", expr), "n = ((22.0 * 44.0) + 66.0)"); + let expr = formula_parser::FormulaParser::new() + .parse("eps = 3 * 22 ** 4") + .unwrap(); + assert_eq!(&format!("{:?}", expr), "eps = (3.0 * (22.0 ** 4.0))"); + let expr = formula_parser::FormulaParser::new() + .parse("eps = 3 * lbda ** 4") + .unwrap(); + assert_eq!(&format!("{:?}", expr), "eps = (3.0 * (lbda ** 4.0))"); + let expr = formula_parser::FormulaParser::new() + .parse("eps = sum[param]") + .unwrap(); + assert_eq!(&format!("{:?}", expr), "eps = sum[r_param]"); + assert!(formula_parser::FormulaParser::new() + .parse("n = ((((22))))") + .is_ok()); + assert!(formula_parser::FormulaParser::new() + .parse("n = sum[2 * 3] + sum[4*5]") + .is_ok()); + assert!(formula_parser::FormulaParser::new() + .parse("n = sum[sum [ 2 * lbda ] * 3] + sum[4*5]") + .is_err()); + assert!(formula_parser::FormulaParser::new() + .parse("n = ((((22))))") + .is_ok()); + assert!(formula_parser::FormulaParser::new() + .parse("eps = ((22)") + .is_err()); + assert!(formula_parser::FormulaParser::new() + .parse("something = ((22)") + .is_err()); + assert!(formula_parser::FormulaParser::new().parse("(22)").is_err()); } #[cached]