From 65f449b7689efab199fa36c5d40cc04542ed4076 Mon Sep 17 00:00:00 2001 From: Chloe Date: Tue, 9 Nov 2021 09:42:59 -0800 Subject: [PATCH] Optimized type converting in DSL filters (#272) * optimized cast in filter queries Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh * added unit tests Signed-off-by: chloe-zh * update Signed-off-by: chloe-zh --- .../sql/data/model/ExprTimestampValue.java | 4 +- .../operator/convert/TypeCastOperator.java | 4 +- .../script/filter/lucene/LuceneQuery.java | 151 +++++++- .../script/filter/lucene/RangeQuery.java | 11 +- .../script/filter/lucene/TermQuery.java | 11 +- .../script/filter/FilterQueryBuilderTest.java | 344 ++++++++++++++++++ 6 files changed, 516 insertions(+), 9 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java b/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java index f3ffdb7c42..a6bb92aca4 100644 --- a/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java +++ b/core/src/main/java/org/opensearch/sql/data/model/ExprTimestampValue.java @@ -56,7 +56,7 @@ public class ExprTimestampValue extends AbstractExprValue { /** * todo. only support timestamp in format yyyy-MM-dd HH:mm:ss. */ - private static final DateTimeFormatter FORMATTER_WITNOUT_NANO = DateTimeFormatter + private static final DateTimeFormatter FORMATTER_WITHOUT_NANO = DateTimeFormatter .ofPattern("yyyy-MM-dd HH:mm:ss"); private final Instant timestamp; @@ -92,7 +92,7 @@ public ExprTimestampValue(String timestamp) { @Override public String value() { - return timestamp.getNano() == 0 ? FORMATTER_WITNOUT_NANO.withZone(ZONE) + return timestamp.getNano() == 0 ? FORMATTER_WITHOUT_NANO.withZone(ZONE) .format(timestamp.truncatedTo(ChronoUnit.SECONDS)) : FORMATTER_VARIABLE_MICROS.withZone(ZONE).format(timestamp); } diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java index c6a84985a0..6eaa75ee46 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java @@ -101,9 +101,9 @@ private static FunctionResolver castToString() { private static FunctionResolver castToByte() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BYTE.getName(), impl(nullMissingHandling( - (v) -> new ExprByteValue(Short.valueOf(v.stringValue()))), BYTE, STRING), + (v) -> new ExprByteValue(Byte.valueOf(v.stringValue()))), BYTE, STRING), impl(nullMissingHandling( - (v) -> new ExprByteValue(v.shortValue())), BYTE, DOUBLE), + (v) -> new ExprByteValue(v.byteValue())), BYTE, DOUBLE), impl(nullMissingHandling( (v) -> new ExprByteValue(v.booleanValue() ? 1 : 0)), BYTE, BOOLEAN) ); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java index 3d2ed8720f..80c58d2ac9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQuery.java @@ -29,14 +29,32 @@ import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD; +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import java.util.function.Function; import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.data.model.ExprBooleanValue; +import org.opensearch.sql.data.model.ExprByteValue; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDatetimeValue; +import org.opensearch.sql.data.model.ExprDoubleValue; +import org.opensearch.sql.data.model.ExprFloatValue; +import org.opensearch.sql.data.model.ExprIntegerValue; +import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprShortValue; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprTimeValue; +import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.FunctionName; /** * Lucene query abstraction that builds Lucene query from function expression. @@ -55,7 +73,8 @@ public abstract class LuceneQuery { public boolean canSupport(FunctionExpression func) { return (func.getArguments().size() == 2) && (func.getArguments().get(0) instanceof ReferenceExpression) - && (func.getArguments().get(1) instanceof LiteralExpression) + && (func.getArguments().get(1) instanceof LiteralExpression + || literalExpressionWrappedByCast(func)) || isMultiParameterQuery(func); } @@ -74,18 +93,144 @@ private boolean isMultiParameterQuery(FunctionExpression func) { return true; } + /** + * Check if the second argument of the function is a literal expression wrapped by cast function. + */ + private boolean literalExpressionWrappedByCast(FunctionExpression func) { + if (func.getArguments().get(1) instanceof FunctionExpression) { + FunctionExpression expr = (FunctionExpression) func.getArguments().get(1); + return castMap.containsKey(expr.getFunctionName()) + && expr.getArguments().get(0) instanceof LiteralExpression; + } + return false; + } + /** * Build Lucene query from function expression. + * The cast function is converted to literal expressions before generating DSL. * * @param func function * @return query */ public QueryBuilder build(FunctionExpression func) { ReferenceExpression ref = (ReferenceExpression) func.getArguments().get(0); - LiteralExpression literal = (LiteralExpression) func.getArguments().get(1); - return doBuild(ref.getAttr(), ref.type(), literal.valueOf(null)); + Expression expr = func.getArguments().get(1); + ExprValue literalValue = expr instanceof LiteralExpression ? expr + .valueOf(null) : cast((FunctionExpression) expr); + return doBuild(ref.getAttr(), ref.type(), literalValue); } + private ExprValue cast(FunctionExpression castFunction) { + return castMap.get(castFunction.getFunctionName()).apply( + (LiteralExpression) castFunction.getArguments().get(0)); + } + + /** + * Type converting map. + */ + private final Map> castMap = ImmutableMap + .>builder() + .put(BuiltinFunctionName.CAST_TO_STRING.getName(), expr -> { + if (!expr.type().equals(ExprCoreType.STRING)) { + return new ExprStringValue(String.valueOf(expr.valueOf(null).value())); + } else { + return expr.valueOf(null); + } + }) + .put(BuiltinFunctionName.CAST_TO_BYTE.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return new ExprByteValue(expr.valueOf(null).byteValue()); + } else if (expr.type().equals(ExprCoreType.BOOLEAN)) { + return new ExprByteValue(expr.valueOf(null).booleanValue() ? 1 : 0); + } else { + return new ExprByteValue(Byte.valueOf(expr.valueOf(null).stringValue())); + } + }) + .put(BuiltinFunctionName.CAST_TO_SHORT.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return new ExprShortValue(expr.valueOf(null).shortValue()); + } else if (expr.type().equals(ExprCoreType.BOOLEAN)) { + return new ExprShortValue(expr.valueOf(null).booleanValue() ? 1 : 0); + } else { + return new ExprShortValue(Short.valueOf(expr.valueOf(null).stringValue())); + } + }) + .put(BuiltinFunctionName.CAST_TO_INT.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return new ExprIntegerValue(expr.valueOf(null).integerValue()); + } else if (expr.type().equals(ExprCoreType.BOOLEAN)) { + return new ExprIntegerValue(expr.valueOf(null).booleanValue() ? 1 : 0); + } else { + return new ExprIntegerValue(Integer.valueOf(expr.valueOf(null).stringValue())); + } + }) + .put(BuiltinFunctionName.CAST_TO_LONG.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return new ExprLongValue(expr.valueOf(null).longValue()); + } else if (expr.type().equals(ExprCoreType.BOOLEAN)) { + return new ExprLongValue(expr.valueOf(null).booleanValue() ? 1 : 0); + } else { + return new ExprLongValue(Long.valueOf(expr.valueOf(null).stringValue())); + } + }) + .put(BuiltinFunctionName.CAST_TO_FLOAT.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return new ExprFloatValue(expr.valueOf(null).floatValue()); + } else if (expr.type().equals(ExprCoreType.BOOLEAN)) { + return new ExprFloatValue(expr.valueOf(null).booleanValue() ? 1 : 0); + } else { + return new ExprFloatValue(Float.valueOf(expr.valueOf(null).stringValue())); + } + }) + .put(BuiltinFunctionName.CAST_TO_DOUBLE.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return new ExprDoubleValue(expr.valueOf(null).doubleValue()); + } else if (expr.type().equals(ExprCoreType.BOOLEAN)) { + return new ExprDoubleValue(expr.valueOf(null).booleanValue() ? 1 : 0); + } else { + return new ExprDoubleValue(Double.valueOf(expr.valueOf(null).stringValue())); + } + }) + .put(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), expr -> { + if (ExprCoreType.numberTypes().contains(expr.type())) { + return expr.valueOf(null).doubleValue() == 1 + ? ExprBooleanValue.of(true) : ExprBooleanValue.of(false); + } else if (expr.type().equals(ExprCoreType.STRING)) { + return ExprBooleanValue.of(Boolean.valueOf(expr.valueOf(null).stringValue())); + } else { + return expr.valueOf(null); + } + }) + .put(BuiltinFunctionName.CAST_TO_DATE.getName(), expr -> { + if (expr.type().equals(ExprCoreType.STRING)) { + return new ExprDateValue(expr.valueOf(null).stringValue()); + } else { + return new ExprDateValue(expr.valueOf(null).dateValue()); + } + }) + .put(BuiltinFunctionName.CAST_TO_TIME.getName(), expr -> { + if (expr.type().equals(ExprCoreType.STRING)) { + return new ExprTimeValue(expr.valueOf(null).stringValue()); + } else { + return new ExprTimeValue(expr.valueOf(null).timeValue()); + } + }) + .put(BuiltinFunctionName.CAST_TO_DATETIME.getName(), expr -> { + if (expr.type().equals(ExprCoreType.STRING)) { + return new ExprDatetimeValue(expr.valueOf(null).stringValue()); + } else { + return new ExprDatetimeValue(expr.valueOf(null).datetimeValue()); + } + }) + .put(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), expr -> { + if (expr.type().equals(ExprCoreType.STRING)) { + return new ExprTimestampValue(expr.valueOf(null).stringValue()); + } else { + return new ExprTimestampValue(expr.valueOf(null).timestampValue()); + } + }) + .build(); + /** * Build method that subclass implements by default which is to build query * from reference and literal in function arguments. diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/RangeQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/RangeQuery.java index 6a09902f65..0be4ce713b 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/RangeQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/RangeQuery.java @@ -32,6 +32,7 @@ import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; /** @@ -51,7 +52,7 @@ public enum Comparison { @Override protected QueryBuilder doBuild(String fieldName, ExprType fieldType, ExprValue literal) { - Object value = literal.value(); + Object value = value(literal); RangeQueryBuilder query = QueryBuilders.rangeQuery(fieldName); switch (comparison) { @@ -68,4 +69,12 @@ protected QueryBuilder doBuild(String fieldName, ExprType fieldType, ExprValue l } } + private Object value(ExprValue literal) { + if (literal.type().equals(ExprCoreType.TIMESTAMP)) { + return literal.timestampValue().toEpochMilli(); + } else { + return literal.value(); + } + } + } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/TermQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/TermQuery.java index f784f4077a..7786f2b8f8 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/TermQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/TermQuery.java @@ -30,6 +30,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; /** @@ -40,7 +41,15 @@ public class TermQuery extends LuceneQuery { @Override protected QueryBuilder doBuild(String fieldName, ExprType fieldType, ExprValue literal) { fieldName = convertTextToKeyword(fieldName, fieldType); - return QueryBuilders.termQuery(fieldName, literal.value()); + return QueryBuilders.termQuery(fieldName, value(literal)); + } + + private Object value(ExprValue literal) { + if (literal.type().equals(ExprCoreType.TIMESTAMP)) { + return literal.timestampValue().toEpochMilli(); + } else { + return literal.value(); + } } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index ffbbb5feda..5a9f758665 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -31,27 +31,46 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; +import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; +import static org.opensearch.sql.data.type.ExprCoreType.BYTE; +import static org.opensearch.sql.data.type.ExprCoreType.DATE; +import static org.opensearch.sql.data.type.ExprCoreType.DATETIME; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.FLOAT; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.data.type.ExprCoreType.SHORT; import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.TIME; +import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.expression.DSL.literal; import static org.opensearch.sql.expression.DSL.ref; import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD; import com.google.common.collect.ImmutableMap; import java.util.Map; +import java.util.stream.Stream; import org.json.JSONObject; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprDateValue; +import org.opensearch.sql.data.model.ExprDatetimeValue; +import org.opensearch.sql.data.model.ExprTimeValue; +import org.opensearch.sql.data.model.ExprTimestampValue; +import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -61,6 +80,21 @@ class FilterQueryBuilderTest { private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + private static Stream numericCastSource() { + return Stream.of(literal((byte) 1), literal((short) 1), literal( + 1), literal(1L), literal(1F), literal(1D), literal(true), literal("1")); + } + + private static Stream booleanTrueCastSource() { + return Stream.of(literal((byte) 1), literal((short) 1), literal( + 1), literal(1L), literal(1F), literal(1D), literal(true), literal("true")); + } + + private static Stream booleanFalseCastSource() { + return Stream.of(literal((byte) 0), literal((short) 0), literal( + 0), literal(0L), literal(0F), literal(0D), literal(false), literal("false")); + } + @Mock private ExpressionSerializer serializer; @@ -352,6 +386,316 @@ void match_invalid_parameter() { "Parameter invalid_parameter is invalid for match function."); } + @Test + void cast_to_string_in_filter() { + String json = "{\n" + + " \"term\" : {\n" + + " \"string_value\" : {\n" + + " \"value\" : \"1\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals(json, buildQuery( + dsl.equal(ref("string_value", STRING), dsl.castString(literal(1))))); + assertJsonEquals(json, buildQuery( + dsl.equal(ref("string_value", STRING), dsl.castString(literal("1"))))); + } + + @ParameterizedTest(name = "castByte({0})") + @MethodSource({"numericCastSource"}) + void cast_to_byte_in_filter(LiteralExpression expr) { + assertJsonEquals( + "{\n" + + " \"term\" : {\n" + + " \"byte_value\" : {\n" + + " \"value\" : 1,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("byte_value", BYTE), dsl.castByte(expr)))); + } + + @ParameterizedTest(name = "castShort({0})") + @MethodSource({"numericCastSource"}) + void cast_to_short_in_filter(LiteralExpression expr) { + assertJsonEquals( + "{\n" + + " \"term\" : {\n" + + " \"short_value\" : {\n" + + " \"value\" : 1,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("short_value", SHORT), dsl.castShort(expr)))); + } + + @ParameterizedTest(name = "castInt({0})") + @MethodSource({"numericCastSource"}) + void cast_to_int_in_filter(LiteralExpression expr) { + assertJsonEquals( + "{\n" + + " \"term\" : {\n" + + " \"integer_value\" : {\n" + + " \"value\" : 1,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("integer_value", INTEGER), dsl.castInt(expr)))); + } + + @ParameterizedTest(name = "castLong({0})") + @MethodSource({"numericCastSource"}) + void cast_to_long_in_filter(LiteralExpression expr) { + assertJsonEquals( + "{\n" + + " \"term\" : {\n" + + " \"long_value\" : {\n" + + " \"value\" : 1,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("long_value", LONG), dsl.castLong(expr)))); + } + + @ParameterizedTest(name = "castFloat({0})") + @MethodSource({"numericCastSource"}) + void cast_to_float_in_filter(LiteralExpression expr) { + assertJsonEquals( + "{\n" + + " \"term\" : {\n" + + " \"float_value\" : {\n" + + " \"value\" : 1.0,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("float_value", FLOAT), dsl.castFloat(expr)))); + } + + @ParameterizedTest(name = "castDouble({0})") + @MethodSource({"numericCastSource"}) + void cast_to_double_in_filter(LiteralExpression expr) { + assertJsonEquals( + "{\n" + + " \"term\" : {\n" + + " \"double_value\" : {\n" + + " \"value\" : 1.0,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("double_value", DOUBLE), dsl.castDouble(expr)))); + } + + @ParameterizedTest(name = "castBooleanTrue({0})") + @MethodSource({"booleanTrueCastSource"}) + void cast_to_boolean_true_in_filter(LiteralExpression expr) { + String json = "{\n" + + " \"term\" : {\n" + + " \"boolean_value\" : {\n" + + " \"value\" : true,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals( + json, buildQuery(dsl.equal(ref("boolean_value", BOOLEAN), dsl.castBoolean(expr)))); + } + + @ParameterizedTest(name = "castBooleanFalse({0})") + @MethodSource({"booleanFalseCastSource"}) + void cast_to_boolean_false_in_filter(LiteralExpression expr) { + String json = "{\n" + + " \"term\" : {\n" + + " \"boolean_value\" : {\n" + + " \"value\" : false,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals( + json, buildQuery(dsl.equal(ref("boolean_value", BOOLEAN), dsl.castBoolean(expr)))); + } + + @Test + void cast_from_boolean() { + Expression booleanExpr = literal(false); + String json = "{\n" + + " \"term\" : {\n" + + " \"my_value\" : {\n" + + " \"value\" : 0,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", BYTE), dsl.castByte(booleanExpr)))); + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", SHORT), dsl.castShort(booleanExpr)))); + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", INTEGER), dsl.castInt(booleanExpr)))); + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", LONG), dsl.castLong(booleanExpr)))); + + json = "{\n" + + " \"term\" : {\n" + + " \"my_value\" : {\n" + + " \"value\" : 0.0,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", FLOAT), dsl.castFloat(booleanExpr)))); + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", DOUBLE), dsl.castDouble(booleanExpr)))); + + json = "{\n" + + " \"term\" : {\n" + + " \"my_value\" : {\n" + + " \"value\" : \"false\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + assertJsonEquals(json, buildQuery( + dsl.equal(ref("my_value", STRING), dsl.castString(booleanExpr)))); + } + + @Test + void cast_to_date_in_filter() { + String json = "{\n" + + " \"term\" : {\n" + + " \"date_value\" : {\n" + + " \"value\" : \"2021-11-08\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals(json, buildQuery(dsl.equal( + ref("date_value", DATE), dsl.castDate(literal("2021-11-08"))))); + assertJsonEquals(json, buildQuery(dsl.equal( + ref("date_value", DATE), dsl.castDate(literal(new ExprDateValue("2021-11-08")))))); + assertJsonEquals(json, buildQuery(dsl.equal(ref( + "date_value", DATE), dsl.castDate(literal(new ExprDatetimeValue("2021-11-08 17:00:00")))))); + } + + @Test + void cast_to_time_in_filter() { + String json = "{\n" + + " \"term\" : {\n" + + " \"time_value\" : {\n" + + " \"value\" : \"17:00:00\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals(json, buildQuery(dsl.equal( + ref("time_value", TIME), dsl.castTime(literal("17:00:00"))))); + assertJsonEquals(json, buildQuery(dsl.equal( + ref("time_value", TIME), dsl.castTime(literal(new ExprTimeValue("17:00:00")))))); + assertJsonEquals(json, buildQuery(dsl.equal(ref("time_value", TIME), dsl + .castTime(literal(new ExprTimestampValue("2021-11-08 17:00:00")))))); + } + + @Test + void cast_to_datetime_in_filter() { + String json = "{\n" + + " \"term\" : {\n" + + " \"datetime_value\" : {\n" + + " \"value\" : \"2021-11-08 17:00:00\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals(json, buildQuery(dsl.equal(ref("datetime_value", DATETIME), dsl + .castDatetime(literal("2021-11-08 17:00:00"))))); + assertJsonEquals(json, buildQuery(dsl.equal(ref("datetime_value", DATETIME), dsl + .castDatetime(literal(new ExprTimestampValue("2021-11-08 17:00:00")))))); + } + + @Test + void cast_to_timestamp_in_filter() { + String json = "{\n" + + " \"term\" : {\n" + + " \"timestamp_value\" : {\n" + + " \"value\" : 1636390800000,\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}"; + + assertJsonEquals(json, buildQuery(dsl.equal(ref("timestamp_value", TIMESTAMP), dsl + .castTimestamp(literal("2021-11-08 17:00:00"))))); + assertJsonEquals(json, buildQuery(dsl.equal(ref("timestamp_value", TIMESTAMP), dsl + .castTimestamp(literal(new ExprTimestampValue("2021-11-08 17:00:00")))))); + } + + @Test + void cast_in_range_query() { + assertJsonEquals( + "{\n" + + " \"range\" : {\n" + + " \"timestamp_value\" : {\n" + + " \"from\" : 1636390800000,\n" + + " \"to\" : null," + + " \"include_lower\" : false," + + " \"include_upper\" : true," + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(dsl.greater(ref("timestamp_value", TIMESTAMP), dsl + .castTimestamp(literal("2021-11-08 17:00:00"))))); + } + + @Test + void non_literal_in_cast_should_build_script() { + mockToStringSerializer(); + assertJsonEquals( + "{\n" + + " \"script\" : {\n" + + " \"script\" : {\n" + + " \"source\" : \"=(string_value, cast_to_string(+(1, 0)))\",\n" + + " \"lang\" : \"opensearch_query_expression\"\n" + + " },\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("string_value", STRING), dsl.castString(dsl + .add(literal(1), literal(0))))) + ); + } + + @Test + void non_cast_nested_function_should_build_script() { + mockToStringSerializer(); + assertJsonEquals( + "{\n" + + " \"script\" : {\n" + + " \"script\" : {\n" + + " \"source\" : \"=(integer_value, abs(+(1, 0)))\",\n" + + " \"lang\" : \"opensearch_query_expression\"\n" + + " },\n" + + " \"boost\" : 1.0\n" + + " }\n" + + "}", + buildQuery(dsl.equal(ref("integer_value", INTEGER), dsl.abs(dsl + .add(literal(1), literal(0))))) + ); + } + private static void assertJsonEquals(String expected, String actual) { assertTrue(new JSONObject(expected).similar(new JSONObject(actual)), StringUtils.format("Expected: %s, actual: %s", expected, actual));