From bbcbd216200c2d71e32f4fe062c6e040d907659d Mon Sep 17 00:00:00 2001 From: panguixin Date: Fri, 10 Jan 2025 09:53:11 +0800 Subject: [PATCH] use the correct type to widen the sort fields when merging top docs (#16881) * use the correct type to widen the sort fields when merging top docs Signed-off-by: panguixin * fix Signed-off-by: panguixin * apply commments Signed-off-by: panguixin * changelog Signed-off-by: panguixin * add more tests Signed-off-by: panguixin --------- Signed-off-by: panguixin --- CHANGELOG.md | 1 + .../opensearch/search/sort/FieldSortIT.java | 99 +++++++++++++++++++ .../action/search/SearchPhaseController.java | 54 ++++++---- .../sort/SortedWiderNumericSortField.java | 21 +++- 4 files changed, 152 insertions(+), 23 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0cb11d1c45d38..9aabbbf75f00c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -66,6 +66,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Indexed IP field supports `terms_query` with more than 1025 IP masks [#16391](https://github.com/opensearch-project/OpenSearch/pull/16391) - Make entries for dependencies from server/build.gradle to gradle version catalog ([#16707](https://github.com/opensearch-project/OpenSearch/pull/16707)) - Allow extended plugins to be optional ([#16909](https://github.com/opensearch-project/OpenSearch/pull/16909)) +- Use the correct type to widen the sort fields when merging top docs ([#16881](https://github.com/opensearch-project/OpenSearch/pull/16881)) ### Deprecated - Performing update operation with default pipeline or final pipeline is deprecated ([#16712](https://github.com/opensearch-project/OpenSearch/pull/16712)) diff --git a/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java b/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java index fdb12639c65be..cc837019d0b42 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/sort/FieldSortIT.java @@ -49,6 +49,7 @@ import org.opensearch.common.Numbers; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.core.xcontent.XContentBuilder; @@ -63,6 +64,7 @@ import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.test.InternalSettingsPlugin; +import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.ParameterizedDynamicSettingsOpenSearchIntegTestCase; import org.hamcrest.Matchers; @@ -82,7 +84,9 @@ import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; +import java.util.function.Supplier; import static org.opensearch.common.xcontent.XContentFactory.jsonBuilder; import static org.opensearch.index.query.QueryBuilders.functionScoreQuery; @@ -2609,4 +2613,99 @@ public void testSimpleSortsPoints() throws Exception { assertThat(searchResponse.toString(), not(containsString("error"))); } + + public void testSortMixedIntegerNumericFields() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(3); + AtomicInteger counter = new AtomicInteger(); + index("long", () -> Long.MAX_VALUE - counter.getAndIncrement()); + index("integer", () -> Integer.MAX_VALUE - counter.getAndIncrement()); + SearchResponse searchResponse = client().prepareSearch("long", "integer") + .setQuery(matchAllQuery()) + .setSize(10) + .addSort(SortBuilders.fieldSort("field").order(SortOrder.ASC).sortMode(SortMode.MAX)) + .get(); + assertNoFailures(searchResponse); + long[] sortValues = new long[10]; + for (int i = 0; i < 10; i++) { + sortValues[i] = ((Number) searchResponse.getHits().getAt(i).getSortValues()[0]).longValue(); + } + for (int i = 1; i < 10; i++) { + assertThat(Arrays.toString(sortValues), sortValues[i - 1], lessThan(sortValues[i])); + } + } + + public void testSortMixedFloatingNumericFields() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(3); + AtomicInteger counter = new AtomicInteger(); + index("double", () -> 100.5 - counter.getAndIncrement()); + counter.set(0); + index("float", () -> 200.5 - counter.getAndIncrement()); + counter.set(0); + index("half_float", () -> 300.5 - counter.getAndIncrement()); + SearchResponse searchResponse = client().prepareSearch("double", "float", "half_float") + .setQuery(matchAllQuery()) + .setSize(15) + .addSort(SortBuilders.fieldSort("field").order(SortOrder.ASC).sortMode(SortMode.MAX)) + .get(); + assertNoFailures(searchResponse); + double[] sortValues = new double[15]; + for (int i = 0; i < 15; i++) { + sortValues[i] = ((Number) searchResponse.getHits().getAt(i).getSortValues()[0]).doubleValue(); + } + for (int i = 1; i < 15; i++) { + assertThat(Arrays.toString(sortValues), sortValues[i - 1], lessThan(sortValues[i])); + } + } + + public void testSortMixedFloatingAndIntegerNumericFields() throws Exception { + internalCluster().ensureAtLeastNumDataNodes(3); + index("long", () -> randomLongBetween(0, (long) 2E53 - 1)); + index("integer", OpenSearchTestCase::randomInt); + index("double", OpenSearchTestCase::randomDouble); + index("float", () -> randomFloat()); + boolean asc = randomBoolean(); + SearchResponse searchResponse = client().prepareSearch("long", "integer", "double", "float") + .setQuery(matchAllQuery()) + .setSize(20) + .addSort(SortBuilders.fieldSort("field").order(asc ? SortOrder.ASC : SortOrder.DESC).sortMode(SortMode.MAX)) + .get(); + assertNoFailures(searchResponse); + double[] sortValues = new double[20]; + for (int i = 0; i < 20; i++) { + sortValues[i] = ((Number) searchResponse.getHits().getAt(i).getSortValues()[0]).doubleValue(); + } + if (asc) { + for (int i = 1; i < 20; i++) { + assertThat(Arrays.toString(sortValues), sortValues[i - 1], lessThanOrEqualTo(sortValues[i])); + } + } else { + for (int i = 1; i < 20; i++) { + assertThat(Arrays.toString(sortValues), sortValues[i - 1], greaterThanOrEqualTo(sortValues[i])); + } + } + } + + private void index(String type, Supplier valueSupplier) throws Exception { + assertAcked( + prepareCreate(type).setMapping( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject("field") + .field("type", type) + .endObject() + .endObject() + .endObject() + ).setSettings(Settings.builder().put("index.number_of_shards", 3).put("index.number_of_replicas", 0)) + ); + ensureGreen(type); + for (int i = 0; i < 5; i++) { + client().prepareIndex(type) + .setId(Integer.toString(i)) + .setSource("{\"field\" : " + valueSupplier.get() + " }", XContentType.JSON) + .get(); + } + client().admin().indices().prepareRefresh(type).get(); + } + } diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java index 161a103cdf36a..d63695447e365 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseController.java @@ -48,6 +48,7 @@ import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -604,36 +605,51 @@ private static void validateMergeSortValueFormats(Collection comparator; + /** * Creates a sort, possibly in reverse, specifying how the sort value from the document's set is * selected. @@ -39,6 +43,15 @@ public class SortedWiderNumericSortField extends SortedNumericSortField { */ public SortedWiderNumericSortField(String field, Type type, boolean reverse) { super(field, type, reverse); + if (type == Type.LONG) { + byteCounts = Long.BYTES; + comparator = Comparator.comparingLong(Number::longValue); + } else if (type == Type.DOUBLE) { + byteCounts = Double.BYTES; + comparator = Comparator.comparingDouble(Number::doubleValue); + } else { + throw new IllegalArgumentException("Unsupported numeric type: " + type); + } } /** @@ -51,7 +64,7 @@ public SortedWiderNumericSortField(String field, Type type, boolean reverse) { */ @Override public FieldComparator getComparator(int numHits, Pruning pruning) { - return new NumericComparator(getField(), (Number) getMissingValue(), getReverse(), pruning, Double.BYTES) { + return new NumericComparator(getField(), (Number) getMissingValue(), getReverse(), pruning, byteCounts) { @Override public int compare(int slot1, int slot2) { throw new UnsupportedOperationException(); @@ -78,7 +91,7 @@ public int compareValues(Number first, Number second) { } else if (second == null) { return 1; } else { - return Double.compare(first.doubleValue(), second.doubleValue()); + return comparator.compare(first, second); } } };