From b869b6a4b1daf5acc44390def00e8914922f8f18 Mon Sep 17 00:00:00 2001 From: Max Ksyunz Date: Wed, 7 Sep 2022 11:38:40 -0700 Subject: [PATCH] Refactor relevance search functions (#746) - Update QueryStringTest to check for SyntaxCheckException. SyntaxCheckException is correct when incorrect # of parameters See https://github.com/opensearch-project/sql/pull/604#discussion_r877339888 for reference. - Introduce MultiFieldQuery and SingleFieldQuery base classes. - Extract FunctionResolver interface. FunctionResolver is now DefaultFunctionResolver. RelevanceFunctionResolver is a simplified function resolver for relevance search functions. - Removed tests from FilterQueryBuilderTest that verified exceptions thrown for invalid function calls. These scenarios are now handled by RelevanceQuery::build. Signed-off-by: MaxKsyunz Signed-off-by: MaxKsyunz --- .../aggregation/AggregatorFunction.java | 39 ++++----- .../expression/datetime/DateTimeFunction.java | 57 +++++++------ .../expression/datetime/IntervalClause.java | 4 +- .../function/BuiltinFunctionRepository.java | 4 +- .../function/DefaultFunctionResolver.java | 69 +++++++++++++++ .../sql/expression/function/FunctionDSL.java | 14 ++-- .../expression/function/FunctionResolver.java | 60 ++----------- .../function/OpenSearchFunctions.java | 57 ++++--------- .../function/RelevanceFunctionResolver.java | 67 +++++++++++++++ .../arthmetic/ArithmeticFunction.java | 12 +-- .../arthmetic/MathematicalFunction.java | 64 +++++++------- .../operator/convert/TypeCastOperator.java | 26 +++--- .../predicate/BinaryPredicateOperator.java | 26 +++--- .../predicate/UnaryPredicateOperator.java | 22 ++--- .../sql/expression/text/TextFunction.java | 36 ++++---- .../expression/window/WindowFunctions.java | 16 ++-- .../sql/analysis/ExpressionAnalyzerTest.java | 9 ++ .../BuiltinFunctionRepositoryTest.java | 4 +- ....java => DefaultFunctionResolverTest.java} | 8 +- .../RelevanceFunctionResolverTest.java | 64 ++++++++++++++ .../relevance/MatchBoolPrefixQuery.java | 9 +- .../relevance/MatchPhrasePrefixQuery.java | 9 +- .../lucene/relevance/MatchPhraseQuery.java | 9 +- .../filter/lucene/relevance/MatchQuery.java | 10 ++- .../lucene/relevance/MultiFieldQuery.java | 37 ++++++++ .../lucene/relevance/MultiMatchQuery.java | 48 ++--------- .../lucene/relevance/QueryStringQuery.java | 55 +++--------- .../lucene/relevance/RelevanceQuery.java | 34 +++++--- .../relevance/SimpleQueryStringQuery.java | 47 ++--------- .../lucene/relevance/SingleFieldQuery.java | 31 +++++++ .../script/filter/FilterQueryBuilderTest.java | 84 ------------------- .../lucene/MatchBoolPrefixQueryTest.java | 16 ++-- .../filter/lucene/MatchPhraseQueryTest.java | 41 ++++----- .../script/filter/lucene/MultiMatchTest.java | 15 ++-- .../script/filter/lucene/QueryStringTest.java | 9 +- .../filter/lucene/SimpleQueryStringTest.java | 9 +- .../lucene/relevance/MultiFieldQueryTest.java | 61 ++++++++++++++ .../relevance/RelevanceQueryBuildTest.java | 20 +++-- .../relevance/SingleFieldQueryTest.java | 51 +++++++++++ 39 files changed, 703 insertions(+), 550 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java create mode 100644 core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java rename core/src/test/java/org/opensearch/sql/expression/function/{FunctionResolverTest.java => DefaultFunctionResolverTest.java} (90%) create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index 20e91aa6cd..172e1ee778 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -27,9 +27,9 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; /** @@ -44,6 +44,7 @@ public class AggregatorFunction { /** * Register Aggregation Function. + * * @param repository {@link BuiltinFunctionRepository}. */ public static void register(BuiltinFunctionRepository repository) { @@ -58,9 +59,9 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(stddevPop()); } - private static FunctionResolver avg() { + private static DefaultFunctionResolver avg() { FunctionName functionName = BuiltinFunctionName.AVG.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -69,18 +70,18 @@ private static FunctionResolver avg() { ); } - private static FunctionResolver count() { + private static DefaultFunctionResolver count() { FunctionName functionName = BuiltinFunctionName.COUNT.getName(); - FunctionResolver functionResolver = new FunctionResolver(functionName, + DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, ExprCoreType.coreTypes().stream().collect(Collectors.toMap( type -> new FunctionSignature(functionName, Collections.singletonList(type)), type -> arguments -> new CountAggregator(arguments, INTEGER)))); return functionResolver; } - private static FunctionResolver sum() { + private static DefaultFunctionResolver sum() { FunctionName functionName = BuiltinFunctionName.SUM.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), @@ -95,9 +96,9 @@ private static FunctionResolver sum() { ); } - private static FunctionResolver min() { + private static DefaultFunctionResolver min() { FunctionName functionName = BuiltinFunctionName.MIN.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), @@ -121,9 +122,9 @@ private static FunctionResolver min() { .build()); } - private static FunctionResolver max() { + private static DefaultFunctionResolver max() { FunctionName functionName = BuiltinFunctionName.MAX.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), @@ -148,9 +149,9 @@ private static FunctionResolver max() { ); } - private static FunctionResolver varSamp() { + private static DefaultFunctionResolver varSamp() { FunctionName functionName = BuiltinFunctionName.VARSAMP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -159,9 +160,9 @@ private static FunctionResolver varSamp() { ); } - private static FunctionResolver varPop() { + private static DefaultFunctionResolver varPop() { FunctionName functionName = BuiltinFunctionName.VARPOP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -170,9 +171,9 @@ private static FunctionResolver varPop() { ); } - private static FunctionResolver stddevSamp() { + private static DefaultFunctionResolver stddevSamp() { FunctionName functionName = BuiltinFunctionName.STDDEV_SAMP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), @@ -181,9 +182,9 @@ private static FunctionResolver stddevSamp() { ); } - private static FunctionResolver stddevPop() { + private static DefaultFunctionResolver stddevPop() { FunctionName functionName = BuiltinFunctionName.STDDEV_POP.getName(); - return new FunctionResolver( + return new DefaultFunctionResolver( functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java index 0fccacd136..469f7e2011 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java @@ -37,6 +37,7 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionName; import org.opensearch.sql.expression.function.FunctionResolver; @@ -94,7 +95,7 @@ public void register(BuiltinFunctionRepository repository) { * (STRING/DATETIME/TIMESTAMP, LONG) -> DATETIME */ - private FunctionResolver add_date(FunctionName functionName) { + private DefaultFunctionResolver add_date(FunctionName functionName) { return define(functionName, impl(nullMissingHandling(DateTimeFunction::exprAddDateInterval), DATETIME, STRING, INTERVAL), @@ -110,7 +111,7 @@ private FunctionResolver add_date(FunctionName functionName) { ); } - private FunctionResolver adddate() { + private DefaultFunctionResolver adddate() { return add_date(BuiltinFunctionName.ADDDATE.getName()); } @@ -119,7 +120,7 @@ private FunctionResolver adddate() { * Also to construct a date type. The supported signatures: * STRING/DATE/DATETIME/TIMESTAMP -> DATE */ - private FunctionResolver date() { + private DefaultFunctionResolver date() { return define(BuiltinFunctionName.DATE.getName(), impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, STRING), impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, DATE), @@ -127,7 +128,7 @@ private FunctionResolver date() { impl(nullMissingHandling(DateTimeFunction::exprDate), DATE, TIMESTAMP)); } - private FunctionResolver date_add() { + private DefaultFunctionResolver date_add() { return add_date(BuiltinFunctionName.DATE_ADD.getName()); } @@ -138,7 +139,7 @@ private FunctionResolver date_add() { * (DATE, LONG) -> DATE * (STRING/DATETIME/TIMESTAMP, LONG) -> DATETIME */ - private FunctionResolver sub_date(FunctionName functionName) { + private DefaultFunctionResolver sub_date(FunctionName functionName) { return define(functionName, impl(nullMissingHandling(DateTimeFunction::exprSubDateInterval), DATETIME, STRING, INTERVAL), @@ -154,14 +155,14 @@ private FunctionResolver sub_date(FunctionName functionName) { ); } - private FunctionResolver date_sub() { + private DefaultFunctionResolver date_sub() { return sub_date(BuiltinFunctionName.DATE_SUB.getName()); } /** * DAY(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). */ - private FunctionResolver day() { + private DefaultFunctionResolver day() { return define(BuiltinFunctionName.DAY.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), @@ -175,7 +176,7 @@ private FunctionResolver day() { * return the name of the weekday for date, including Monday, Tuesday, Wednesday, * Thursday, Friday, Saturday and Sunday. */ - private FunctionResolver dayName() { + private DefaultFunctionResolver dayName() { return define(BuiltinFunctionName.DAYNAME.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayName), STRING, DATETIME), @@ -187,7 +188,7 @@ private FunctionResolver dayName() { /** * DAYOFMONTH(STRING/DATE/DATETIME/TIMESTAMP). return the day of the month (1-31). */ - private FunctionResolver dayOfMonth() { + private DefaultFunctionResolver dayOfMonth() { return define(BuiltinFunctionName.DAYOFMONTH.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfMonth), INTEGER, DATETIME), @@ -200,7 +201,7 @@ private FunctionResolver dayOfMonth() { * DAYOFWEEK(STRING/DATE/DATETIME/TIMESTAMP). * return the weekday index for date (1 = Sunday, 2 = Monday, …, 7 = Saturday). */ - private FunctionResolver dayOfWeek() { + private DefaultFunctionResolver dayOfWeek() { return define(BuiltinFunctionName.DAYOFWEEK.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfWeek), INTEGER, DATETIME), @@ -213,7 +214,7 @@ private FunctionResolver dayOfWeek() { * DAYOFYEAR(STRING/DATE/DATETIME/TIMESTAMP). * return the day of the year for date (1-366). */ - private FunctionResolver dayOfYear() { + private DefaultFunctionResolver dayOfYear() { return define(BuiltinFunctionName.DAYOFYEAR.getName(), impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprDayOfYear), INTEGER, DATETIME), @@ -225,7 +226,7 @@ private FunctionResolver dayOfYear() { /** * FROM_DAYS(LONG). return the date value given the day number N. */ - private FunctionResolver from_days() { + private DefaultFunctionResolver from_days() { return define(BuiltinFunctionName.FROM_DAYS.getName(), impl(nullMissingHandling(DateTimeFunction::exprFromDays), DATE, LONG)); } @@ -233,7 +234,7 @@ private FunctionResolver from_days() { /** * HOUR(STRING/TIME/DATETIME/TIMESTAMP). return the hour value for time. */ - private FunctionResolver hour() { + private DefaultFunctionResolver hour() { return define(BuiltinFunctionName.HOUR.getName(), impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprHour), INTEGER, TIME), @@ -255,7 +256,7 @@ private FunctionResolver maketime() { /** * MICROSECOND(STRING/TIME/DATETIME/TIMESTAMP). return the microsecond value for time. */ - private FunctionResolver microsecond() { + private DefaultFunctionResolver microsecond() { return define(BuiltinFunctionName.MICROSECOND.getName(), impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprMicrosecond), INTEGER, TIME), @@ -267,7 +268,7 @@ private FunctionResolver microsecond() { /** * MINUTE(STRING/TIME/DATETIME/TIMESTAMP). return the minute value for time. */ - private FunctionResolver minute() { + private DefaultFunctionResolver minute() { return define(BuiltinFunctionName.MINUTE.getName(), impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprMinute), INTEGER, TIME), @@ -279,7 +280,7 @@ private FunctionResolver minute() { /** * MONTH(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-12). */ - private FunctionResolver month() { + private DefaultFunctionResolver month() { return define(BuiltinFunctionName.MONTH.getName(), impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprMonth), INTEGER, DATETIME), @@ -291,7 +292,7 @@ private FunctionResolver month() { /** * MONTHNAME(STRING/DATE/DATETIME/TIMESTAMP). return the full name of the month for date. */ - private FunctionResolver monthName() { + private DefaultFunctionResolver monthName() { return define(BuiltinFunctionName.MONTHNAME.getName(), impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATE), impl(nullMissingHandling(DateTimeFunction::exprMonthName), STRING, DATETIME), @@ -303,7 +304,7 @@ private FunctionResolver monthName() { /** * QUARTER(STRING/DATE/DATETIME/TIMESTAMP). return the month for date (1-4). */ - private FunctionResolver quarter() { + private DefaultFunctionResolver quarter() { return define(BuiltinFunctionName.QUARTER.getName(), impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprQuarter), INTEGER, DATETIME), @@ -315,7 +316,7 @@ private FunctionResolver quarter() { /** * SECOND(STRING/TIME/DATETIME/TIMESTAMP). return the second value for time. */ - private FunctionResolver second() { + private DefaultFunctionResolver second() { return define(BuiltinFunctionName.SECOND.getName(), impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, STRING), impl(nullMissingHandling(DateTimeFunction::exprSecond), INTEGER, TIME), @@ -324,7 +325,7 @@ private FunctionResolver second() { ); } - private FunctionResolver subdate() { + private DefaultFunctionResolver subdate() { return sub_date(BuiltinFunctionName.SUBDATE.getName()); } @@ -333,7 +334,7 @@ private FunctionResolver subdate() { * Also to construct a time type. The supported signatures: * STRING/DATE/DATETIME/TIME/TIMESTAMP -> TIME */ - private FunctionResolver time() { + private DefaultFunctionResolver time() { return define(BuiltinFunctionName.TIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, STRING), impl(nullMissingHandling(DateTimeFunction::exprTime), TIME, DATE), @@ -345,7 +346,7 @@ private FunctionResolver time() { /** * TIME_TO_SEC(STRING/TIME/DATETIME/TIMESTAMP). return the time argument, converted to seconds. */ - private FunctionResolver time_to_sec() { + private DefaultFunctionResolver time_to_sec() { return define(BuiltinFunctionName.TIME_TO_SEC.getName(), impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, STRING), impl(nullMissingHandling(DateTimeFunction::exprTimeToSec), LONG, TIME), @@ -359,7 +360,7 @@ private FunctionResolver time_to_sec() { * Also to construct a date type. The supported signatures: * STRING/DATE/DATETIME/TIMESTAMP -> DATE */ - private FunctionResolver timestamp() { + private DefaultFunctionResolver timestamp() { return define(BuiltinFunctionName.TIMESTAMP.getName(), impl(nullMissingHandling(DateTimeFunction::exprTimestamp), TIMESTAMP, STRING), impl(nullMissingHandling(DateTimeFunction::exprTimestamp), TIMESTAMP, DATE), @@ -370,7 +371,7 @@ private FunctionResolver timestamp() { /** * TO_DAYS(STRING/DATE/DATETIME/TIMESTAMP). return the day number of the given date. */ - private FunctionResolver to_days() { + private DefaultFunctionResolver to_days() { return define(BuiltinFunctionName.TO_DAYS.getName(), impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, STRING), impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, TIMESTAMP), @@ -381,7 +382,7 @@ private FunctionResolver to_days() { /** * WEEK(DATE[,mode]). return the week number for date. */ - private FunctionResolver week() { + private DefaultFunctionResolver week() { return define(BuiltinFunctionName.WEEK.getName(), impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprWeekWithoutMode), INTEGER, DATETIME), @@ -397,7 +398,7 @@ private FunctionResolver week() { /** * YEAR(STRING/DATE/DATETIME/TIMESTAMP). return the year for date (1000-9999). */ - private FunctionResolver year() { + private DefaultFunctionResolver year() { return define(BuiltinFunctionName.YEAR.getName(), impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATE), impl(nullMissingHandling(DateTimeFunction::exprYear), INTEGER, DATETIME), @@ -414,7 +415,7 @@ private FunctionResolver year() { * (DATETIME, STRING) -> STRING * (TIMESTAMP, STRING) -> STRING */ - private FunctionResolver date_format() { + private DefaultFunctionResolver date_format() { return define(BuiltinFunctionName.DATE_FORMAT.getName(), impl(nullMissingHandling(DateTimeFormatterUtil::getFormattedDate), STRING, STRING, STRING), @@ -711,6 +712,7 @@ private ExprValue exprToDays(ExprValue date) { /** * Week for date implementation for ExprValue. + * * @param date ExprValue of Date/Datetime/Timestamp/String type. * @param mode ExprValue of Integer type. */ @@ -722,6 +724,7 @@ private ExprValue exprWeek(ExprValue date, ExprValue mode) { /** * Week for date implementation for ExprValue. * When mode is not specified default value mode 0 is used for default_week_format. + * * @param date ExprValue of Date/Datetime/Timestamp/String type. * @return ExprValue. */ diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java b/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java index f4746ebe7a..c5076431cc 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/IntervalClause.java @@ -25,7 +25,7 @@ import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; -import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; @UtilityClass public class IntervalClause { @@ -44,7 +44,7 @@ public void register(BuiltinFunctionRepository repository) { repository.register(interval()); } - private FunctionResolver interval() { + private DefaultFunctionResolver interval() { return define(BuiltinFunctionName.INTERVAL.getName(), impl(nullMissingHandling(IntervalClause::interval), INTERVAL, INTEGER, STRING), impl(nullMissingHandling(IntervalClause::interval), INTERVAL, LONG, STRING)); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 1f4c885723..545e710f65 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -29,9 +29,9 @@ public class BuiltinFunctionRepository { private final Map functionResolverMap; /** - * Register {@link FunctionResolver} to the Builtin Function Repository. + * Register {@link DefaultFunctionResolver} to the Builtin Function Repository. * - * @param resolver {@link FunctionResolver} to be registered + * @param resolver {@link DefaultFunctionResolver} to be registered */ public void register(FunctionResolver resolver) { functionResolverMap.put(resolver.getFunctionName(), resolver); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java new file mode 100644 index 0000000000..7081179162 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/DefaultFunctionResolver.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.AbstractMap; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.Builder; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.Singular; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +/** + * The Function Resolver hold the overload {@link FunctionBuilder} implementation. + * is composed by {@link FunctionName} which identified the function name + * and a map of {@link FunctionSignature} and {@link FunctionBuilder} + * to represent the overloaded implementation + */ +@Builder +@RequiredArgsConstructor +public class DefaultFunctionResolver implements FunctionResolver { + @Getter + private final FunctionName functionName; + @Singular("functionBundle") + private final Map functionBundle; + + /** + * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. + * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. + * If applying the widening rule, found the most match one, return it. + * If nothing found, throw {@link ExpressionEvaluationException} + * + * @return function signature and its builder + */ + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + PriorityQueue> functionMatchQueue = new PriorityQueue<>( + Map.Entry.comparingByKey()); + + for (FunctionSignature functionSignature : functionBundle.keySet()) { + functionMatchQueue.add( + new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), + functionSignature)); + } + Map.Entry bestMatchEntry = functionMatchQueue.peek(); + if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) { + throw new ExpressionEvaluationException( + String.format("%s function expected %s, but get %s", functionName, + formatFunctions(functionBundle.keySet()), + unresolvedSignature.formatTypes() + )); + } else { + FunctionSignature resolvedSignature = bestMatchEntry.getValue(); + return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); + } + } + + private String formatFunctions(Set functionSignatures) { + return functionSignatures.stream().map(FunctionSignature::formatTypes) + .collect(Collectors.joining(",", "{", "}")); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index dcd65d6b87..1fad333ead 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -32,9 +32,9 @@ public class FunctionDSL { * @param functions a list of function implementation. * @return FunctionResolver. */ - public static FunctionResolver define(FunctionName functionName, - SerializableFunction>... functions) { + public static DefaultFunctionResolver define(FunctionName functionName, + SerializableFunction>... functions) { return define(functionName, Arrays.asList(functions)); } @@ -45,11 +45,11 @@ public static FunctionResolver define(FunctionName functionName, * @param functions a list of function implementation. * @return FunctionResolver. */ - public static FunctionResolver define(FunctionName functionName, - List>> functions) { + public static DefaultFunctionResolver define(FunctionName functionName, List< + SerializableFunction>> functions) { - FunctionResolver.FunctionResolverBuilder builder = FunctionResolver.builder(); + DefaultFunctionResolver.DefaultFunctionResolverBuilder builder + = DefaultFunctionResolver.builder(); builder.functionName(functionName); for (Function> func : functions) { Pair functionBuilder = func.apply(functionName); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java index 06d0fb673c..1635b6f846 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionResolver.java @@ -5,64 +5,14 @@ package org.opensearch.sql.expression.function; -import java.util.AbstractMap; -import java.util.Map; -import java.util.PriorityQueue; -import java.util.Set; -import java.util.stream.Collectors; -import lombok.Builder; -import lombok.Getter; -import lombok.RequiredArgsConstructor; -import lombok.Singular; import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.exception.ExpressionEvaluationException; /** - * The Function Resolver hold the overload {@link FunctionBuilder} implementation. - * is composed by {@link FunctionName} which identified the function name - * and a map of {@link FunctionSignature} and {@link FunctionBuilder} - * to represent the overloaded implementation + * An interface for any class that can provide a {@ref FunctionBuilder} + * given a {@ref FunctionSignature}. */ -@Builder -@RequiredArgsConstructor -public class FunctionResolver { - @Getter - private final FunctionName functionName; - @Singular("functionBundle") - private final Map functionBundle; +public interface FunctionResolver { + Pair resolve(FunctionSignature unresolvedSignature); - /** - * Resolve the {@link FunctionBuilder} by using input {@link FunctionSignature}. - * If the {@link FunctionBuilder} exactly match the input {@link FunctionSignature}, return it. - * If applying the widening rule, found the most match one, return it. - * If nothing found, throw {@link ExpressionEvaluationException} - * - * @return function signature and its builder - */ - public Pair resolve(FunctionSignature unresolvedSignature) { - PriorityQueue> functionMatchQueue = new PriorityQueue<>( - Map.Entry.comparingByKey()); - - for (FunctionSignature functionSignature : functionBundle.keySet()) { - functionMatchQueue.add( - new AbstractMap.SimpleEntry<>(unresolvedSignature.match(functionSignature), - functionSignature)); - } - Map.Entry bestMatchEntry = functionMatchQueue.peek(); - if (FunctionSignature.NOT_MATCH.equals(bestMatchEntry.getKey())) { - throw new ExpressionEvaluationException( - String.format("%s function expected %s, but get %s", functionName, - formatFunctions(functionBundle.keySet()), - unresolvedSignature.formatTypes() - )); - } else { - FunctionSignature resolvedSignature = bestMatchEntry.getValue(); - return Pair.of(resolvedSignature, functionBundle.get(resolvedSignature)); - } - } - - private String formatFunctions(Set functionSignatures) { - return functionSignatures.stream().map(FunctionSignature::formatTypes) - .collect(Collectors.joining(",", "{", "}")); - } + FunctionName getFunctionName(); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index c3e5cc5594..bb3eb7008b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -9,13 +9,9 @@ import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; import com.google.common.collect.ImmutableMap; -import java.util.ArrayList; -import java.util.Collections; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; -import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; @@ -27,16 +23,6 @@ @UtilityClass public class OpenSearchFunctions { - - public static final int MATCH_MAX_NUM_PARAMETERS = 14; - public static final int MATCH_BOOL_PREFIX_MAX_NUM_PARAMETERS = 9; - public static final int MATCH_PHRASE_MAX_NUM_PARAMETERS = 5; - public static final int MIN_NUM_PARAMETERS = 2; - public static final int MULTI_MATCH_MAX_NUM_PARAMETERS = 17; - public static final int SIMPLE_QUERY_STRING_MAX_NUM_PARAMETERS = 14; - public static final int QUERY_STRING_MAX_NUM_PARAMETERS = 25; - public static final int MATCH_PHRASE_PREFIX_MAX_NUM_PARAMETERS = 7; - /** * Add functions specific to OpenSearch to repository. */ @@ -58,67 +44,54 @@ private static FunctionResolver highlight() { FunctionName functionName = BuiltinFunctionName.HIGHLIGHT.getName(); FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING)); FunctionBuilder functionBuilder = arguments -> new HighlightExpression(arguments.get(0)); - return new FunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); + return new DefaultFunctionResolver(functionName, + ImmutableMap.of(functionSignature, functionBuilder)); } private static FunctionResolver match_bool_prefix() { FunctionName name = BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(); - return getRelevanceFunctionResolver(name, MATCH_BOOL_PREFIX_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(name, STRING); } private static FunctionResolver match() { FunctionName funcName = BuiltinFunctionName.MATCH.getName(); - return getRelevanceFunctionResolver(funcName, MATCH_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(funcName, STRING); } private static FunctionResolver match_phrase_prefix() { FunctionName funcName = BuiltinFunctionName.MATCH_PHRASE_PREFIX.getName(); - return getRelevanceFunctionResolver(funcName, MATCH_PHRASE_PREFIX_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(funcName, STRING); } private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) { FunctionName funcName = matchPhrase.getName(); - return getRelevanceFunctionResolver(funcName, MATCH_PHRASE_MAX_NUM_PARAMETERS, STRING); + return new RelevanceFunctionResolver(funcName, STRING); } private static FunctionResolver multi_match() { FunctionName funcName = BuiltinFunctionName.MULTI_MATCH.getName(); - return getRelevanceFunctionResolver(funcName, MULTI_MATCH_MAX_NUM_PARAMETERS, STRUCT); + return new RelevanceFunctionResolver(funcName, STRUCT); } private static FunctionResolver simple_query_string() { FunctionName funcName = BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(); - return getRelevanceFunctionResolver(funcName, SIMPLE_QUERY_STRING_MAX_NUM_PARAMETERS, STRUCT); + return new RelevanceFunctionResolver(funcName, STRUCT); } private static FunctionResolver query_string() { FunctionName funcName = BuiltinFunctionName.QUERY_STRING.getName(); - return getRelevanceFunctionResolver(funcName, QUERY_STRING_MAX_NUM_PARAMETERS, STRUCT); - } - - private static FunctionResolver getRelevanceFunctionResolver( - FunctionName funcName, int maxNumParameters, ExprCoreType firstArgType) { - return new FunctionResolver(funcName, - getRelevanceFunctionSignatureMap(funcName, maxNumParameters, firstArgType)); - } - - private static Map getRelevanceFunctionSignatureMap( - FunctionName funcName, int maxNumParameters, ExprCoreType firstArgType) { - FunctionBuilder buildFunction = args -> new OpenSearchFunction(funcName, args); - var signatureMapBuilder = ImmutableMap.builder(); - for (int numParameters = MIN_NUM_PARAMETERS; - numParameters <= maxNumParameters; numParameters++) { - List args = new ArrayList<>(Collections.nCopies(numParameters - 1, STRING)); - args.add(0, firstArgType); - signatureMapBuilder.put(new FunctionSignature(funcName, args), buildFunction); - } - return signatureMapBuilder.build(); + return new RelevanceFunctionResolver(funcName, STRUCT); } - private static class OpenSearchFunction extends FunctionExpression { + public static class OpenSearchFunction extends FunctionExpression { private final FunctionName functionName; private final List arguments; + /** + * Required argument constructor. + * @param functionName name of the function + * @param arguments a list of expressions + */ public OpenSearchFunction(FunctionName functionName, List arguments) { super(functionName, arguments); this.functionName = functionName; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java new file mode 100644 index 0000000000..e781db8c84 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java @@ -0,0 +1,67 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.List; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; + +@RequiredArgsConstructor +public class RelevanceFunctionResolver + implements FunctionResolver { + + @Getter + private final FunctionName functionName; + + @Getter + private final ExprType declaredFirstParamType; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + if (!unresolvedSignature.getFunctionName().equals(functionName)) { + throw new SemanticCheckException(String.format("Expected '%s' but got '%s'", + functionName.getFunctionName(), unresolvedSignature.getFunctionName().getFunctionName())); + } + List paramTypes = unresolvedSignature.getParamTypeList(); + ExprType providedFirstParamType = paramTypes.get(0); + + // Check if the first parameter is of the specified type. + if (!declaredFirstParamType.equals(providedFirstParamType)) { + throw new SemanticCheckException( + getWrongParameterErrorMessage(0, providedFirstParamType, declaredFirstParamType)); + } + + // Check if all but the first parameter are of type STRING. + for (int i = 1; i < paramTypes.size(); i++) { + ExprType paramType = paramTypes.get(i); + if (!ExprCoreType.STRING.equals(paramType)) { + throw new SemanticCheckException( + getWrongParameterErrorMessage(i, paramType, ExprCoreType.STRING)); + } + } + + FunctionBuilder buildFunction = + args -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); + return Pair.of(unresolvedSignature, buildFunction); + } + + /** Returns a helpful error message when expected parameter type does not match the + * specified parameter type. + * + * @param i 0-based index of the parameter in a function signature. + * @param paramType the type of the ith parameter at run-time. + * @param expectedType the expected type of the ith parameter + * @return A user-friendly error message that informs of the type difference. + */ + private String getWrongParameterErrorMessage(int i, ExprType paramType, ExprType expectedType) { + return String.format("Expected type %s instead of %s for parameter #%d", + expectedType.typeName(), paramType.typeName(), i + 1); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java index 81356e789b..c4b106bbf4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/ArithmeticFunction.java @@ -23,8 +23,8 @@ import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; -import org.opensearch.sql.expression.function.FunctionResolver; /** * The definition of arithmetic function @@ -49,7 +49,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(modules()); } - private static FunctionResolver add() { + private static DefaultFunctionResolver add() { return FunctionDSL.define(BuiltinFunctionName.ADD.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -79,7 +79,7 @@ private static FunctionResolver add() { ); } - private static FunctionResolver subtract() { + private static DefaultFunctionResolver subtract() { return FunctionDSL.define(BuiltinFunctionName.SUBTRACT.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -109,7 +109,7 @@ private static FunctionResolver subtract() { ); } - private static FunctionResolver multiply() { + private static DefaultFunctionResolver multiply() { return FunctionDSL.define(BuiltinFunctionName.MULTIPLY.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -139,7 +139,7 @@ private static FunctionResolver multiply() { ); } - private static FunctionResolver divide() { + private static DefaultFunctionResolver divide() { return FunctionDSL.define(BuiltinFunctionName.DIVIDE.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -174,7 +174,7 @@ private static FunctionResolver divide() { } - private static FunctionResolver modules() { + private static DefaultFunctionResolver modules() { return FunctionDSL.define(BuiltinFunctionName.MODULES.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java index d310b42904..0ce48af48c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/arthmetic/MathematicalFunction.java @@ -36,10 +36,10 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionDSL; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.SerializableFunction; @@ -88,7 +88,7 @@ public static void register(BuiltinFunctionRepository repository) { * Definition of abs() function. The supported signature of abs() function are INT -> INT LONG -> * LONG FLOAT -> FLOAT DOUBLE -> DOUBLE */ - private static FunctionResolver abs() { + private static DefaultFunctionResolver abs() { return FunctionDSL.define(BuiltinFunctionName.ABS.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprByteValue(Math.abs(v.byteValue()))), @@ -115,7 +115,7 @@ private static FunctionResolver abs() { * Definition of ceil(x)/ceiling(x) function. Calculate the next highest integer that x rounds up * to The supported signature of ceil/ceiling function is DOUBLE -> INTEGER */ - private static FunctionResolver ceil() { + private static DefaultFunctionResolver ceil() { return FunctionDSL.define(BuiltinFunctionName.CEIL.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprIntegerValue(Math.ceil(v.doubleValue()))), @@ -123,7 +123,7 @@ private static FunctionResolver ceil() { ); } - private static FunctionResolver ceiling() { + private static DefaultFunctionResolver ceiling() { return FunctionDSL.define(BuiltinFunctionName.CEILING.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprIntegerValue(Math.ceil(v.doubleValue()))), @@ -138,7 +138,7 @@ private static FunctionResolver ceiling() { * (STRING, INTEGER, INTEGER) -> STRING * (INTEGER, INTEGER, INTEGER) -> STRING */ - private static FunctionResolver conv() { + private static DefaultFunctionResolver conv() { return FunctionDSL.define(BuiltinFunctionName.CONV.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling((x, a, b) -> new ExprStringValue( @@ -161,7 +161,7 @@ private static FunctionResolver conv() { * The supported signature of crc32 function is * STRING -> LONG */ - private static FunctionResolver crc32() { + private static DefaultFunctionResolver crc32() { return FunctionDSL.define(BuiltinFunctionName.CRC32.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> { @@ -178,7 +178,7 @@ private static FunctionResolver crc32() { * Get the Euler's number. * () -> DOUBLE */ - private static FunctionResolver euler() { + private static DefaultFunctionResolver euler() { return FunctionDSL.define(BuiltinFunctionName.E.getName(), FunctionDSL.impl(() -> new ExprDoubleValue(Math.E), DOUBLE) ); @@ -188,7 +188,7 @@ private static FunctionResolver euler() { * Definition of exp(x) function. Calculate exponent function e to the x The supported signature * of exp function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver exp() { + private static DefaultFunctionResolver exp() { return FunctionDSL.define(BuiltinFunctionName.EXP.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -200,7 +200,7 @@ private static FunctionResolver exp() { * Definition of floor(x) function. Calculate the next nearest whole integer that x rounds down to * The supported signature of floor function is DOUBLE -> INTEGER */ - private static FunctionResolver floor() { + private static DefaultFunctionResolver floor() { return FunctionDSL.define(BuiltinFunctionName.FLOOR.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling(v -> new ExprIntegerValue(Math.floor(v.doubleValue()))), @@ -212,7 +212,7 @@ private static FunctionResolver floor() { * Definition of ln(x) function. Calculate the natural logarithm of x The supported signature of * ln function is INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver ln() { + private static DefaultFunctionResolver ln() { return FunctionDSL.define(BuiltinFunctionName.LN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -225,7 +225,7 @@ private static FunctionResolver ln() { * supported signature of log function is (b: INTEGER/LONG/FLOAT/DOUBLE, x: * INTEGER/LONG/FLOAT/DOUBLE]) -> DOUBLE */ - private static FunctionResolver log() { + private static DefaultFunctionResolver log() { ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); @@ -253,7 +253,7 @@ private static FunctionResolver log() { * Definition of log10(x) function. Calculate base-10 logarithm of x The supported signature of * log function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver log10() { + private static DefaultFunctionResolver log10() { return FunctionDSL.define(BuiltinFunctionName.LOG10.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -265,7 +265,7 @@ private static FunctionResolver log10() { * Definition of log2(x) function. Calculate base-2 logarithm of x The supported signature of log * function is SHORT/INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver log2() { + private static DefaultFunctionResolver log2() { return FunctionDSL.define(BuiltinFunctionName.LOG2.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -280,7 +280,7 @@ private static FunctionResolver log2() { * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) * -> wider type between types of x and y */ - private static FunctionResolver mod() { + private static DefaultFunctionResolver mod() { return FunctionDSL.define(BuiltinFunctionName.MOD.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -321,7 +321,7 @@ private static FunctionResolver mod() { * Get the value of pi. * () -> DOUBLE */ - private static FunctionResolver pi() { + private static DefaultFunctionResolver pi() { return FunctionDSL.define(BuiltinFunctionName.PI.getName(), FunctionDSL.impl(() -> new ExprDoubleValue(Math.PI), DOUBLE) ); @@ -336,11 +336,11 @@ private static FunctionResolver pi() { * (FLOAT, FLOAT) -> DOUBLE * (DOUBLE, DOUBLE) -> DOUBLE */ - private static FunctionResolver pow() { + private static DefaultFunctionResolver pow() { return FunctionDSL.define(BuiltinFunctionName.POW.getName(), powerFunctionImpl()); } - private static FunctionResolver power() { + private static DefaultFunctionResolver power() { return FunctionDSL.define(BuiltinFunctionName.POWER.getName(), powerFunctionImpl()); } @@ -378,7 +378,7 @@ FunctionBuilder>>> powerFunctionImpl() { * The supported signature of rand function is * ([INTEGER]) -> FLOAT */ - private static FunctionResolver rand() { + private static DefaultFunctionResolver rand() { return FunctionDSL.define(BuiltinFunctionName.RAND.getName(), FunctionDSL.impl(() -> new ExprFloatValue(new Random().nextFloat()), FLOAT), FunctionDSL.impl( @@ -396,7 +396,7 @@ private static FunctionResolver rand() { * (x: FLOAT [, y: INTEGER]) -> FLOAT * (x: DOUBLE [, y: INTEGER]) -> DOUBLE */ - private static FunctionResolver round() { + private static DefaultFunctionResolver round() { return FunctionDSL.define(BuiltinFunctionName.ROUND.getName(), // rand(x) FunctionDSL.impl( @@ -448,7 +448,7 @@ private static FunctionResolver round() { * The supported signature is * SHORT/INTEGER/LONG/FLOAT/DOUBLE -> INTEGER */ - private static FunctionResolver sign() { + private static DefaultFunctionResolver sign() { return FunctionDSL.define(BuiltinFunctionName.SIGN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -462,7 +462,7 @@ private static FunctionResolver sign() { * The supported signature is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver sqrt() { + private static DefaultFunctionResolver sqrt() { return FunctionDSL.define(BuiltinFunctionName.SQRT.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -480,7 +480,7 @@ private static FunctionResolver sqrt() { * (x: FLOAT, y: INTEGER) -> DOUBLE * (x: DOUBLE, y: INTEGER) -> DOUBLE */ - private static FunctionResolver truncate() { + private static DefaultFunctionResolver truncate() { return FunctionDSL.define(BuiltinFunctionName.TRUNCATE.getName(), FunctionDSL.impl( FunctionDSL.nullMissingHandling( @@ -515,7 +515,7 @@ private static FunctionResolver truncate() { * The supported signature of acos function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver acos() { + private static DefaultFunctionResolver acos() { return FunctionDSL.define(BuiltinFunctionName.ACOS.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -531,7 +531,7 @@ private static FunctionResolver acos() { * The supported signature of asin function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver asin() { + private static DefaultFunctionResolver asin() { return FunctionDSL.define(BuiltinFunctionName.ASIN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -548,7 +548,7 @@ private static FunctionResolver asin() { * The supported signature of atan function is * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE */ - private static FunctionResolver atan() { + private static DefaultFunctionResolver atan() { ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); @@ -571,7 +571,7 @@ private static FunctionResolver atan() { * The supported signature of atan2 function is * (x: INTEGER/LONG/FLOAT/DOUBLE, y: INTEGER/LONG/FLOAT/DOUBLE) -> DOUBLE */ - private static FunctionResolver atan2() { + private static DefaultFunctionResolver atan2() { ImmutableList.Builder>> builder = new ImmutableList.Builder<>(); @@ -590,7 +590,7 @@ private static FunctionResolver atan2() { * The supported signature of cos function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver cos() { + private static DefaultFunctionResolver cos() { return FunctionDSL.define(BuiltinFunctionName.COS.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -604,7 +604,7 @@ private static FunctionResolver cos() { * The supported signature of cot function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver cot() { + private static DefaultFunctionResolver cot() { return FunctionDSL.define(BuiltinFunctionName.COT.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -625,7 +625,7 @@ private static FunctionResolver cot() { * The supported signature of degrees function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver degrees() { + private static DefaultFunctionResolver degrees() { return FunctionDSL.define(BuiltinFunctionName.DEGREES.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -639,7 +639,7 @@ private static FunctionResolver degrees() { * The supported signature of radians function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver radians() { + private static DefaultFunctionResolver radians() { return FunctionDSL.define(BuiltinFunctionName.RADIANS.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -653,7 +653,7 @@ private static FunctionResolver radians() { * The supported signature of sin function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver sin() { + private static DefaultFunctionResolver sin() { return FunctionDSL.define(BuiltinFunctionName.SIN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( @@ -667,7 +667,7 @@ private static FunctionResolver sin() { * The supported signature of tan function is * INTEGER/LONG/FLOAT/DOUBLE -> DOUBLE */ - private static FunctionResolver tan() { + private static DefaultFunctionResolver tan() { return FunctionDSL.define(BuiltinFunctionName.TAN.getName(), ExprCoreType.numberTypes().stream() .map(type -> FunctionDSL.impl(FunctionDSL.nullMissingHandling( diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java index 171563e0a3..23508406ac 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/convert/TypeCastOperator.java @@ -39,8 +39,8 @@ import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; -import org.opensearch.sql.expression.function.FunctionResolver; @UtilityClass public class TypeCastOperator { @@ -63,7 +63,7 @@ public static void register(BuiltinFunctionRepository repository) { } - private static FunctionResolver castToString() { + private static DefaultFunctionResolver castToString() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_STRING.getName(), Stream.concat( Arrays.asList(BYTE, SHORT, INTEGER, LONG, FLOAT, DOUBLE, BOOLEAN, TIME, DATE, @@ -76,7 +76,7 @@ private static FunctionResolver castToString() { ); } - private static FunctionResolver castToByte() { + private static DefaultFunctionResolver castToByte() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BYTE.getName(), impl(nullMissingHandling( (v) -> new ExprByteValue(Byte.valueOf(v.stringValue()))), BYTE, STRING), @@ -87,7 +87,7 @@ private static FunctionResolver castToByte() { ); } - private static FunctionResolver castToShort() { + private static DefaultFunctionResolver castToShort() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_SHORT.getName(), impl(nullMissingHandling( (v) -> new ExprShortValue(Short.valueOf(v.stringValue()))), SHORT, STRING), @@ -98,7 +98,7 @@ private static FunctionResolver castToShort() { ); } - private static FunctionResolver castToInt() { + private static DefaultFunctionResolver castToInt() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_INT.getName(), impl(nullMissingHandling( (v) -> new ExprIntegerValue(Integer.valueOf(v.stringValue()))), INTEGER, STRING), @@ -109,7 +109,7 @@ private static FunctionResolver castToInt() { ); } - private static FunctionResolver castToLong() { + private static DefaultFunctionResolver castToLong() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_LONG.getName(), impl(nullMissingHandling( (v) -> new ExprLongValue(Long.valueOf(v.stringValue()))), LONG, STRING), @@ -120,7 +120,7 @@ private static FunctionResolver castToLong() { ); } - private static FunctionResolver castToFloat() { + private static DefaultFunctionResolver castToFloat() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_FLOAT.getName(), impl(nullMissingHandling( (v) -> new ExprFloatValue(Float.valueOf(v.stringValue()))), FLOAT, STRING), @@ -131,7 +131,7 @@ private static FunctionResolver castToFloat() { ); } - private static FunctionResolver castToDouble() { + private static DefaultFunctionResolver castToDouble() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DOUBLE.getName(), impl(nullMissingHandling( (v) -> new ExprDoubleValue(Double.valueOf(v.stringValue()))), DOUBLE, STRING), @@ -142,7 +142,7 @@ private static FunctionResolver castToDouble() { ); } - private static FunctionResolver castToBoolean() { + private static DefaultFunctionResolver castToBoolean() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_BOOLEAN.getName(), impl(nullMissingHandling( (v) -> ExprBooleanValue.of(Boolean.valueOf(v.stringValue()))), BOOLEAN, STRING), @@ -152,7 +152,7 @@ private static FunctionResolver castToBoolean() { ); } - private static FunctionResolver castToDate() { + private static DefaultFunctionResolver castToDate() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATE.getName(), impl(nullMissingHandling( (v) -> new ExprDateValue(v.stringValue())), DATE, STRING), @@ -164,7 +164,7 @@ private static FunctionResolver castToDate() { ); } - private static FunctionResolver castToTime() { + private static DefaultFunctionResolver castToTime() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIME.getName(), impl(nullMissingHandling( (v) -> new ExprTimeValue(v.stringValue())), TIME, STRING), @@ -176,7 +176,7 @@ private static FunctionResolver castToTime() { ); } - private static FunctionResolver castToTimestamp() { + private static DefaultFunctionResolver castToTimestamp() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_TIMESTAMP.getName(), impl(nullMissingHandling( (v) -> new ExprTimestampValue(v.stringValue())), TIMESTAMP, STRING), @@ -186,7 +186,7 @@ private static FunctionResolver castToTimestamp() { ); } - private static FunctionResolver castToDatetime() { + private static DefaultFunctionResolver castToDatetime() { return FunctionDSL.define(BuiltinFunctionName.CAST_TO_DATETIME.getName(), impl(nullMissingHandling( (v) -> new ExprDatetimeValue(v.stringValue())), DATETIME, STRING), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java index 4caed12cae..99399249c2 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/BinaryPredicateOperator.java @@ -23,8 +23,8 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionDSL; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.utils.OperatorUtils; /** @@ -140,25 +140,25 @@ public static void register(BuiltinFunctionRepository repository) { .put(LITERAL_MISSING, LITERAL_MISSING, LITERAL_MISSING) .build(); - private static FunctionResolver and() { + private static DefaultFunctionResolver and() { return FunctionDSL.define(BuiltinFunctionName.AND.getName(), FunctionDSL .impl((v1, v2) -> lookupTableFunction(v1, v2, andTable), BOOLEAN, BOOLEAN, BOOLEAN)); } - private static FunctionResolver or() { + private static DefaultFunctionResolver or() { return FunctionDSL.define(BuiltinFunctionName.OR.getName(), FunctionDSL .impl((v1, v2) -> lookupTableFunction(v1, v2, orTable), BOOLEAN, BOOLEAN, BOOLEAN)); } - private static FunctionResolver xor() { + private static DefaultFunctionResolver xor() { return FunctionDSL.define(BuiltinFunctionName.XOR.getName(), FunctionDSL .impl((v1, v2) -> lookupTableFunction(v1, v2, xorTable), BOOLEAN, BOOLEAN, BOOLEAN)); } - private static FunctionResolver equal() { + private static DefaultFunctionResolver equal() { return FunctionDSL.define(BuiltinFunctionName.EQUAL.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL.impl( @@ -168,7 +168,7 @@ private static FunctionResolver equal() { Collectors.toList())); } - private static FunctionResolver notEqual() { + private static DefaultFunctionResolver notEqual() { return FunctionDSL .define(BuiltinFunctionName.NOTEQUAL.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -182,7 +182,7 @@ private static FunctionResolver notEqual() { Collectors.toList())); } - private static FunctionResolver less() { + private static DefaultFunctionResolver less() { return FunctionDSL .define(BuiltinFunctionName.LESS.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -194,7 +194,7 @@ private static FunctionResolver less() { Collectors.toList())); } - private static FunctionResolver lte() { + private static DefaultFunctionResolver lte() { return FunctionDSL .define(BuiltinFunctionName.LTE.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -208,7 +208,7 @@ private static FunctionResolver lte() { Collectors.toList())); } - private static FunctionResolver greater() { + private static DefaultFunctionResolver greater() { return FunctionDSL .define(BuiltinFunctionName.GREATER.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -219,7 +219,7 @@ private static FunctionResolver greater() { Collectors.toList())); } - private static FunctionResolver gte() { + private static DefaultFunctionResolver gte() { return FunctionDSL .define(BuiltinFunctionName.GTE.getName(), ExprCoreType.coreTypes().stream() .map(type -> FunctionDSL @@ -232,19 +232,19 @@ private static FunctionResolver gte() { Collectors.toList())); } - private static FunctionResolver like() { + private static DefaultFunctionResolver like() { return FunctionDSL.define(BuiltinFunctionName.LIKE.getName(), FunctionDSL .impl(FunctionDSL.nullMissingHandling(OperatorUtils::matches), BOOLEAN, STRING, STRING)); } - private static FunctionResolver regexp() { + private static DefaultFunctionResolver regexp() { return FunctionDSL.define(BuiltinFunctionName.REGEXP.getName(), FunctionDSL .impl(FunctionDSL.nullMissingHandling(OperatorUtils::matchesRegexp), INTEGER, STRING, STRING)); } - private static FunctionResolver notLike() { + private static DefaultFunctionResolver notLike() { return FunctionDSL.define(BuiltinFunctionName.NOT_LIKE.getName(), FunctionDSL .impl(FunctionDSL.nullMissingHandling( (v1, v2) -> UnaryPredicateOperator.not(OperatorUtils.matches(v1, v2))), diff --git a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java index ca228a6a7e..7d79d9d923 100644 --- a/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java +++ b/core/src/main/java/org/opensearch/sql/expression/operator/predicate/UnaryPredicateOperator.java @@ -20,10 +20,10 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionDSL; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.SerializableFunction; @@ -46,7 +46,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(ifFunction()); } - private static FunctionResolver not() { + private static DefaultFunctionResolver not() { return FunctionDSL.define(BuiltinFunctionName.NOT.getName(), FunctionDSL .impl(UnaryPredicateOperator::not, BOOLEAN, BOOLEAN)); } @@ -67,7 +67,7 @@ public ExprValue not(ExprValue v) { } } - private static FunctionResolver isNull(BuiltinFunctionName funcName) { + private static DefaultFunctionResolver isNull(BuiltinFunctionName funcName) { return FunctionDSL .define(funcName.getName(), Arrays.stream(ExprCoreType.values()) .map(type -> FunctionDSL @@ -76,7 +76,7 @@ private static FunctionResolver isNull(BuiltinFunctionName funcName) { Collectors.toList())); } - private static FunctionResolver isNotNull() { + private static DefaultFunctionResolver isNotNull() { return FunctionDSL .define(BuiltinFunctionName.IS_NOT_NULL.getName(), Arrays.stream(ExprCoreType.values()) .map(type -> FunctionDSL @@ -85,7 +85,7 @@ private static FunctionResolver isNotNull() { Collectors.toList())); } - private static FunctionResolver ifFunction() { + private static DefaultFunctionResolver ifFunction() { FunctionName functionName = BuiltinFunctionName.IF.getName(); List typeList = ExprCoreType.coreTypes(); @@ -94,11 +94,11 @@ private static FunctionResolver ifFunction() { impl((UnaryPredicateOperator::exprIf), v, BOOLEAN, v, v)) .collect(Collectors.toList()); - FunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); return functionResolver; } - private static FunctionResolver ifNull() { + private static DefaultFunctionResolver ifNull() { FunctionName functionName = BuiltinFunctionName.IFNULL.getName(); List typeList = ExprCoreType.coreTypes(); @@ -107,15 +107,15 @@ private static FunctionResolver ifNull() { impl((UnaryPredicateOperator::exprIfNull), v, v, v)) .collect(Collectors.toList()); - FunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, functionsOne); return functionResolver; } - private static FunctionResolver nullIf() { + private static DefaultFunctionResolver nullIf() { FunctionName functionName = BuiltinFunctionName.NULLIF.getName(); List typeList = ExprCoreType.coreTypes(); - FunctionResolver functionResolver = + DefaultFunctionResolver functionResolver = FunctionDSL.define(functionName, typeList.stream().map(v -> impl((UnaryPredicateOperator::exprNullIf), v, v, v)) @@ -124,6 +124,7 @@ private static FunctionResolver nullIf() { } /** v2 if v1 is null. + * * @param v1 varable 1 * @param v2 varable 2 * @return v2 if v1 is null @@ -133,6 +134,7 @@ public static ExprValue exprIfNull(ExprValue v1, ExprValue v2) { } /** return null if v1 equls to v2. + * * @param v1 varable 1 * @param v2 varable 2 * @return null if v1 equls to v2 diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 372540b4e9..8035728d19 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -18,8 +18,8 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.SerializableBiFunction; import org.opensearch.sql.expression.function.SerializableTriFunction; @@ -63,7 +63,7 @@ public void register(BuiltinFunctionRepository repository) { * Supports following signatures: * (STRING, INTEGER)/(STRING, INTEGER, INTEGER) -> STRING */ - private FunctionResolver substringSubstr(FunctionName functionName) { + private DefaultFunctionResolver substringSubstr(FunctionName functionName) { return define(functionName, impl(nullMissingHandling(TextFunction::exprSubstrStart), STRING, STRING, INTEGER), @@ -71,11 +71,11 @@ private FunctionResolver substringSubstr(FunctionName functionName) { STRING, STRING, INTEGER, INTEGER)); } - private FunctionResolver substring() { + private DefaultFunctionResolver substring() { return substringSubstr(BuiltinFunctionName.SUBSTRING.getName()); } - private FunctionResolver substr() { + private DefaultFunctionResolver substr() { return substringSubstr(BuiltinFunctionName.SUBSTR.getName()); } @@ -84,7 +84,7 @@ private FunctionResolver substr() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver ltrim() { + private DefaultFunctionResolver ltrim() { return define(BuiltinFunctionName.LTRIM.getName(), impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripLeading())), STRING, STRING)); @@ -95,7 +95,7 @@ private FunctionResolver ltrim() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver rtrim() { + private DefaultFunctionResolver rtrim() { return define(BuiltinFunctionName.RTRIM.getName(), impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().stripTrailing())), STRING, STRING)); @@ -108,7 +108,7 @@ private FunctionResolver rtrim() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver trim() { + private DefaultFunctionResolver trim() { return define(BuiltinFunctionName.TRIM.getName(), impl(nullMissingHandling((v) -> new ExprStringValue(v.stringValue().trim())), STRING, STRING)); @@ -119,7 +119,7 @@ private FunctionResolver trim() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver lower() { + private DefaultFunctionResolver lower() { return define(BuiltinFunctionName.LOWER.getName(), impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toLowerCase()))), STRING, STRING) @@ -131,7 +131,7 @@ private FunctionResolver lower() { * Supports following signatures: * STRING -> STRING */ - private FunctionResolver upper() { + private DefaultFunctionResolver upper() { return define(BuiltinFunctionName.UPPER.getName(), impl(nullMissingHandling((v) -> new ExprStringValue((v.stringValue().toUpperCase()))), STRING, STRING) @@ -145,7 +145,7 @@ private FunctionResolver upper() { * Supports following signatures: * (STRING, STRING) -> STRING */ - private FunctionResolver concat() { + private DefaultFunctionResolver concat() { return define(BuiltinFunctionName.CONCAT.getName(), impl(nullMissingHandling((str1, str2) -> new ExprStringValue(str1.stringValue() + str2.stringValue())), STRING, STRING, STRING)); @@ -158,7 +158,7 @@ private FunctionResolver concat() { * Supports following signatures: * (STRING, STRING, STRING) -> STRING */ - private FunctionResolver concat_ws() { + private DefaultFunctionResolver concat_ws() { return define(BuiltinFunctionName.CONCAT_WS.getName(), impl(nullMissingHandling((sep, str1, str2) -> new ExprStringValue(str1.stringValue() + sep.stringValue() + str2.stringValue())), @@ -170,7 +170,7 @@ private FunctionResolver concat_ws() { * Supports following signatures: * STRING -> INTEGER */ - private FunctionResolver length() { + private DefaultFunctionResolver length() { return define(BuiltinFunctionName.LENGTH.getName(), impl(nullMissingHandling((str) -> new ExprIntegerValue(str.stringValue().getBytes().length)), INTEGER, STRING)); @@ -181,7 +181,7 @@ private FunctionResolver length() { * Supports following signatures: * (STRING, STRING) -> INTEGER */ - private FunctionResolver strcmp() { + private DefaultFunctionResolver strcmp() { return define(BuiltinFunctionName.STRCMP.getName(), impl(nullMissingHandling((str1, str2) -> new ExprIntegerValue(Integer.compare( @@ -194,7 +194,7 @@ private FunctionResolver strcmp() { * Supports following signatures: * (STRING, INTEGER) -> STRING */ - private FunctionResolver right() { + private DefaultFunctionResolver right() { return define(BuiltinFunctionName.RIGHT.getName(), impl(nullMissingHandling(TextFunction::exprRight), STRING, STRING, INTEGER)); } @@ -204,7 +204,7 @@ private FunctionResolver right() { * Supports following signature: * (STRING, INTEGER) -> STRING */ - private FunctionResolver left() { + private DefaultFunctionResolver left() { return define(BuiltinFunctionName.LEFT.getName(), impl(nullMissingHandling(TextFunction::exprLeft), STRING, STRING, INTEGER)); } @@ -216,7 +216,7 @@ private FunctionResolver left() { * Supports following signature: * STRING -> INTEGER */ - private FunctionResolver ascii() { + private DefaultFunctionResolver ascii() { return define(BuiltinFunctionName.ASCII.getName(), impl(nullMissingHandling(TextFunction::exprAscii), INTEGER, STRING)); } @@ -231,7 +231,7 @@ private FunctionResolver ascii() { * (STRING, STRING) -> INTEGER * (STRING, STRING, INTEGER) -> INTEGER */ - private FunctionResolver locate() { + private DefaultFunctionResolver locate() { return define(BuiltinFunctionName.LOCATE.getName(), impl(nullMissingHandling( (SerializableBiFunction) @@ -248,7 +248,7 @@ private FunctionResolver locate() { * Supports following signature: * (STRING, STRING, STRING) -> STRING */ - private FunctionResolver replace() { + private DefaultFunctionResolver replace() { return define(BuiltinFunctionName.REPLACE.getName(), impl(nullMissingHandling(TextFunction::exprReplace), STRING, STRING, STRING, STRING)); } diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java index 2851dd9f6b..a3baf08ff3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java @@ -13,9 +13,9 @@ import lombok.experimental.UtilityClass; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.DefaultFunctionResolver; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.window.ranking.DenseRankFunction; import org.opensearch.sql.expression.window.ranking.RankFunction; @@ -30,6 +30,7 @@ public class WindowFunctions { /** * Register all window functions to function repository. + * * @param repository function repository */ public void register(BuiltinFunctionRepository repository) { @@ -38,23 +39,24 @@ public void register(BuiltinFunctionRepository repository) { repository.register(denseRank()); } - private FunctionResolver rowNumber() { + private DefaultFunctionResolver rowNumber() { return rankingFunction(BuiltinFunctionName.ROW_NUMBER.getName(), RowNumberFunction::new); } - private FunctionResolver rank() { + private DefaultFunctionResolver rank() { return rankingFunction(BuiltinFunctionName.RANK.getName(), RankFunction::new); } - private FunctionResolver denseRank() { + private DefaultFunctionResolver denseRank() { return rankingFunction(BuiltinFunctionName.DENSE_RANK.getName(), DenseRankFunction::new); } - private FunctionResolver rankingFunction(FunctionName functionName, - Supplier constructor) { + private DefaultFunctionResolver rankingFunction(FunctionName functionName, + Supplier constructor) { FunctionSignature functionSignature = new FunctionSignature(functionName, emptyList()); FunctionBuilder functionBuilder = arguments -> constructor.get(); - return new FunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); + return new DefaultFunctionResolver(functionName, + ImmutableMap.of(functionSignature, functionBuilder)); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 72db402552..c8ce70c418 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.field; +import static org.opensearch.sql.ast.dsl.AstDSL.floatLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.function; import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; @@ -355,6 +356,14 @@ void match_bool_prefix_expression() { AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } + @Test + void match_bool_prefix_wrong_expression() { + assertThrows(SemanticCheckException.class, + () -> analyze(AstDSL.function("match_bool_prefix", + AstDSL.unresolvedArg("field", stringLiteral("fieldA")), + AstDSL.unresolvedArg("query", floatLiteral(1.2f))))); + } + @Test void visit_span() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java index eca6408d17..61cc560670 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java @@ -47,7 +47,7 @@ @ExtendWith(MockitoExtension.class) class BuiltinFunctionRepositoryTest { @Mock - private FunctionResolver mockfunctionResolver; + private DefaultFunctionResolver mockfunctionResolver; @Mock private Map mockMap; @Mock @@ -182,7 +182,7 @@ private FunctionSignature registerFunctionResolver(FunctionName funcName, FunctionSignature resolvedSignature = new FunctionSignature( funcName, ImmutableList.of(targetType)); - FunctionResolver funcResolver = mock(FunctionResolver.class); + DefaultFunctionResolver funcResolver = mock(DefaultFunctionResolver.class); FunctionBuilder funcBuilder = mock(FunctionBuilder.class); when(mockMap.containsKey(eq(funcName))).thenReturn(true); diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java similarity index 90% rename from core/src/test/java/org/opensearch/sql/expression/function/FunctionResolverTest.java rename to core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java index 141c1fbd54..baa299b60b 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/FunctionResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/DefaultFunctionResolverTest.java @@ -22,7 +22,7 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class FunctionResolverTest { +class DefaultFunctionResolverTest { @Mock private FunctionSignature exactlyMatchFS; @Mock @@ -47,7 +47,7 @@ class FunctionResolverTest { @Test void resolve_function_signature_exactly_match() { when(functionSignature.match(exactlyMatchFS)).thenReturn(WideningTypeRule.TYPE_EQUAL); - FunctionResolver resolver = new FunctionResolver(functionName, + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(exactlyMatchFS, exactlyMatchBuilder)); assertEquals(exactlyMatchBuilder, resolver.resolve(functionSignature).getValue()); @@ -57,7 +57,7 @@ void resolve_function_signature_exactly_match() { void resolve_function_signature_best_match() { when(functionSignature.match(bestMatchFS)).thenReturn(1); when(functionSignature.match(leastMatchFS)).thenReturn(2); - FunctionResolver resolver = new FunctionResolver(functionName, + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(bestMatchFS, bestMatchBuilder, leastMatchFS, leastMatchBuilder)); assertEquals(bestMatchBuilder, resolver.resolve(functionSignature).getValue()); @@ -68,7 +68,7 @@ void resolve_function_not_match() { when(functionSignature.match(notMatchFS)).thenReturn(WideningTypeRule.IMPOSSIBLE_WIDENING); when(notMatchFS.formatTypes()).thenReturn("[INTEGER,INTEGER]"); when(functionSignature.formatTypes()).thenReturn("[BOOLEAN,BOOLEAN]"); - FunctionResolver resolver = new FunctionResolver(functionName, + DefaultFunctionResolver resolver = new DefaultFunctionResolver(functionName, ImmutableMap.of(notMatchFS, notMatchBuilder)); ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, diff --git a/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java b/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java new file mode 100644 index 0000000000..d8547057c4 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/RelevanceFunctionResolverTest.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.exception.SemanticCheckException; + +class RelevanceFunctionResolverTest { + private final FunctionName sampleFuncName = FunctionName.of("sample_function"); + private RelevanceFunctionResolver resolver; + + @BeforeEach + void setUp() { + resolver = new RelevanceFunctionResolver(sampleFuncName, STRING); + } + + @Test + void resolve_correct_name_test() { + var sig = new FunctionSignature(sampleFuncName, List.of(STRING)); + Pair builderPair = resolver.resolve(sig); + assertEquals(sampleFuncName, builderPair.getKey().getFunctionName()); + } + + @Test + void resolve_invalid_name_test() { + var wrongFuncName = FunctionName.of("wrong_func"); + var sig = new FunctionSignature(wrongFuncName, List.of(STRING)); + Exception exception = assertThrows(SemanticCheckException.class, + () -> resolver.resolve(sig)); + assertEquals("Expected 'sample_function' but got 'wrong_func'", + exception.getMessage()); + } + + @Test + void resolve_invalid_first_param_type_test() { + var sig = new FunctionSignature(sampleFuncName, List.of(INTEGER)); + Exception exception = assertThrows(SemanticCheckException.class, + () -> resolver.resolve(sig)); + assertEquals("Expected type STRING instead of INTEGER for parameter #1", + exception.getMessage()); + } + + @Test + void resolve_invalid_third_param_type_test() { + var sig = new FunctionSignature(sampleFuncName, List.of(STRING, STRING, INTEGER, STRING)); + Exception exception = assertThrows(SemanticCheckException.class, + () -> resolver.resolve(sig)); + assertEquals("Expected type STRING instead of INTEGER for parameter #3", + exception.getMessage()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java index 754a09259d..33e357afe3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchBoolPrefixQuery.java @@ -14,7 +14,7 @@ * Initializes MatchBoolPrefixQueryBuilder from a FunctionExpression. */ public class MatchBoolPrefixQuery - extends RelevanceQuery { + extends SingleFieldQuery { /** * Constructor for MatchBoolPrefixQuery to configure RelevanceQuery * with support of optional parameters. @@ -41,7 +41,12 @@ public MatchBoolPrefixQuery() { * @return Object of executed query */ @Override - protected MatchBoolPrefixQueryBuilder createQueryBuilder(String field, String query) { + protected MatchBoolPrefixQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchBoolPrefixQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchBoolPrefixQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java index b8d0d4f18d..6d181daa4c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhrasePrefixQuery.java @@ -12,7 +12,7 @@ /** * Lucene query that builds a match_phrase_prefix query. */ -public class MatchPhrasePrefixQuery extends RelevanceQuery { +public class MatchPhrasePrefixQuery extends SingleFieldQuery { /** * Default constructor for MatchPhrasePrefixQuery configures how RelevanceQuery.build() handles * named arguments. @@ -29,7 +29,12 @@ public MatchPhrasePrefixQuery() { } @Override - protected MatchPhrasePrefixQueryBuilder createQueryBuilder(String field, String query) { + protected MatchPhrasePrefixQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchPhrasePrefixQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchPhrasePrefixQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java index 333d8eff89..6a7694f629 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchPhraseQuery.java @@ -23,7 +23,7 @@ /** * Lucene query that builds a match_phrase query. */ -public class MatchPhraseQuery extends RelevanceQuery { +public class MatchPhraseQuery extends SingleFieldQuery { /** * Default constructor for MatchPhraseQuery configures how RelevanceQuery.build() handles * named arguments. @@ -39,7 +39,12 @@ public MatchPhraseQuery() { } @Override - protected MatchPhraseQueryBuilder createQueryBuilder(String field, String query) { + protected MatchPhraseQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchPhraseQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchPhraseQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java index 4095ffba4e..f6d88013e4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MatchQuery.java @@ -6,7 +6,6 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; import com.google.common.collect.ImmutableMap; -import java.util.Map; import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.Operator; import org.opensearch.index.query.QueryBuilders; @@ -14,7 +13,7 @@ /** * Initializes MatchQueryBuilder from a FunctionExpression. */ -public class MatchQuery extends RelevanceQuery { +public class MatchQuery extends SingleFieldQuery { /** * Default constructor for MatchQuery configures how RelevanceQuery.build() handles * named arguments. @@ -40,7 +39,12 @@ public MatchQuery() { } @Override - protected MatchQueryBuilder createQueryBuilder(String field, String query) { + protected MatchQueryBuilder createBuilder(String field, String query) { return QueryBuilders.matchQuery(field, query); } + + @Override + protected String getQueryName() { + return MatchQueryBuilder.NAME; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java new file mode 100644 index 0000000000..b447f2ffe2 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQuery.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.expression.NamedArgumentExpression; + +/** + * Base class to represent relevance queries that search multiple fields. + * @param The builder class for the OpenSearch query. + */ +abstract class MultiFieldQuery extends RelevanceQuery { + + public MultiFieldQuery(Map> queryBuildActions) { + super(queryBuildActions); + } + + @Override + public T createQueryBuilder(NamedArgumentExpression fields, NamedArgumentExpression queryExpr) { + var fieldsAndWeights = fields + .getValue() + .valueOf(null) + .tupleValue() + .entrySet() + .stream() + .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); + var query = queryExpr.getValue().valueOf(null).stringValue(); + return createBuilder(fieldsAndWeights, query); + } + + protected abstract T createBuilder(ImmutableMap fields, String query); +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java index 524d42f0b6..549f58cb19 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiMatchQuery.java @@ -6,18 +6,11 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; import com.google.common.collect.ImmutableMap; -import java.util.Iterator; -import java.util.Objects; import org.opensearch.index.query.MultiMatchQueryBuilder; import org.opensearch.index.query.Operator; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -public class MultiMatchQuery extends RelevanceQuery { +public class MultiMatchQuery extends MultiFieldQuery { /** * Default constructor for MultiMatch configures how RelevanceQuery.build() handles * named arguments. @@ -46,43 +39,12 @@ public MultiMatchQuery() { } @Override - public QueryBuilder build(FunctionExpression func) { - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("'multi_match' must have at least two arguments"); - } - Iterator iterator = func.getArguments().iterator(); - var fields = (NamedArgumentExpression) iterator.next(); - var query = (NamedArgumentExpression) iterator.next(); - // Fields is a map already, but we need to convert types. - var fieldsAndWeights = fields - .getValue() - .valueOf(null) - .tupleValue() - .entrySet() - .stream() - .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); - - MultiMatchQueryBuilder queryBuilder = createQueryBuilder(null, - query.getValue().valueOf(null).stringValue()) - .fields(fieldsAndWeights); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - String argNormalized = arg.getArgName().toLowerCase(); - if (!queryBuildActions.containsKey(argNormalized)) { - throw new SemanticCheckException( - String.format("Parameter %s is invalid for %s function.", - argNormalized, queryBuilder.getWriteableName())); - } - (Objects.requireNonNull( - queryBuildActions - .get(argNormalized))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected MultiMatchQueryBuilder createBuilder(ImmutableMap fields, String query) { + return QueryBuilders.multiMatchQuery(query).fields(fields); } @Override - protected MultiMatchQueryBuilder createQueryBuilder(String field, String query) { - return QueryBuilders.multiMatchQuery(query); + protected String getQueryName() { + return MultiMatchQueryBuilder.NAME; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java index 54ffea6158..21eb3f8837 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/QueryStringQuery.java @@ -23,7 +23,7 @@ /** * Class for Lucene query that builds the query_string query. */ -public class QueryStringQuery extends RelevanceQuery { +public class QueryStringQuery extends MultiFieldQuery { /** * Default constructor for QueryString configures how RelevanceQuery.build() handles * named arguments. @@ -66,55 +66,22 @@ public QueryStringQuery() { .build()); } - /** - * Override base build function for multi-field query support. - * @param func function : 'query_string' function - * @return : QueryBuilder for query_string query - */ - @Override - public QueryBuilder build(FunctionExpression func) { - Iterator iterator = func.getArguments().iterator(); - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("'query_string' must have at least two arguments"); - } - NamedArgumentExpression fields = (NamedArgumentExpression) iterator.next(); - NamedArgumentExpression query = (NamedArgumentExpression) iterator.next(); - // Fields is a map already, but we need to convert types. - var fieldsAndWeights = fields - .getValue() - .valueOf(null) - .tupleValue() - .entrySet() - .stream() - .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); - - QueryStringQueryBuilder queryBuilder = createQueryBuilder(null, - query.getValue().valueOf(null).stringValue()) - .fields(fieldsAndWeights); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - String argNormalized = arg.getArgName().toLowerCase(); - if (!queryBuildActions.containsKey(argNormalized)) { - throw new SemanticCheckException( - String.format("Parameter %s is invalid for %s function.", - argNormalized, queryBuilder.getWriteableName())); - } - (Objects.requireNonNull( - queryBuildActions - .get(argNormalized))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; - } /** * Builds QueryBuilder with query value and other default parameter values set. - * @param field : Field value in query_string query + * + * @param fields : A map of field names and their boost values * @param query : Query value for query_string query * @return : Builder for query_string query */ @Override - protected QueryStringQueryBuilder createQueryBuilder(String field, String query) { - return QueryBuilders.queryStringQuery(query); + protected QueryStringQueryBuilder createBuilder(ImmutableMap fields, + String query) { + return QueryBuilders.queryStringQuery(query).fields(fields); + } + + @Override + protected String getQueryName() { + return QueryStringQueryBuilder.NAME; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java index fb997646f4..282c5478b4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQuery.java @@ -5,11 +5,14 @@ package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.BiFunction; +import lombok.RequiredArgsConstructor; import org.opensearch.index.query.QueryBuilder; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprValue; @@ -22,31 +25,33 @@ /** * Base class for query abstraction that builds a relevance query from function expression. */ +@RequiredArgsConstructor public abstract class RelevanceQuery extends LuceneQuery { - protected Map> queryBuildActions; - - protected RelevanceQuery(Map> actionMap) { - queryBuildActions = actionMap; - } + private final Map> queryBuildActions; @Override public QueryBuilder build(FunctionExpression func) { List arguments = func.getArguments(); if (arguments.size() < 2) { - String queryName = createQueryBuilder("dummy_field", "").getWriteableName(); throw new SyntaxCheckException( - String.format("%s requires at least two parameters", queryName)); + String.format("%s requires at least two parameters", getQueryName())); } NamedArgumentExpression field = (NamedArgumentExpression) arguments.get(0); NamedArgumentExpression query = (NamedArgumentExpression) arguments.get(1); - T queryBuilder = createQueryBuilder( - field.getValue().valueOf(null).stringValue(), - query.getValue().valueOf(null).stringValue()); + T queryBuilder = createQueryBuilder(field, query); Iterator iterator = arguments.listIterator(2); + Set visitedParms = new HashSet(); while (iterator.hasNext()) { NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); String argNormalized = arg.getArgName().toLowerCase(); + if (visitedParms.contains(argNormalized)) { + throw new SemanticCheckException(String.format("Parameter '%s' can only be specified once.", + argNormalized)); + } else { + visitedParms.add(argNormalized); + } + if (!queryBuildActions.containsKey(argNormalized)) { throw new SemanticCheckException( String.format("Parameter %s is invalid for %s function.", @@ -60,16 +65,19 @@ public QueryBuilder build(FunctionExpression func) { return queryBuilder; } - protected abstract T createQueryBuilder(String field, String query); + protected abstract T createQueryBuilder(NamedArgumentExpression field, + NamedArgumentExpression query); + + protected abstract String getQueryName(); /** * Convenience interface for a function that updates a QueryBuilder * based on ExprValue. + * * @param Concrete query builder */ - public interface QueryBuilderStep extends + protected interface QueryBuilderStep extends BiFunction { - } public static String valueOfToUpper(ExprValue v) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java index 45637e98a6..1b7c18cb2c 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SimpleQueryStringQuery.java @@ -10,16 +10,11 @@ import java.util.Iterator; import java.util.Objects; import org.opensearch.index.query.Operator; -import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; import org.opensearch.index.query.SimpleQueryStringBuilder; import org.opensearch.index.query.SimpleQueryStringFlag; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; -public class SimpleQueryStringQuery extends RelevanceQuery { +public class SimpleQueryStringQuery extends MultiFieldQuery { /** * Default constructor for SimpleQueryString configures how RelevanceQuery.build() handles * named arguments. @@ -48,43 +43,13 @@ public SimpleQueryStringQuery() { } @Override - public QueryBuilder build(FunctionExpression func) { - if (func.getArguments().size() < 2) { - throw new SemanticCheckException("'simple_query_string' must have at least two arguments"); - } - Iterator iterator = func.getArguments().iterator(); - var fields = (NamedArgumentExpression) iterator.next(); - var query = (NamedArgumentExpression) iterator.next(); - // Fields is a map already, but we need to convert types. - var fieldsAndWeights = fields - .getValue() - .valueOf(null) - .tupleValue() - .entrySet() - .stream() - .collect(ImmutableMap.toImmutableMap(e -> e.getKey(), e -> e.getValue().floatValue())); - - SimpleQueryStringBuilder queryBuilder = createQueryBuilder(null, - query.getValue().valueOf(null).stringValue()) - .fields(fieldsAndWeights); - while (iterator.hasNext()) { - NamedArgumentExpression arg = (NamedArgumentExpression) iterator.next(); - String argNormalized = arg.getArgName().toLowerCase(); - if (!queryBuildActions.containsKey(argNormalized)) { - throw new SemanticCheckException( - String.format("Parameter %s is invalid for %s function.", - argNormalized, queryBuilder.getWriteableName())); - } - (Objects.requireNonNull( - queryBuildActions - .get(argNormalized))) - .apply(queryBuilder, arg.getValue().valueOf(null)); - } - return queryBuilder; + protected SimpleQueryStringBuilder createBuilder(ImmutableMap fields, + String query) { + return QueryBuilders.simpleQueryStringQuery(query).fields(fields); } @Override - protected SimpleQueryStringBuilder createQueryBuilder(String field, String query) { - return QueryBuilders.simpleQueryStringQuery(query); + protected String getQueryName() { + return SimpleQueryStringBuilder.NAME; } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java new file mode 100644 index 0000000000..9876c62cce --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQuery.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import java.util.Map; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.expression.NamedArgumentExpression; + +/** + * Base class to represent builder class for relevance queries like match_query, match_bool_prefix, + * and match_phrase that search in a single field only. + * + * @param The builder class for the OpenSearch query class. + */ +abstract class SingleFieldQuery extends RelevanceQuery { + public SingleFieldQuery(Map> queryBuildActions) { + super(queryBuildActions); + } + + @Override + protected T createQueryBuilder(NamedArgumentExpression fields, NamedArgumentExpression query) { + return createBuilder( + fields.getValue().valueOf(null).stringValue(), + query.getValue().valueOf(null).stringValue()); + } + + protected abstract T createBuilder(String field, String query); +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index b1efe86d01..75ddd1dd93 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -855,41 +855,6 @@ void match_phrase_invalid_value_ztq() { msg); } - @Test - void match_phrase_missing_field() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.match_phrase( - dsl.namedArgument("query", literal("search query")))).getMessage(); - assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," - + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get [STRING]", - msg); - } - - @Test - void match_phrase_missing_query() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.match_phrase( - dsl.namedArgument("field", literal("message")))).getMessage(); - assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," - + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get [STRING]", - msg); - } - - @Test - void match_phrase_too_many_args() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.match_phrase( - dsl.namedArgument("one", literal("1")), - dsl.namedArgument("two", literal("2")), - dsl.namedArgument("three", literal("3")), - dsl.namedArgument("four", literal("4")), - dsl.namedArgument("fix", literal("5")), - dsl.namedArgument("six", literal("6")) - )).getMessage(); - assertEquals("match_phrase function expected {[STRING,STRING],[STRING,STRING,STRING]," - + "[STRING,STRING,STRING,STRING],[STRING,STRING,STRING,STRING,STRING]}, but get " - + "[STRING,STRING,STRING,STRING,STRING,STRING]", msg); - } @Test @@ -913,55 +878,6 @@ void should_build_match_bool_prefix_query_with_default_parameters() { dsl.namedArgument("query", literal("search query"))))); } - @Test - void multi_match_missing_fields() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.multi_match( - dsl.namedArgument("query", literal("search query")))).getMessage(); - assertEquals("multi_match function expected {[STRUCT,STRING],[STRUCT,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING]}, but get [STRING]", - msg); - } - - @Test - void multi_match_missing_query() { - var msg = assertThrows(ExpressionEvaluationException.class, () -> - dsl.multi_match( - dsl.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field1", ExprValueUtils.floatValue(1.F), - "field2", ExprValueUtils.floatValue(.3F)))))))).getMessage(); - assertEquals("multi_match function expected {[STRUCT,STRING],[STRUCT,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING]," - + "[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING],[STRUCT,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING],[STRUCT,STRING," - + "STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING,STRING," - + "STRING,STRING,STRING,STRING]}, but get [STRUCT]", - msg); - } - @Test void should_build_match_phrase_prefix_query_with_default_parameters() { assertJsonEquals( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java index 00cf3158c4..c30e06bc1a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java @@ -61,8 +61,8 @@ public void test_valid_arguments(List validArgs) { @Test public void test_valid_when_two_arguments() { List arguments = List.of( - namedArgument("field", "field_value"), - namedArgument("query", "query_value")); + dsl.namedArgument("field", "field_value"), + dsl.namedArgument("query", "query_value")); Assertions.assertNotNull(matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -75,7 +75,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(namedArgument("field", "field_value")); + List arguments = List.of(dsl.namedArgument("field", "field_value")); assertThrows(SyntaxCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -83,17 +83,13 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SemanticCheckException_when_invalid_argument() { List arguments = List.of( - namedArgument("field", "field_value"), - namedArgument("query", "query_value"), - namedArgument("unsupported", "unsupported_value")); + dsl.namedArgument("field", "field_value"), + dsl.namedArgument("query", "query_value"), + dsl.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - private class MatchExpression extends FunctionExpression { public MatchExpression(List arguments) { super(MatchBoolPrefixQueryTest.this.matchBoolPrefix, arguments); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java index 4e8895a12a..09e25fe569 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java @@ -20,7 +20,6 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.NamedArgumentExpression; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; @@ -33,10 +32,6 @@ public class MatchPhraseQueryTest { private final MatchPhraseQuery matchPhraseQuery = new MatchPhraseQuery(); private final FunctionName matchPhrase = FunctionName.of("match_phrase"); - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - @Test public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); @@ -46,7 +41,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(namedArgument("field", "test")); + List arguments = List.of(dsl.namedArgument("field", "test")); assertThrows(SyntaxCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -54,9 +49,9 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SyntaxCheckException_when_invalid_parameter() { List arguments = List.of( - namedArgument("field", "test"), - namedArgument("query", "test2"), - namedArgument("unsupported", "3")); + dsl.namedArgument("field", "test"), + dsl.namedArgument("query", "test2"), + dsl.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -64,9 +59,9 @@ public void test_SyntaxCheckException_when_invalid_parameter() { @Test public void test_analyzer_parameter() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("analyzer", "standard") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -74,17 +69,17 @@ public void test_analyzer_parameter() { @Test public void build_succeeds_with_two_arguments() { List arguments = List.of( - namedArgument("field", "test"), - namedArgument("query", "test2")); + dsl.namedArgument("field", "test"), + dsl.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @Test public void test_slop_parameter() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("slop", "2") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -92,9 +87,9 @@ public void test_slop_parameter() { @Test public void test_zero_terms_query_parameter() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("zero_terms_query", "ALL") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -102,9 +97,9 @@ public void test_zero_terms_query_parameter() { @Test public void test_zero_terms_query_parameter_lower_case() { List arguments = List.of( - namedArgument("field", "t1"), - namedArgument("query", "t2"), - namedArgument("zero_terms_query", "all") + dsl.namedArgument("field", "t1"), + dsl.namedArgument("query", "t2"), + dsl.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java index 4a6e1d2ed9..261870ca17 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -137,16 +138,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } @@ -155,15 +156,11 @@ public void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), - namedArgument("unsupported", "unsupported_value")); + dsl.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - private NamedArgumentExpression namedArgument(String name, LiteralExpression value) { return dsl.namedArgument(name, value); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java index fce835bf43..21b03abab0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java @@ -17,6 +17,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -88,16 +89,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java index 048f6e1cb9..8f06f48727 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.data.model.ExprTupleValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; @@ -161,16 +162,16 @@ public void test_valid_parameters(List validArgs) { } @Test - public void test_SemanticCheckException_when_no_arguments() { + public void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> simpleQueryStringQuery.build(new SimpleQueryStringExpression(arguments))); } @Test - public void test_SemanticCheckException_when_one_argument() { + public void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); - assertThrows(SemanticCheckException.class, + assertThrows(SyntaxCheckException.class, () -> simpleQueryStringQuery.build(new SimpleQueryStringExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java new file mode 100644 index 0000000000..7e4c6ea011 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.config.ExpressionConfig; + +class MultiFieldQueryTest { + MultiFieldQuery query; + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + private final String testQueryName = "test_query"; + private final Map actionMap + = ImmutableMap.of("paramA", (o, v) -> o); + + @BeforeEach + public void setUp() { + query = mock(MultiFieldQuery.class, + Mockito.withSettings().useConstructor(actionMap) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + when(query.getQueryName()).thenReturn(testQueryName); + } + + @Test + void createQueryBuilderTest() { + String sampleQuery = "sample query"; + String sampleField = "fieldA"; + float sampleValue = 34f; + + var fieldSpec = ImmutableMap.builder().put(sampleField, + ExprValueUtils.floatValue(sampleValue)).build(); + + query.createQueryBuilder(dsl.namedArgument("fields", + new LiteralExpression(ExprTupleValue.fromExprValueMap(fieldSpec))), + dsl.namedArgument("query", + new LiteralExpression(ExprValueUtils.stringValue(sampleQuery)))); + + verify(query).createBuilder(argThat( + (ArgumentMatcher>) map -> map.size() == 1 + && map.containsKey(sampleField) && map.containsValue(sampleValue)), + eq(sampleQuery)); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java index a67f0f34a7..fa6a43474a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/RelevanceQueryBuildTest.java @@ -30,7 +30,6 @@ import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; @@ -55,14 +54,20 @@ public void setUp() { .defaultAnswer(Mockito.CALLS_REAL_METHODS)); queryBuilder = mock(QueryBuilder.class); when(query.createQueryBuilder(any(), any())).thenReturn(queryBuilder); - when(queryBuilder.queryName()).thenReturn("mocked_query"); - when(queryBuilder.getWriteableName()).thenReturn("mock_query"); + String queryName = "mock_query"; + when(queryBuilder.queryName()).thenReturn(queryName); + when(queryBuilder.getWriteableName()).thenReturn(queryName); + when(query.getQueryName()).thenReturn(queryName); } @Test - void first_arg_field_second_arg_query_test() { - query.build(createCall(List.of(FIELD_ARG, QUERY_ARG))); - verify(query, times(1)).createQueryBuilder("field_A", "find me"); + void throws_SemanticCheckException_when_same_argument_twice() { + FunctionExpression expr = createCall(List.of(FIELD_ARG, QUERY_ARG, + namedArgument("boost", "2.3"), + namedArgument("boost", "2.4"))); + SemanticCheckException exception = + assertThrows(SemanticCheckException.class, () -> query.build(expr)); + assertEquals("Parameter 'boost' can only be specified once.", exception.getMessage()); } @Test @@ -72,7 +77,8 @@ void throws_SemanticCheckException_when_wrong_argument_name() { SemanticCheckException exception = assertThrows(SemanticCheckException.class, () -> query.build(expr)); - assertEquals("Parameter wrongarg is invalid for mock_query function.", exception.getMessage()); + assertEquals("Parameter wrongarg is invalid for mock_query function.", + exception.getMessage()); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java new file mode 100644 index 0000000000..5d35327116 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableMap; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.config.ExpressionConfig; + +class SingleFieldQueryTest { + SingleFieldQuery query; + private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + private final String testQueryName = "test_query"; + private final Map actionMap + = ImmutableMap.of("paramA", (o, v) -> o); + + @BeforeEach + void setUp() { + query = mock(SingleFieldQuery.class, + Mockito.withSettings().useConstructor(actionMap) + .defaultAnswer(Mockito.CALLS_REAL_METHODS)); + when(query.getQueryName()).thenReturn(testQueryName); + } + + @Test + void createQueryBuilderTest() { + String sampleQuery = "sample query"; + String sampleField = "fieldA"; + + query.createQueryBuilder(dsl.namedArgument("field", + new LiteralExpression(ExprValueUtils.stringValue(sampleField))), + dsl.namedArgument("query", + new LiteralExpression(ExprValueUtils.stringValue(sampleQuery)))); + + verify(query).createBuilder(eq(sampleField), + eq(sampleQuery)); + } +}