diff --git a/engine/src/ast/field_expr.rs b/engine/src/ast/field_expr.rs index c243d379..3c7f8c5d 100644 --- a/engine/src/ast/field_expr.rs +++ b/engine/src/ast/field_expr.rs @@ -309,9 +309,9 @@ mod tests { use ast::function_expr::{FunctionCallArgExpr, FunctionCallExpr}; use cidr::{Cidr, IpCidr}; use execution_context::ExecutionContext; - use functions::{Function, FunctionArg, FunctionArgKind, FunctionImpl}; + use functions::{Function, FunctionArg, FunctionArgKind, FunctionImpl, FunctionOptArg}; use lazy_static::lazy_static; - use rhs_types::IpRange; + use rhs_types::{Bytes, IpRange}; use std::net::IpAddr; fn echo_function<'a>(args: &[LhsValue<'a>]) -> LhsValue<'a> { @@ -330,6 +330,20 @@ mod tests { } } + fn concat_function<'a, 'r>(args: &'r [LhsValue<'a>]) -> LhsValue<'a> { + match (&args[0], &args[1]) { + (LhsValue::Bytes(buf1), LhsValue::Bytes(buf2)) => { + let mut vec1 = buf1.to_vec(); + vec1.extend_from_slice(&*buf2); + LhsValue::Bytes(vec1.into()) + } + _ => panic!( + "Invalid types: expected (Bytes, Bytes), got ({:?}, {:?})", + args[0], args[1] + ), + } + } + lazy_static! { static ref SCHEME: Scheme = { let mut scheme: Scheme = Scheme! { @@ -367,6 +381,23 @@ mod tests { ) .unwrap(); scheme + .add_function( + "concat".into(), + Function { + args: vec![FunctionArg { + arg_kind: FunctionArgKind::Field, + val_type: Type::Bytes, + }], + opt_args: vec![FunctionOptArg { + arg_kind: FunctionArgKind::Literal, + default_value: RhsValue::Bytes(Bytes::from("".to_owned())), + }], + return_type: Type::Bytes, + implementation: FunctionImpl::new(concat_function), + }, + ) + .unwrap(); + scheme }; } @@ -889,4 +920,105 @@ mod tests { ctx.set_field_value("http.host", "EXAMPLE.ORG").unwrap(); assert_eq!(expr.execute(ctx), true); } + + #[test] + fn test_bytes_compare_with_concat_function() { + let expr = assert_ok!( + FieldExpr::lex_with(r#"concat(http.host) == "example.org""#, &SCHEME), + FieldExpr { + lhs: LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { + name: String::from("concat"), + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field(field("http.host"))), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from("".to_owned()))), + ], + }), + op: FieldOp::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("example.org".to_owned().into()) + } + } + ); + + assert_json!( + expr, + { + "lhs": { + "name": "concat", + "args": [ + { + "kind": "LhsFieldExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": "" + }, + ] + }, + "op": "Equal", + "rhs": "example.org" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value("http.host", "example.org").unwrap(); + assert_eq!(expr.execute(ctx), true); + + ctx.set_field_value("http.host", "example.co.uk").unwrap(); + assert_eq!(expr.execute(ctx), false); + + let expr = assert_ok!( + FieldExpr::lex_with(r#"concat(http.host, ".org") == "example.org""#, &SCHEME), + FieldExpr { + lhs: LhsFieldExpr::FunctionCallExpr(FunctionCallExpr { + name: String::from("concat"), + function: SCHEME.get_function("concat").unwrap(), + args: vec![ + FunctionCallArgExpr::LhsFieldExpr(LhsFieldExpr::Field(field("http.host"))), + FunctionCallArgExpr::Literal(RhsValue::Bytes(Bytes::from( + ".org".to_owned() + ))), + ], + }), + op: FieldOp::Ordering { + op: OrderingOp::Equal, + rhs: RhsValue::Bytes("example.org".to_owned().into()) + } + } + ); + + assert_json!( + expr, + { + "lhs": { + "name": "concat", + "args": [ + { + "kind": "LhsFieldExpr", + "value": "http.host" + }, + { + "kind": "Literal", + "value": ".org" + }, + ] + }, + "op": "Equal", + "rhs": "example.org" + } + ); + + let expr = expr.compile(); + let ctx = &mut ExecutionContext::new(&SCHEME); + + ctx.set_field_value("http.host", "example").unwrap(); + assert_eq!(expr.execute(ctx), true); + + ctx.set_field_value("http.host", "cloudflare").unwrap(); + assert_eq!(expr.execute(ctx), false); + } }