From cd1b8cd4ad3d5fea4a78f964156ef1f40f6b1327 Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 11:18:34 +0200 Subject: [PATCH 01/12] Skip can_match phase for nodes that support batched query execution No need to do the can_match phase for those nodes that support the batched query phase. We can get the equivalent degree of pre-filtering by simply running can_match for sort based queries. --- .../search/CanMatchPreFilterSearchPhase.java | 104 ++++++++++-------- .../SearchQueryThenFetchAsyncAction.java | 100 +++++++++++++---- .../TransportOpenPointInTimeAction.java | 4 +- .../action/search/TransportSearchAction.java | 4 +- .../search/TransportSearchShardsAction.java | 12 +- .../CanMatchPreFilterSearchPhaseTests.java | 25 +++-- 6 files changed, 155 insertions(+), 94 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index e42f8127c5e97..caca901fbb0b5 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -39,6 +39,7 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; import java.util.concurrent.atomic.AtomicReferenceArray; import java.util.function.BiFunction; @@ -75,7 +76,7 @@ final class CanMatchPreFilterSearchPhase { private final FixedBitSet possibleMatches; private final MinAndMax[] minAndMaxes; private int numPossibleMatches; - private final CoordinatorRewriteContextProvider coordinatorRewriteContextProvider; + private final boolean batchQueryPhase; private CanMatchPreFilterSearchPhase( Logger logger, @@ -89,7 +90,7 @@ private CanMatchPreFilterSearchPhase( TransportSearchAction.SearchTimeProvider timeProvider, SearchTask task, boolean requireAtLeastOneMatch, - CoordinatorRewriteContextProvider coordinatorRewriteContextProvider, + boolean batchQueryPhase, ActionListener> listener ) { this.logger = logger; @@ -103,7 +104,6 @@ private CanMatchPreFilterSearchPhase( this.aliasFilter = aliasFilter; this.task = task; this.requireAtLeastOneMatch = requireAtLeastOneMatch; - this.coordinatorRewriteContextProvider = coordinatorRewriteContextProvider; this.executor = executor; final int size = shardsIts.size(); possibleMatches = new FixedBitSet(size); @@ -122,6 +122,7 @@ private CanMatchPreFilterSearchPhase( shardItIndexMap.put(naturalOrder[j], j); } this.shardItIndexMap = shardItIndexMap; + this.batchQueryPhase = batchQueryPhase; } public static SubscribableListener> execute( @@ -130,17 +131,19 @@ public static SubscribableListener> execute( BiFunction nodeIdToConnection, Map aliasFilter, Map concreteIndexBoosts, - Executor executor, SearchRequest request, List shardsIts, TransportSearchAction.SearchTimeProvider timeProvider, SearchTask task, boolean requireAtLeastOneMatch, - CoordinatorRewriteContextProvider coordinatorRewriteContextProvider + boolean batchQueryPhase, + SearchService searchService ) { + if (shardsIts.isEmpty()) { return SubscribableListener.newSucceeded(List.of()); } + ExecutorService executor = searchTransportService.transportService().getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION); final SubscribableListener> listener = new SubscribableListener<>(); // Note that the search is failed when this task is rejected by the executor executor.execute(new AbstractRunnable() { @@ -167,9 +170,9 @@ protected void doRun() { timeProvider, task, requireAtLeastOneMatch, - coordinatorRewriteContextProvider, + batchQueryPhase && searchService.batchQueryPhase(), listener - ).runCoordinatorRewritePhase(); + ).runCoordinatorRewritePhase(searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis)); } }); return listener; @@ -181,7 +184,7 @@ private static boolean assertSearchCoordinationThread() { // tries to pre-filter shards based on information that's available to the coordinator // without having to reach out to the actual shards - private void runCoordinatorRewritePhase() { + private void runCoordinatorRewritePhase(CoordinatorRewriteContextProvider coordinatorRewriteContextProvider) { // TODO: the index filter (i.e, `_index:patten`) should be prefiltered on the coordinator assert assertSearchCoordinationThread(); final List matchedShardLevelRequests = new ArrayList<>(); @@ -296,50 +299,60 @@ protected void doRun() { if (entry.getKey().nodeId == null) { // no target node: just mark the requests as failed - for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), null); - } + failAll(shardLevelRequests, null); continue; } var sendingTarget = entry.getKey(); try { - searchTransportService.sendCanMatch( - nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId), - canMatchNodeRequest, - task, - new ActionListener<>() { - @Override - public void onResponse(CanMatchNodeResponse canMatchNodeResponse) { - assert canMatchNodeResponse.getResponses().size() == canMatchNodeRequest.getShardLevelRequests().size(); - for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) { - CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i); - if (response.getResponse() != null) { - CanMatchShardResponse shardResponse = response.getResponse(); - shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex()); - onOperation(shardResponse.getShardIndex(), shardResponse); - } else { - Exception failure = response.getException(); - assert failure != null; - onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure); - } - } - } - - @Override - public void onFailure(Exception e) { - for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), e); - } - } - } - ); + var connection = nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId); + if (batchQueryPhase && SearchQueryThenFetchAsyncAction.connectionSupportsBatchedExecution(connection)) { + failAll(shardLevelRequests, null); + } else { + bwcSendCanMatchRequest(connection, canMatchNodeRequest, shardLevelRequests); + } } catch (Exception e) { + failAll(shardLevelRequests, e); + } + } + } + + private void failAll(List shardLevelRequests, Exception e) { + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + onOperationFailed(shard.getShardRequestIndex(), e); + } + } + + private void bwcSendCanMatchRequest( + Transport.Connection connection, + CanMatchNodeRequest canMatchNodeRequest, + List shardLevelRequests + ) { + searchTransportService.sendCanMatch(connection, canMatchNodeRequest, task, new ActionListener<>() { + @Override + public void onResponse(CanMatchNodeResponse canMatchNodeResponse) { + assert canMatchNodeResponse.getResponses().size() == shardLevelRequests.size(); + for (int i = 0; i < canMatchNodeResponse.getResponses().size(); i++) { + CanMatchNodeResponse.ResponseOrFailure response = canMatchNodeResponse.getResponses().get(i); + if (response.getResponse() != null) { + CanMatchShardResponse shardResponse = response.getResponse(); + shardResponse.setShardIndex(shardLevelRequests.get(i).getShardRequestIndex()); + onOperation(shardResponse.getShardIndex(), shardResponse); + } else { + Exception failure = response.getException(); + assert failure != null; + onOperationFailed(shardLevelRequests.get(i).getShardRequestIndex(), failure); + } + } + } + + @Override + public void onFailure(Exception e) { for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { onOperationFailed(shard.getShardRequestIndex(), e); } } - } + }); } private void onOperation(int idx, CanMatchShardResponse response) { @@ -461,11 +474,10 @@ private synchronized List getIterator(List sortShards(List shardsIts, MinAndMax[] minAndMaxes, SortOrder order) { + public static > List sortShards(List shardsIts, MinAndMax[] minAndMaxes, SortOrder order) { int bound = shardsIts.size(); List toSort = new ArrayList<>(bound); for (int i = 0; i < bound; i++) { @@ -479,7 +491,7 @@ private static List sortShards(List sh } return shardsIts.get(idx1).compareTo(shardsIts.get(idx2)); }); - List list = new ArrayList<>(bound); + List list = new ArrayList<>(bound); for (Integer integer : toSort) { list.add(shardsIts.get(integer)); } diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 39e1c30f658d8..274935ccfecdc 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -46,6 +46,8 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; +import org.elasticsearch.search.sort.FieldSortBuilder; +import org.elasticsearch.search.sort.MinAndMax; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; @@ -78,6 +80,7 @@ import java.util.function.BiFunction; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; +import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { @@ -345,6 +348,7 @@ public IndicesOptions indicesOptions() { private record ShardToQuery(float boost, String[] originalIndices, int shardIndex, ShardId shardId, ShardSearchContextId contextId) implements + Comparable, Writeable { static ShardToQuery readFrom(StreamInput in) throws IOException { @@ -365,6 +369,11 @@ public void writeTo(StreamOutput out) throws IOException { shardId.writeTo(out); out.writeOptionalWriteable(contextId); } + + @Override + public int compareTo(ShardToQuery o) { + return shardId.compareTo(o.shardId); + } } /** @@ -462,8 +471,7 @@ protected void doRun(Map shardIndexMap) { return; } // must check both node and transport versions to correctly deal with BwC on proxy connections - if (connection.getTransportVersion().before(TransportVersions.BATCHED_QUERY_PHASE_VERSION) - || connection.getNode().getVersionInformation().nodeVersion().before(Version.V_9_1_0)) { + if (connectionSupportsBatchedExecution(connection) == false) { executeWithoutBatching(routing, request); return; } @@ -525,6 +533,11 @@ public void handleException(TransportException e) { }); } + public static boolean connectionSupportsBatchedExecution(Transport.Connection connection) { + return connection.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION) + || connection.getNode().getVersionInformation().nodeVersion().onOrAfter(Version.V_9_1_0); + } + private void executeWithoutBatching(CanMatchPreFilterSearchPhase.SendingTarget targetNode, NodeQueryRequest request) { for (ShardToQuery shard : request.shards) { executeAsSingleRequest(targetNode, shard); @@ -562,11 +575,45 @@ static void registerNodeSearchAction( final int searchPoolMax = threadPool.info(ThreadPool.Names.SEARCH).getMax(); transportService.registerRequestHandler( NODE_SEARCH_ACTION_NAME, - EsExecutors.DIRECT_EXECUTOR_SERVICE, + threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), NodeQueryRequest::new, (request, channel, task) -> { + final SearchRequest searchRequest = request.searchRequest; + final List shards; + if (hasPrimaryFieldSort(searchRequest.source())) { + var pitBuilder = searchRequest.pointInTimeBuilder(); + @SuppressWarnings("rawtypes") + final MinAndMax[] minAndMax = new MinAndMax[request.shards.size()]; + for (int i = 0; i < minAndMax.length; i++) { + var shardToQuery = request.shards.get(i); + var shardId = shardToQuery.shardId; + var r = buildShardSearchRequest( + shardId, + request.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, request.indicesOptions()), + request.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + request.totalShards, + request.absoluteStartMillis, + false + ); + var res = searchService.canMatch(r); + minAndMax[i] = res.estimatedMinAndMax(); + } + shards = CanMatchPreFilterSearchPhase.sortShards( + request.shards, + minAndMax, + FieldSortBuilder.getPrimaryFieldSortOrNull(searchRequest.source()).order() + ); + } else { + shards = request.shards; + } final CancellableTask cancellableTask = (CancellableTask) task; - final int shardCount = request.shards.size(); + final int shardCount = shards.size(); int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); final var state = new QueryPerNodeState( new QueryPhaseResultConsumer( @@ -580,6 +627,7 @@ static void registerNodeSearchAction( e -> logger.error("failed to merge on data node", e) ), request, + shards, cancellableTask, channel, dependencies @@ -593,12 +641,12 @@ static void registerNodeSearchAction( TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new); } - private static void releaseLocalContext(SearchService searchService, NodeQueryRequest request, SearchPhaseResult result) { + private static void releaseLocalContext(SearchService searchService, SearchRequest searchRequest, SearchPhaseResult result) { var phaseResult = result.queryResult() != null ? result.queryResult() : result.rankFeatureResult(); if (phaseResult != null && phaseResult.hasSearchContext() - && request.searchRequest.scroll() == null - && isPartOfPIT(request.searchRequest, phaseResult.getContextId()) == false) { + && searchRequest.scroll() == null + && isPartOfPIT(searchRequest, phaseResult.getContextId()) == false) { searchService.freeReaderContext(phaseResult.getContextId()); } } @@ -646,15 +694,15 @@ private static ShardSearchRequest buildShardSearchRequest( private static void executeShardTasks(QueryPerNodeState state) { int idx; - final int totalShardCount = state.searchRequest.shards.size(); + final int totalShardCount = state.shardsToQuery.size(); while ((idx = state.currentShardIndex.getAndIncrement()) < totalShardCount) { final int dataNodeLocalIdx = idx; final ListenableFuture doneFuture = new ListenableFuture<>(); try { - final NodeQueryRequest nodeQueryRequest = state.searchRequest; + final NodeQueryRequest nodeQueryRequest = state.nodeQueryRequest; final SearchRequest searchRequest = nodeQueryRequest.searchRequest; var pitBuilder = searchRequest.pointInTimeBuilder(); - var shardToQuery = nodeQueryRequest.shards.get(dataNodeLocalIdx); + var shardToQuery = state.shardsToQuery.get(dataNodeLocalIdx); final var shardId = shardToQuery.shardId; state.dependencies.searchService.executeQueryPhase( tryRewriteWithUpdatedSortValue( @@ -732,7 +780,8 @@ private static final class QueryPerNodeState { private final AtomicInteger currentShardIndex = new AtomicInteger(); private final QueryPhaseResultConsumer queryPhaseResultConsumer; - private final NodeQueryRequest searchRequest; + private final NodeQueryRequest nodeQueryRequest; + private final List shardsToQuery; private final CancellableTask task; private final ConcurrentHashMap failures = new ConcurrentHashMap<>(); private final Dependencies dependencies; @@ -745,15 +794,17 @@ private static final class QueryPerNodeState { private QueryPerNodeState( QueryPhaseResultConsumer queryPhaseResultConsumer, - NodeQueryRequest searchRequest, + NodeQueryRequest nodeQueryRequest, + List shardsToQuery, CancellableTask task, TransportChannel channel, Dependencies dependencies ) { this.queryPhaseResultConsumer = queryPhaseResultConsumer; - this.searchRequest = searchRequest; - this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo(); - this.topDocsSize = getTopDocsSize(searchRequest.searchRequest); + this.nodeQueryRequest = nodeQueryRequest; + this.shardsToQuery = shardsToQuery; + this.trackTotalHitsUpTo = nodeQueryRequest.searchRequest.resolveTrackTotalHitsUpTo(); + this.topDocsSize = getTopDocsSize(nodeQueryRequest.searchRequest); this.task = task; this.countDown = new CountDown(queryPhaseResultConsumer.getNumShards()); this.channel = channel; @@ -786,11 +837,11 @@ void onShardDone() { // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments, // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other // indices without a roundtrip to the coordinating node - final BitSet relevantShardIndices = new BitSet(searchRequest.shards.size()); + final BitSet relevantShardIndices = new BitSet(shardsToQuery.size()); if (mergeResult.reducedTopDocs() != null) { for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { final int localIndex = scoreDoc.shardIndex; - scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex; + scoreDoc.shardIndex = shardsToQuery.get(localIndex).shardIndex; relevantShardIndices.set(localIndex); } } @@ -800,9 +851,10 @@ void onShardDone() { try { out.writeVInt(resultCount); for (int i = 0; i < resultCount; i++) { - var result = queryPhaseResultConsumer.results.get(i); + int idx = shardsToQuery.indexOf(nodeQueryRequest.shards.get(i)); + var result = queryPhaseResultConsumer.results.get(idx); if (result == null) { - NodeQueryResponse.writePerShardException(out, failures.remove(i)); + NodeQueryResponse.writePerShardException(out, failures.remove(idx)); } else { // free context id and remove it from the result right away in case we don't need it anymore maybeFreeContext(result, relevantShardIndices); @@ -829,8 +881,8 @@ private void maybeFreeContext(SearchPhaseResult result, BitSet relevantShardIndi && relevantShardIndices.get(q.getShardIndex()) == false && q.hasSuggestHits() == false && q.getRankShardResult() == null - && searchRequest.searchRequest.scroll() == null - && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { + && nodeQueryRequest.searchRequest.scroll() == null + && isPartOfPIT(nodeQueryRequest.searchRequest, q.getContextId()) == false) { if (dependencies.searchService.freeReaderContext(q.getContextId())) { q.clearContextId(); } @@ -839,7 +891,9 @@ && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { private void handleMergeFailure(Exception e, ChannelActionListener channelListener) { queryPhaseResultConsumer.getSuccessfulResults() - .forEach(searchPhaseResult -> releaseLocalContext(dependencies.searchService, searchRequest, searchPhaseResult)); + .forEach( + searchPhaseResult -> releaseLocalContext(dependencies.searchService, nodeQueryRequest.searchRequest, searchPhaseResult) + ); channelListener.onFailure(e); } @@ -849,7 +903,7 @@ void consumeResult(QuerySearchResult queryResult) { // TODO: dry up the bottom sort collector with the coordinator side logic in the top-level class here if (queryResult.isNull() == false // disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard) - && searchRequest.searchRequest.scroll() == null + && nodeQueryRequest.searchRequest.scroll() == null // top docs are already consumed if the query was cancelled or in error. && queryResult.hasConsumedTopDocs() == false && queryResult.topDocs() != null diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java index ac23731c38b84..54e355ffe5069 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportOpenPointInTimeAction.java @@ -167,13 +167,13 @@ public void runNewSearchPhase( connectionLookup, aliasFilter, concreteIndexBoosts, - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardIterators, timeProvider, task, false, - searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis) + false, + searchService ) .addListener( listener.delegateFailureAndWrap( diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 2b13ac7bd2ae0..4231e801b0582 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -1490,13 +1490,13 @@ public void runNewSearchPhase( connectionLookup, aliasFilter, concreteIndexBoosts, - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardIterators, timeProvider, task, requireAtLeastOneMatch, - searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis) + false, + searchService ) .addListener( listener.delegateFailureAndWrap( diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java index d12847ec8bf7f..207ce7843526a 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchShardsAction.java @@ -156,17 +156,7 @@ public void searchShards(Task task, SearchShardsRequest searchShardsRequest, Act CanMatchPreFilterSearchPhase.execute(logger, searchTransportService, (clusterAlias, node) -> { assert Objects.equals(clusterAlias, searchShardsRequest.clusterAlias()); return transportService.getConnection(project.cluster().nodes().get(node)); - }, - aliasFilters, - Map.of(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), - searchRequest, - shardIts, - timeProvider, - (SearchTask) task, - false, - searchService.getCoordinatorRewriteContextProvider(timeProvider::absoluteStartMillis) - ) + }, aliasFilters, Map.of(), searchRequest, shardIts, timeProvider, (SearchTask) task, false, false, searchService) .addListener( delegate.map( its -> new SearchShardsResponse(toGroups(its), project.cluster().nodes().getAllNodes(), aliasFilters) diff --git a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 1c3a6cd47a3b7..6e65b76c09734 100644 --- a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -44,6 +44,7 @@ import org.elasticsearch.index.shard.ShardLongFieldRange; import org.elasticsearch.indices.DateFieldRangeInfo; import org.elasticsearch.search.CanMatchShardResponse; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.bucket.terms.SignificantTermsAggregationBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; @@ -81,7 +82,9 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; public class CanMatchPreFilterSearchPhaseTests extends ESTestCase { @@ -155,13 +158,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, true, - EMPTY_CONTEXT_PROVIDER + false, + mock(SearchService.class) ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -250,13 +253,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, true, - EMPTY_CONTEXT_PROVIDER + false, + mock(SearchService.class) ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -341,13 +344,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, true, - EMPTY_CONTEXT_PROVIDER + false, + mock(SearchService.class) ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -440,13 +443,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), Collections.singletonMap("_na_", AliasFilter.EMPTY), Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardsIter, timeProvider, null, shardsIter.size() > shardToSkip.size(), - EMPTY_CONTEXT_PROVIDER + false, + mock(SearchService.class) ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -1405,6 +1408,8 @@ public void sendCanMatch( System::nanoTime ); + var searchService = mock(SearchService.class); + when(searchService.getCoordinatorRewriteContextProvider(any())).thenReturn(contextProvider); return new Tuple<>( CanMatchPreFilterSearchPhase.execute( logger, @@ -1412,13 +1417,13 @@ public void sendCanMatch( (clusterAlias, node) -> lookup.get(node), aliasFilters, Collections.emptyMap(), - threadPool.executor(ThreadPool.Names.SEARCH_COORDINATION), searchRequest, shardIters, timeProvider, null, true, - contextProvider + false, + searchService ), requests ); From a10684731c75eb71d2fa9b03292f0ef2a9d7f716 Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 11:55:02 +0200 Subject: [PATCH 02/12] cleanup --- .../CanMatchPreFilterSearchPhaseTests.java | 47 +++++++++++++++++-- 1 file changed, 42 insertions(+), 5 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java index 6e65b76c09734..f8651c2b76d4c 100644 --- a/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhaseTests.java @@ -56,8 +56,8 @@ import org.elasticsearch.search.suggest.SuggestBuilder; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.TestThreadPool; -import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transport; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xcontent.XContentParserConfiguration; import java.util.ArrayList; @@ -137,6 +137,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -164,7 +169,7 @@ public void sendCanMatch( null, true, false, - mock(SearchService.class) + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -188,6 +193,12 @@ public void sendCanMatch( } } + private SearchService mockSearchService() { + var searchService = mock(SearchService.class); + when(searchService.getCoordinatorRewriteContextProvider(any())).thenReturn(EMPTY_CONTEXT_PROVIDER); + return searchService; + } + public void testFilterWithFailure() throws InterruptedException { final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( 0, @@ -231,6 +242,11 @@ public void sendCanMatch( } }).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -259,7 +275,7 @@ public void sendCanMatch( null, true, false, - mock(SearchService.class) + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -322,6 +338,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -350,7 +371,7 @@ public void sendCanMatch( null, true, false, - mock(SearchService.class) + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -421,6 +442,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; AtomicReference> result = new AtomicReference<>(); @@ -449,7 +475,7 @@ public void sendCanMatch( null, shardsIter.size() > shardToSkip.size(), false, - mock(SearchService.class) + mockSearchService() ).addListener(ActionTestUtils.assertNoFailureListener(iter -> { result.set(iter); latch.countDown(); @@ -1400,6 +1426,11 @@ public void sendCanMatch( new Thread(() -> listener.onResponse(new CanMatchNodeResponse(responses))).start(); } + + @Override + public TransportService transportService() { + return mockTransportService(); + } }; final TransportSearchAction.SearchTimeProvider timeProvider = new TransportSearchAction.SearchTimeProvider( @@ -1429,6 +1460,12 @@ public void sendCanMatch( ); } + private TransportService mockTransportService() { + var transportService = mock(TransportService.class); + when(transportService.getThreadPool()).thenReturn(threadPool); + return transportService; + } + static class StaticCoordinatorRewriteContextProviderBuilder { private ClusterState clusterState = ClusterState.EMPTY_STATE; private final Map fields = new HashMap<>(); From 163cf13e593ca1640c2bc097abca01ff4d2b60bb Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 13:03:18 +0200 Subject: [PATCH 03/12] works better --- .../search/CanMatchPreFilterSearchPhase.java | 17 ++-- .../SearchQueryThenFetchAsyncAction.java | 85 +++++++++++-------- 2 files changed, 60 insertions(+), 42 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index caca901fbb0b5..2a77f28ef14a2 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -474,10 +474,15 @@ private synchronized List getIterator(List list = new ArrayList<>(indexTranslation.length); + for (int in : indexTranslation) { + list.add(shardsIts.get(in)); + } + return list; } - public static > List sortShards(List shardsIts, MinAndMax[] minAndMaxes, SortOrder order) { + public static > int[] sortShards(List shardsIts, MinAndMax[] minAndMaxes, SortOrder order) { int bound = shardsIts.size(); List toSort = new ArrayList<>(bound); for (int i = 0; i < bound; i++) { @@ -491,11 +496,11 @@ public static > List sortShards(List shardsIts, Mi } return shardsIts.get(idx1).compareTo(shardsIts.get(idx2)); }); - List list = new ArrayList<>(bound); - for (Integer integer : toSort) { - list.add(shardsIts.get(integer)); + int[] result = new int[bound]; + for (int i = 0; i < bound; i++) { + result[i] = toSort.get(i); } - return list; + return result; } private static boolean shouldSortShards(MinAndMax[] minAndMaxes) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 274935ccfecdc..27674deee623e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -32,6 +32,7 @@ import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ListenableFuture; +import org.elasticsearch.core.Nullable; import org.elasticsearch.core.RefCounted; import org.elasticsearch.core.SimpleRefCounted; import org.elasticsearch.core.TimeValue; @@ -78,6 +79,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BiFunction; +import java.util.function.IntUnaryOperator; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; @@ -579,8 +581,10 @@ static void registerNodeSearchAction( NodeQueryRequest::new, (request, channel, task) -> { final SearchRequest searchRequest = request.searchRequest; - final List shards; + final IntUnaryOperator shards; + final ShardSearchRequest[] shardSearchRequests; if (hasPrimaryFieldSort(searchRequest.source())) { + shardSearchRequests = new ShardSearchRequest[request.shards.size()]; var pitBuilder = searchRequest.pointInTimeBuilder(); @SuppressWarnings("rawtypes") final MinAndMax[] minAndMax = new MinAndMax[request.shards.size()]; @@ -601,19 +605,21 @@ static void registerNodeSearchAction( request.absoluteStartMillis, false ); - var res = searchService.canMatch(r); - minAndMax[i] = res.estimatedMinAndMax(); + shardSearchRequests[i] = r; + minAndMax[i] = searchService.canMatch(r).estimatedMinAndMax(); } - shards = CanMatchPreFilterSearchPhase.sortShards( + int[] indexes = CanMatchPreFilterSearchPhase.sortShards( request.shards, minAndMax, FieldSortBuilder.getPrimaryFieldSortOrNull(searchRequest.source()).order() ); + shards = pos -> indexes[pos]; } else { - shards = request.shards; + shardSearchRequests = null; + shards = IntUnaryOperator.identity(); } final CancellableTask cancellableTask = (CancellableTask) task; - final int shardCount = shards.size(); + final int shardCount = request.shards.size(); int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); final var state = new QueryPerNodeState( new QueryPhaseResultConsumer( @@ -630,7 +636,8 @@ static void registerNodeSearchAction( shards, cancellableTask, channel, - dependencies + dependencies, + shardSearchRequests ); // TODO: log activating or otherwise limiting parallelism might be helpful here for (int i = 0; i < workers; i++) { @@ -694,35 +701,39 @@ private static ShardSearchRequest buildShardSearchRequest( private static void executeShardTasks(QueryPerNodeState state) { int idx; - final int totalShardCount = state.shardsToQuery.size(); + final NodeQueryRequest nodeQueryRequest = state.nodeQueryRequest; + var shards = nodeQueryRequest.shards; + final int totalShardCount = shards.size(); while ((idx = state.currentShardIndex.getAndIncrement()) < totalShardCount) { final int dataNodeLocalIdx = idx; final ListenableFuture doneFuture = new ListenableFuture<>(); try { - final NodeQueryRequest nodeQueryRequest = state.nodeQueryRequest; final SearchRequest searchRequest = nodeQueryRequest.searchRequest; var pitBuilder = searchRequest.pointInTimeBuilder(); - var shardToQuery = state.shardsToQuery.get(dataNodeLocalIdx); + int translatedIndex = state.shardsToQuery.applyAsInt(dataNodeLocalIdx); + var shardToQuery = shards.get(translatedIndex); final var shardId = shardToQuery.shardId; + ShardSearchRequest r = state.shardSearchRequests == null ? null : state.shardSearchRequests[translatedIndex]; + if (r == null) { + r = buildShardSearchRequest( + shardId, + nodeQueryRequest.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, nodeQueryRequest.indicesOptions()), + nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + nodeQueryRequest.totalShards, + nodeQueryRequest.absoluteStartMillis, + state.hasResponse.getAcquire() + ); + } else { + state.shardSearchRequests[translatedIndex] = null; + } state.dependencies.searchService.executeQueryPhase( - tryRewriteWithUpdatedSortValue( - state.bottomSortCollector, - state.trackTotalHitsUpTo, - buildShardSearchRequest( - shardId, - nodeQueryRequest.localClusterAlias, - shardToQuery.shardIndex, - shardToQuery.contextId, - new OriginalIndices(shardToQuery.originalIndices, nodeQueryRequest.indicesOptions()), - nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), - pitBuilder == null ? null : pitBuilder.getKeepAlive(), - shardToQuery.boost, - searchRequest, - nodeQueryRequest.totalShards, - nodeQueryRequest.absoluteStartMillis, - state.hasResponse.getAcquire() - ) - ), + tryRewriteWithUpdatedSortValue(state.bottomSortCollector, state.trackTotalHitsUpTo, r), state.task, new SearchActionListener<>( new SearchShardTarget(null, shardToQuery.shardId, nodeQueryRequest.localClusterAlias), @@ -781,7 +792,7 @@ private static final class QueryPerNodeState { private final AtomicInteger currentShardIndex = new AtomicInteger(); private final QueryPhaseResultConsumer queryPhaseResultConsumer; private final NodeQueryRequest nodeQueryRequest; - private final List shardsToQuery; + private final IntUnaryOperator shardsToQuery; private final CancellableTask task; private final ConcurrentHashMap failures = new ConcurrentHashMap<>(); private final Dependencies dependencies; @@ -789,16 +800,18 @@ private static final class QueryPerNodeState { private final int trackTotalHitsUpTo; private final int topDocsSize; private final CountDown countDown; + private final @Nullable ShardSearchRequest[] shardSearchRequests; private final TransportChannel channel; private volatile BottomSortValuesCollector bottomSortCollector; private QueryPerNodeState( QueryPhaseResultConsumer queryPhaseResultConsumer, NodeQueryRequest nodeQueryRequest, - List shardsToQuery, + IntUnaryOperator shardsToQuery, CancellableTask task, TransportChannel channel, - Dependencies dependencies + Dependencies dependencies, + @Nullable ShardSearchRequest[] shardSearchRequests ) { this.queryPhaseResultConsumer = queryPhaseResultConsumer; this.nodeQueryRequest = nodeQueryRequest; @@ -809,6 +822,7 @@ private QueryPerNodeState( this.countDown = new CountDown(queryPhaseResultConsumer.getNumShards()); this.channel = channel; this.dependencies = dependencies; + this.shardSearchRequests = shardSearchRequests; } void onShardDone() { @@ -837,11 +851,11 @@ void onShardDone() { // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments, // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other // indices without a roundtrip to the coordinating node - final BitSet relevantShardIndices = new BitSet(shardsToQuery.size()); + final BitSet relevantShardIndices = new BitSet(nodeQueryRequest.shards.size()); if (mergeResult.reducedTopDocs() != null) { for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { final int localIndex = scoreDoc.shardIndex; - scoreDoc.shardIndex = shardsToQuery.get(localIndex).shardIndex; + scoreDoc.shardIndex = nodeQueryRequest.shards.get(localIndex).shardIndex; relevantShardIndices.set(localIndex); } } @@ -851,10 +865,9 @@ void onShardDone() { try { out.writeVInt(resultCount); for (int i = 0; i < resultCount; i++) { - int idx = shardsToQuery.indexOf(nodeQueryRequest.shards.get(i)); - var result = queryPhaseResultConsumer.results.get(idx); + var result = queryPhaseResultConsumer.results.get(i); if (result == null) { - NodeQueryResponse.writePerShardException(out, failures.remove(idx)); + NodeQueryResponse.writePerShardException(out, failures.remove(i)); } else { // free context id and remove it from the result right away in case we don't need it anymore maybeFreeContext(result, relevantShardIndices); From 2da2b0cf3b9752b2c516cd234442cd3c8d06b44c Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 16:45:29 +0200 Subject: [PATCH 04/12] WIP: Skip can_match phase on nodes that support batched query execution There's no reason whatsoever to run can_match (except the coordinator rewrite part of it) when batched query execution is used. On a per-node level we can still run it to order the shrds but shoul probably remove its use from the query phase completely if there's no sort in the query. --- .../SearchQueryThenFetchAsyncAction.java | 116 ++++++++++++++---- .../action/search/TransportSearchAction.java | 2 +- .../SearchQueryThenFetchAsyncActionTests.java | 3 +- 3 files changed, 92 insertions(+), 29 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 27674deee623e..44412c0a49395 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -29,6 +29,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.common.util.concurrent.ListenableFuture; @@ -82,6 +83,7 @@ import java.util.function.IntUnaryOperator; import static org.elasticsearch.action.search.SearchPhaseController.getTopDocsSize; +import static org.elasticsearch.search.sort.FieldSortBuilder.NAME; import static org.elasticsearch.search.sort.FieldSortBuilder.hasPrimaryFieldSort; public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction { @@ -96,6 +98,7 @@ public class SearchQueryThenFetchAsyncAction extends AbstractSearchAsyncAction shardIndexMap) { } AbstractSearchAsyncAction.doCheckNoMissingShards(getName(), request, shardsIts); final Map perNodeQueries = new HashMap<>(); - final String localNodeId = searchTransportService.transportService().getLocalNode().getId(); + final var transportService = searchTransportService.transportService(); + final String localNodeId = transportService.getLocalNode().getId(); final int numberOfShardsTotal = shardsIts.size(); for (int i = 0; i < numberOfShardsTotal; i++) { final SearchShardIterator shardRoutings = shardsIts.get(i); @@ -436,30 +441,82 @@ protected void doRun(Map shardIndexMap) { } else { final String nodeId = routing.getNodeId(); // local requests don't need batching as there's no network latency - if (localNodeId.equals(nodeId)) { - performPhaseOnShard(shardIndex, shardRoutings, routing); - } else { - var perNodeRequest = perNodeQueries.computeIfAbsent( - new CanMatchPreFilterSearchPhase.SendingTarget(routing.getClusterAlias(), nodeId), - t -> new NodeQueryRequest(request, numberOfShardsTotal, timeProvider.absoluteStartMillis(), t.clusterAlias()) - ); - final String indexUUID = routing.getShardId().getIndex().getUUID(); - perNodeRequest.shards.add( - new ShardToQuery( - concreteIndexBoosts.getOrDefault(indexUUID, DEFAULT_INDEX_BOOST), - getOriginalIndices(shardIndex).indices(), - shardIndex, - routing.getShardId(), - shardRoutings.getSearchContextId() - ) - ); - var filterForAlias = aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY); - if (filterForAlias != AliasFilter.EMPTY) { - perNodeRequest.aliasFilters.putIfAbsent(indexUUID, filterForAlias); - } + var perNodeRequest = perNodeQueries.computeIfAbsent( + new CanMatchPreFilterSearchPhase.SendingTarget(routing.getClusterAlias(), nodeId), + t -> new NodeQueryRequest(request, numberOfShardsTotal, timeProvider.absoluteStartMillis(), t.clusterAlias()) + ); + final String indexUUID = routing.getShardId().getIndex().getUUID(); + perNodeRequest.shards.add( + new ShardToQuery( + concreteIndexBoosts.getOrDefault(indexUUID, DEFAULT_INDEX_BOOST), + getOriginalIndices(shardIndex).indices(), + shardIndex, + routing.getShardId(), + shardRoutings.getSearchContextId() + ) + ); + var filterForAlias = aliasFilter.getOrDefault(indexUUID, AliasFilter.EMPTY); + if (filterForAlias != AliasFilter.EMPTY) { + perNodeRequest.aliasFilters.putIfAbsent(indexUUID, filterForAlias); } } } + final var localTarget = new CanMatchPreFilterSearchPhase.SendingTarget(request.getLocalClusterAlias(), localNodeId); + var localNodeRequest = perNodeQueries.remove(localTarget); + if (localNodeRequest != null) { + transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION).execute(new AbstractRunnable() { + @Override + protected void doRun() { + if (hasPrimaryFieldSort(request.source())) { + var pitBuilder = request.pointInTimeBuilder(); + @SuppressWarnings("rawtypes") + final MinAndMax[] minAndMax = new MinAndMax[localNodeRequest.shards.size()]; + for (int i = 0; i < minAndMax.length; i++) { + var shardToQuery = localNodeRequest.shards.get(i); + var shardId = shardToQuery.shardId; + var r = buildShardSearchRequest( + shardId, + localNodeRequest.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, request.indicesOptions()), + localNodeRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + request, + localNodeRequest.totalShards, + localNodeRequest.absoluteStartMillis, + false + ); + minAndMax[i] = searchService.canMatch(r).estimatedMinAndMax(); + } + try { + int[] indexes = CanMatchPreFilterSearchPhase.sortShards( + localNodeRequest.shards, + minAndMax, + FieldSortBuilder.getPrimaryFieldSortOrNull(request.source()).order() + ); + for (int i = 0; i < indexes.length; i++) { + ShardToQuery shardToQuery = localNodeRequest.shards.get(i); + shardToQuery = localNodeRequest.shards.set(i, shardToQuery); + localNodeRequest.shards.set(i, shardToQuery); + } + } catch (Exception e) { + // ignored, field type conflicts will be dealt with in upstream logic + // TODO: we should fail the query here, we're already seeing a field type conflict on the sort field, + // no need to actually execute the queries and go through a lot of work before we inevitably have to + // fail the search + } + } + executeWithoutBatching(localTarget, localNodeRequest); + } + + @Override + public void onFailure(Exception e) { + SearchQueryThenFetchAsyncAction.this.onPhaseFailure(NAME, "", e); + } + }); + } perNodeQueries.forEach((routing, request) -> { if (request.shards.size() == 1) { executeAsSingleRequest(routing, request.shards.getFirst()); @@ -477,8 +534,12 @@ protected void doRun(Map shardIndexMap) { executeWithoutBatching(routing, request); return; } - searchTransportService.transportService() - .sendChildRequest(connection, NODE_SEARCH_ACTION_NAME, request, task, new TransportResponseHandler() { + transportService.sendChildRequest( + connection, + NODE_SEARCH_ACTION_NAME, + request, + task, + new TransportResponseHandler() { @Override public NodeQueryResponse read(StreamInput in) throws IOException { return new NodeQueryResponse(in); @@ -531,7 +592,8 @@ public void handleException(TransportException e) { onPhaseFailure(getName(), "", cause); } } - }); + } + ); }); } diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 4231e801b0582..8a8a07f4c2b8d 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -1576,7 +1576,7 @@ public void runNewSearchPhase( task, clusters, client, - searchService.batchQueryPhase() + searchService ); } success = true; diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java index d7348833c757a..a239680a8e6c4 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncActionTests.java @@ -28,6 +28,7 @@ import org.elasticsearch.lucene.grouping.TopFieldGroups; import org.elasticsearch.search.DocValueFormat; import org.elasticsearch.search.SearchPhaseResult; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -207,7 +208,7 @@ public void sendExecuteQuery( task, SearchResponse.Clusters.EMPTY, null, - false + mock(SearchService.class) ) { @Override protected SearchPhase getNextPhase() { From 2ab1d8cc881ada9739ddccc66d4d561766bfc0a7 Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 17:18:15 +0200 Subject: [PATCH 05/12] drier --- .../SearchQueryThenFetchAsyncAction.java | 59 ++++++++----------- 1 file changed, 24 insertions(+), 35 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index 44412c0a49395..c56f2bb1da415 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -468,27 +468,12 @@ protected void doRun(Map shardIndexMap) { @Override protected void doRun() { if (hasPrimaryFieldSort(request.source())) { - var pitBuilder = request.pointInTimeBuilder(); @SuppressWarnings("rawtypes") final MinAndMax[] minAndMax = new MinAndMax[localNodeRequest.shards.size()]; for (int i = 0; i < minAndMax.length; i++) { - var shardToQuery = localNodeRequest.shards.get(i); - var shardId = shardToQuery.shardId; - var r = buildShardSearchRequest( - shardId, - localNodeRequest.localClusterAlias, - shardToQuery.shardIndex, - shardToQuery.contextId, - new OriginalIndices(shardToQuery.originalIndices, request.indicesOptions()), - localNodeRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), - pitBuilder == null ? null : pitBuilder.getKeepAlive(), - shardToQuery.boost, - request, - localNodeRequest.totalShards, - localNodeRequest.absoluteStartMillis, - false - ); - minAndMax[i] = searchService.canMatch(r).estimatedMinAndMax(); + minAndMax[i] = searchService.canMatch( + buildShardSearchRequestForLocal(localNodeRequest, localNodeRequest.shards.get(i)) + ).estimatedMinAndMax(); } try { int[] indexes = CanMatchPreFilterSearchPhase.sortShards( @@ -597,6 +582,26 @@ public void handleException(TransportException e) { }); } + private static ShardSearchRequest buildShardSearchRequestForLocal(NodeQueryRequest nodeQueryRequest, ShardToQuery shardToQuery) { + var shardId = shardToQuery.shardId; + var searchRequest = nodeQueryRequest.searchRequest; + var pitBuilder = searchRequest.pointInTimeBuilder(); + return buildShardSearchRequest( + shardId, + nodeQueryRequest.localClusterAlias, + shardToQuery.shardIndex, + shardToQuery.contextId, + new OriginalIndices(shardToQuery.originalIndices, searchRequest.indicesOptions()), + nodeQueryRequest.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), + pitBuilder == null ? null : pitBuilder.getKeepAlive(), + shardToQuery.boost, + searchRequest, + nodeQueryRequest.totalShards, + nodeQueryRequest.absoluteStartMillis, + false + ); + } + public static boolean connectionSupportsBatchedExecution(Transport.Connection connection) { return connection.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION) || connection.getNode().getVersionInformation().nodeVersion().onOrAfter(Version.V_9_1_0); @@ -647,26 +652,10 @@ static void registerNodeSearchAction( final ShardSearchRequest[] shardSearchRequests; if (hasPrimaryFieldSort(searchRequest.source())) { shardSearchRequests = new ShardSearchRequest[request.shards.size()]; - var pitBuilder = searchRequest.pointInTimeBuilder(); @SuppressWarnings("rawtypes") final MinAndMax[] minAndMax = new MinAndMax[request.shards.size()]; for (int i = 0; i < minAndMax.length; i++) { - var shardToQuery = request.shards.get(i); - var shardId = shardToQuery.shardId; - var r = buildShardSearchRequest( - shardId, - request.localClusterAlias, - shardToQuery.shardIndex, - shardToQuery.contextId, - new OriginalIndices(shardToQuery.originalIndices, request.indicesOptions()), - request.aliasFilters.getOrDefault(shardId.getIndex().getUUID(), AliasFilter.EMPTY), - pitBuilder == null ? null : pitBuilder.getKeepAlive(), - shardToQuery.boost, - searchRequest, - request.totalShards, - request.absoluteStartMillis, - false - ); + ShardSearchRequest r = buildShardSearchRequestForLocal(request, request.shards.get(i)); shardSearchRequests[i] = r; minAndMax[i] = searchService.canMatch(r).estimatedMinAndMax(); } From 7f9b9c568a6deb65c98d87d9d2def7ccdb396abd Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 18:10:01 +0200 Subject: [PATCH 06/12] -noise --- .../SearchQueryThenFetchAsyncAction.java | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index c56f2bb1da415..f78cf4d513626 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -699,12 +699,12 @@ static void registerNodeSearchAction( TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new); } - private static void releaseLocalContext(SearchService searchService, SearchRequest searchRequest, SearchPhaseResult result) { + private static void releaseLocalContext(SearchService searchService, NodeQueryRequest request, SearchPhaseResult result) { var phaseResult = result.queryResult() != null ? result.queryResult() : result.rankFeatureResult(); if (phaseResult != null && phaseResult.hasSearchContext() - && searchRequest.scroll() == null - && isPartOfPIT(searchRequest, phaseResult.getContextId()) == false) { + && request.searchRequest.scroll() == null + && isPartOfPIT(request.searchRequest, phaseResult.getContextId()) == false) { searchService.freeReaderContext(phaseResult.getContextId()); } } @@ -752,7 +752,7 @@ private static ShardSearchRequest buildShardSearchRequest( private static void executeShardTasks(QueryPerNodeState state) { int idx; - final NodeQueryRequest nodeQueryRequest = state.nodeQueryRequest; + final NodeQueryRequest nodeQueryRequest = state.searchRequest; var shards = nodeQueryRequest.shards; final int totalShardCount = shards.size(); while ((idx = state.currentShardIndex.getAndIncrement()) < totalShardCount) { @@ -842,7 +842,7 @@ private static final class QueryPerNodeState { private final AtomicInteger currentShardIndex = new AtomicInteger(); private final QueryPhaseResultConsumer queryPhaseResultConsumer; - private final NodeQueryRequest nodeQueryRequest; + private final NodeQueryRequest searchRequest; private final IntUnaryOperator shardsToQuery; private final CancellableTask task; private final ConcurrentHashMap failures = new ConcurrentHashMap<>(); @@ -857,7 +857,7 @@ private static final class QueryPerNodeState { private QueryPerNodeState( QueryPhaseResultConsumer queryPhaseResultConsumer, - NodeQueryRequest nodeQueryRequest, + NodeQueryRequest searchRequest, IntUnaryOperator shardsToQuery, CancellableTask task, TransportChannel channel, @@ -865,10 +865,10 @@ private QueryPerNodeState( @Nullable ShardSearchRequest[] shardSearchRequests ) { this.queryPhaseResultConsumer = queryPhaseResultConsumer; - this.nodeQueryRequest = nodeQueryRequest; + this.searchRequest = searchRequest; this.shardsToQuery = shardsToQuery; - this.trackTotalHitsUpTo = nodeQueryRequest.searchRequest.resolveTrackTotalHitsUpTo(); - this.topDocsSize = getTopDocsSize(nodeQueryRequest.searchRequest); + this.trackTotalHitsUpTo = searchRequest.searchRequest.resolveTrackTotalHitsUpTo(); + this.topDocsSize = getTopDocsSize(searchRequest.searchRequest); this.task = task; this.countDown = new CountDown(queryPhaseResultConsumer.getNumShards()); this.channel = channel; @@ -902,11 +902,11 @@ void onShardDone() { // translate shard indices to those on the coordinator so that it can interpret the merge result without adjustments, // also collect the set of indices that may be part of a subsequent fetch operation here so that we can release all other // indices without a roundtrip to the coordinating node - final BitSet relevantShardIndices = new BitSet(nodeQueryRequest.shards.size()); + final BitSet relevantShardIndices = new BitSet(searchRequest.shards.size()); if (mergeResult.reducedTopDocs() != null) { for (ScoreDoc scoreDoc : mergeResult.reducedTopDocs().scoreDocs) { final int localIndex = scoreDoc.shardIndex; - scoreDoc.shardIndex = nodeQueryRequest.shards.get(localIndex).shardIndex; + scoreDoc.shardIndex = searchRequest.shards.get(localIndex).shardIndex; relevantShardIndices.set(localIndex); } } @@ -945,8 +945,8 @@ private void maybeFreeContext(SearchPhaseResult result, BitSet relevantShardIndi && relevantShardIndices.get(q.getShardIndex()) == false && q.hasSuggestHits() == false && q.getRankShardResult() == null - && nodeQueryRequest.searchRequest.scroll() == null - && isPartOfPIT(nodeQueryRequest.searchRequest, q.getContextId()) == false) { + && searchRequest.searchRequest.scroll() == null + && isPartOfPIT(searchRequest.searchRequest, q.getContextId()) == false) { if (dependencies.searchService.freeReaderContext(q.getContextId())) { q.clearContextId(); } @@ -955,9 +955,7 @@ && isPartOfPIT(nodeQueryRequest.searchRequest, q.getContextId()) == false) { private void handleMergeFailure(Exception e, ChannelActionListener channelListener) { queryPhaseResultConsumer.getSuccessfulResults() - .forEach( - searchPhaseResult -> releaseLocalContext(dependencies.searchService, nodeQueryRequest.searchRequest, searchPhaseResult) - ); + .forEach(searchPhaseResult -> releaseLocalContext(dependencies.searchService, searchRequest, searchPhaseResult)); channelListener.onFailure(e); } @@ -967,7 +965,7 @@ void consumeResult(QuerySearchResult queryResult) { // TODO: dry up the bottom sort collector with the coordinator side logic in the top-level class here if (queryResult.isNull() == false // disable sort optims for scroll requests because they keep track of the last bottom doc locally (per shard) - && nodeQueryRequest.searchRequest.scroll() == null + && searchRequest.searchRequest.scroll() == null // top docs are already consumed if the query was cancelled or in error. && queryResult.hasConsumedTopDocs() == false && queryResult.topDocs() != null From a7bdf1f247b95d10a2d348a4c45764aacbb1d81b Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 19:09:32 +0200 Subject: [PATCH 07/12] -noise --- .../SearchQueryThenFetchAsyncAction.java | 44 ++++++++++++------- .../elasticsearch/search/SearchService.java | 2 +- .../search/internal/ShardSearchRequest.java | 10 +++++ 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index f78cf4d513626..dc551d44f9616 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -401,11 +401,13 @@ private static ShardSearchRequest tryRewriteWithUpdatedSortValue( // disable tracking total hits if we already reached the required estimation. if (trackTotalHitsUpTo != SearchContext.TRACK_TOTAL_HITS_ACCURATE && bottomSortCollector.getTotalHits() > trackTotalHitsUpTo) { request.source(request.source().shallowCopy().trackTotalHits(false)); + request.setRunCanMatchInQueryPhase(true); } // set the current best bottom field doc if (bottomSortCollector.getBottomSortValues() != null) { request.setBottomSortValues(bottomSortCollector.getBottomSortValues()); + request.setRunCanMatchInQueryPhase(true); } return request; } @@ -467,30 +469,32 @@ protected void doRun(Map shardIndexMap) { transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION).execute(new AbstractRunnable() { @Override protected void doRun() { - if (hasPrimaryFieldSort(request.source())) { + var shards = localNodeRequest.shards; + if (shards.size() > 1 && hasPrimaryFieldSort(request.source())) { @SuppressWarnings("rawtypes") - final MinAndMax[] minAndMax = new MinAndMax[localNodeRequest.shards.size()]; + final MinAndMax[] minAndMax = new MinAndMax[shards.size()]; for (int i = 0; i < minAndMax.length; i++) { - minAndMax[i] = searchService.canMatch( - buildShardSearchRequestForLocal(localNodeRequest, localNodeRequest.shards.get(i)) - ).estimatedMinAndMax(); + // TODO: refactor to avoid building the search request twice, here and then when actually executing the query + minAndMax[i] = searchService.canMatch(buildShardSearchRequestForLocal(localNodeRequest, shards.get(i))) + .estimatedMinAndMax(); } + try { - int[] indexes = CanMatchPreFilterSearchPhase.sortShards( - localNodeRequest.shards, + final int[] indexes = CanMatchPreFilterSearchPhase.sortShards( + shards, minAndMax, FieldSortBuilder.getPrimaryFieldSortOrNull(request.source()).order() ); + final ShardToQuery[] orig = shards.toArray(new ShardToQuery[0]); for (int i = 0; i < indexes.length; i++) { - ShardToQuery shardToQuery = localNodeRequest.shards.get(i); - shardToQuery = localNodeRequest.shards.set(i, shardToQuery); - localNodeRequest.shards.set(i, shardToQuery); + shards.set(i, orig[indexes[i]]); } } catch (Exception e) { // ignored, field type conflicts will be dealt with in upstream logic // TODO: we should fail the query here, we're already seeing a field type conflict on the sort field, // no need to actually execute the queries and go through a lot of work before we inevitably have to // fail the search + } } executeWithoutBatching(localTarget, localNodeRequest); @@ -650,14 +654,21 @@ static void registerNodeSearchAction( final SearchRequest searchRequest = request.searchRequest; final IntUnaryOperator shards; final ShardSearchRequest[] shardSearchRequests; - if (hasPrimaryFieldSort(searchRequest.source())) { - shardSearchRequests = new ShardSearchRequest[request.shards.size()]; + final int shardCount = request.shards.size(); + if (shardCount > 1 && hasPrimaryFieldSort(searchRequest.source())) { + shardSearchRequests = new ShardSearchRequest[shardCount]; @SuppressWarnings("rawtypes") - final MinAndMax[] minAndMax = new MinAndMax[request.shards.size()]; + final MinAndMax[] minAndMax = new MinAndMax[shardCount]; for (int i = 0; i < minAndMax.length; i++) { ShardSearchRequest r = buildShardSearchRequestForLocal(request, request.shards.get(i)); shardSearchRequests[i] = r; - minAndMax[i] = searchService.canMatch(r).estimatedMinAndMax(); + var canMatch = searchService.canMatch(r); + if (canMatch.canMatch()) { + r.setRunCanMatchInQueryPhase(false); + minAndMax[i] = canMatch.estimatedMinAndMax(); + } else { + assert false; + } } int[] indexes = CanMatchPreFilterSearchPhase.sortShards( request.shards, @@ -670,11 +681,10 @@ static void registerNodeSearchAction( shards = IntUnaryOperator.identity(); } final CancellableTask cancellableTask = (CancellableTask) task; - final int shardCount = request.shards.size(); - int workers = Math.min(request.searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); + int workers = Math.min(searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); final var state = new QueryPerNodeState( new QueryPhaseResultConsumer( - request.searchRequest, + searchRequest, dependencies.executor, searchService.getCircuitBreaker(), searchPhaseController, diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index 9b1f0b3f2dd0b..07c237f106b5a 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -685,7 +685,7 @@ public void executeQueryPhase(ShardSearchRequest request, CancellableTask task, threadPool ).delegateFailure((l, orig) -> { // check if we can shortcut the query phase entirely. - if (orig.canReturnNullResponseIfMatchNoDocs()) { + if (orig.canReturnNullResponseIfMatchNoDocs() && orig.runCanMatchInQueryPhase()) { assert orig.scroll() == null; ShardSearchRequest clone = new ShardSearchRequest(orig); CanMatchContext canMatchContext = new CanMatchContext( diff --git a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java index 10d2fb0e23b3b..9bce3d1163d22 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ShardSearchRequest.java @@ -101,6 +101,8 @@ public class ShardSearchRequest extends AbstractTransportRequest implements Indi */ private final boolean forceSyntheticSource; + private transient boolean runCanMatchInQueryPhase = true; + public ShardSearchRequest( OriginalIndices originalIndices, SearchRequest searchRequest, @@ -349,6 +351,14 @@ public void writeTo(StreamOutput out) throws IOException { OriginalIndices.writeOriginalIndices(originalIndices, out); } + public void setRunCanMatchInQueryPhase(boolean runCanMatchInQueryPhase) { + this.runCanMatchInQueryPhase = runCanMatchInQueryPhase; + } + + public boolean runCanMatchInQueryPhase() { + return runCanMatchInQueryPhase; + } + protected final void innerWriteTo(StreamOutput out, boolean asKey) throws IOException { shardId.writeTo(out); out.writeByte(searchType.id()); From 9e840c7a628ec00d0616a1eca556a7bdc0ebf095 Mon Sep 17 00:00:00 2001 From: Armin Date: Mon, 28 Apr 2025 19:13:25 +0200 Subject: [PATCH 08/12] -noise --- .../action/search/SearchQueryThenFetchAsyncAction.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index dc551d44f9616..fcf138ee01622 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -666,8 +666,6 @@ static void registerNodeSearchAction( if (canMatch.canMatch()) { r.setRunCanMatchInQueryPhase(false); minAndMax[i] = canMatch.estimatedMinAndMax(); - } else { - assert false; } } int[] indexes = CanMatchPreFilterSearchPhase.sortShards( From 853b755a8db2c178e08e34eea621a1f947fd0124 Mon Sep 17 00:00:00 2001 From: Armin Date: Tue, 29 Apr 2025 10:58:41 +0200 Subject: [PATCH 09/12] fix test --- .../search/CanMatchPreFilterSearchPhase.java | 10 ++-- .../SearchQueryThenFetchAsyncAction.java | 47 ++++++++----------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index 2a77f28ef14a2..859be4bcc04b8 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -22,11 +22,11 @@ import org.elasticsearch.search.CanMatchShardResponse; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.SearchShardTarget; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.MinAndMax; -import org.elasticsearch.search.sort.SortOrder; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.Transport; @@ -474,7 +474,7 @@ private synchronized List getIterator(List list = new ArrayList<>(indexTranslation.length); for (int in : indexTranslation) { list.add(shardsIts.get(in)); @@ -482,13 +482,15 @@ private synchronized List getIterator(List> int[] sortShards(List shardsIts, MinAndMax[] minAndMaxes, SortOrder order) { + public static > int[] sortShards(List shardsIts, MinAndMax[] minAndMaxes, SearchSourceBuilder source) { int bound = shardsIts.size(); List toSort = new ArrayList<>(bound); for (int i = 0; i < bound; i++) { toSort.add(i); } - Comparator> keyComparator = forciblyCast(MinAndMax.getComparator(order)); + Comparator> keyComparator = forciblyCast( + MinAndMax.getComparator(FieldSortBuilder.getPrimaryFieldSortOrNull(source).order()) + ); toSort.sort((idx1, idx2) -> { int res = keyComparator.compare(minAndMaxes[idx1], minAndMaxes[idx2]); if (res != 0) { diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index fcf138ee01622..f3b201b1d6bb4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -48,7 +48,6 @@ import org.elasticsearch.search.internal.ShardSearchContextId; import org.elasticsearch.search.internal.ShardSearchRequest; import org.elasticsearch.search.query.QuerySearchResult; -import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.MinAndMax; import org.elasticsearch.tasks.CancellableTask; import org.elasticsearch.tasks.Task; @@ -480,11 +479,7 @@ protected void doRun() { } try { - final int[] indexes = CanMatchPreFilterSearchPhase.sortShards( - shards, - minAndMax, - FieldSortBuilder.getPrimaryFieldSortOrNull(request.source()).order() - ); + final int[] indexes = CanMatchPreFilterSearchPhase.sortShards(shards, minAndMax, request.source()); final ShardToQuery[] orig = shards.toArray(new ShardToQuery[0]); for (int i = 0; i < indexes.length; i++) { shards.set(i, orig[indexes[i]]); @@ -652,31 +647,29 @@ static void registerNodeSearchAction( NodeQueryRequest::new, (request, channel, task) -> { final SearchRequest searchRequest = request.searchRequest; - final IntUnaryOperator shards; - final ShardSearchRequest[] shardSearchRequests; + ShardSearchRequest[] shardSearchRequests = null; + IntUnaryOperator shards = IntUnaryOperator.identity(); final int shardCount = request.shards.size(); if (shardCount > 1 && hasPrimaryFieldSort(searchRequest.source())) { - shardSearchRequests = new ShardSearchRequest[shardCount]; - @SuppressWarnings("rawtypes") - final MinAndMax[] minAndMax = new MinAndMax[shardCount]; - for (int i = 0; i < minAndMax.length; i++) { - ShardSearchRequest r = buildShardSearchRequestForLocal(request, request.shards.get(i)); - shardSearchRequests[i] = r; - var canMatch = searchService.canMatch(r); - if (canMatch.canMatch()) { - r.setRunCanMatchInQueryPhase(false); - minAndMax[i] = canMatch.estimatedMinAndMax(); + try { + shardSearchRequests = new ShardSearchRequest[shardCount]; + @SuppressWarnings("rawtypes") + final MinAndMax[] minAndMax = new MinAndMax[shardCount]; + for (int i = 0; i < minAndMax.length; i++) { + ShardSearchRequest r = buildShardSearchRequestForLocal(request, request.shards.get(i)); + shardSearchRequests[i] = r; + var canMatch = searchService.canMatch(r); + if (canMatch.canMatch()) { + r.setRunCanMatchInQueryPhase(false); + minAndMax[i] = canMatch.estimatedMinAndMax(); + } } + int[] indexes = CanMatchPreFilterSearchPhase.sortShards(request.shards, minAndMax, searchRequest.source()); + shards = pos -> indexes[pos]; + } catch (Exception e) { + // TODO: ignored for now but we'll be guaranteed to fail the query phase at this point, fix things to fail here + // already } - int[] indexes = CanMatchPreFilterSearchPhase.sortShards( - request.shards, - minAndMax, - FieldSortBuilder.getPrimaryFieldSortOrNull(searchRequest.source()).order() - ); - shards = pos -> indexes[pos]; - } else { - shardSearchRequests = null; - shards = IntUnaryOperator.identity(); } final CancellableTask cancellableTask = (CancellableTask) task; int workers = Math.min(searchRequest.getMaxConcurrentShardRequests(), Math.min(shardCount, searchPoolMax)); From 83bbe58d9c40c89305563aaab4a49c5deeecf22e Mon Sep 17 00:00:00 2001 From: Armin Date: Thu, 1 May 2025 09:28:24 +0200 Subject: [PATCH 10/12] fix bwc --- .../action/search/SearchQueryThenFetchAsyncAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java index f3b201b1d6bb4..563bd26cdb0bf 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchQueryThenFetchAsyncAction.java @@ -603,7 +603,7 @@ private static ShardSearchRequest buildShardSearchRequestForLocal(NodeQueryReque public static boolean connectionSupportsBatchedExecution(Transport.Connection connection) { return connection.getTransportVersion().onOrAfter(TransportVersions.BATCHED_QUERY_PHASE_VERSION) - || connection.getNode().getVersionInformation().nodeVersion().onOrAfter(Version.V_9_1_0); + && connection.getNode().getVersionInformation().nodeVersion().onOrAfter(Version.V_9_1_0); } private void executeWithoutBatching(CanMatchPreFilterSearchPhase.SendingTarget targetNode, NodeQueryRequest request) { From 756edf57d6ad1b3189ccd14a59d45b7210c04f7e Mon Sep 17 00:00:00 2001 From: Armin Date: Thu, 1 May 2025 10:30:08 +0200 Subject: [PATCH 11/12] less change --- .../search/CanMatchPreFilterSearchPhase.java | 24 ++++++++++++------- .../search/query/QueryPhase.java | 1 + 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java index 859be4bcc04b8..68c8d0a5c5bf9 100644 --- a/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java +++ b/server/src/main/java/org/elasticsearch/action/search/CanMatchPreFilterSearchPhase.java @@ -76,6 +76,9 @@ final class CanMatchPreFilterSearchPhase { private final FixedBitSet possibleMatches; private final MinAndMax[] minAndMaxes; private int numPossibleMatches; + // True if the initiating action to this can_match run is doing batched query phase execution. + // If batched query phase execution is in use, then there is no need to physically send can_match requests to other nodes + // and only the coordinating coordinator can_match logic will run. private final boolean batchQueryPhase; private CanMatchPreFilterSearchPhase( @@ -299,7 +302,9 @@ protected void doRun() { if (entry.getKey().nodeId == null) { // no target node: just mark the requests as failed - failAll(shardLevelRequests, null); + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + onOperationFailed(shard.getShardRequestIndex(), null); + } continue; } @@ -307,22 +312,23 @@ protected void doRun() { try { var connection = nodeIdToConnection.apply(sendingTarget.clusterAlias, sendingTarget.nodeId); if (batchQueryPhase && SearchQueryThenFetchAsyncAction.connectionSupportsBatchedExecution(connection)) { - failAll(shardLevelRequests, null); + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + final int idx = shard.getShardRequestIndex(); + CanMatchShardResponse shardResponse = new CanMatchShardResponse(true, null); + shardResponse.setShardIndex(idx); + onOperation(idx, shardResponse); + } } else { bwcSendCanMatchRequest(connection, canMatchNodeRequest, shardLevelRequests); } } catch (Exception e) { - failAll(shardLevelRequests, e); + for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { + onOperationFailed(shard.getShardRequestIndex(), e); + } } } } - private void failAll(List shardLevelRequests, Exception e) { - for (CanMatchNodeRequest.Shard shard : shardLevelRequests) { - onOperationFailed(shard.getShardRequestIndex(), e); - } - } - private void bwcSendCanMatchRequest( Transport.Connection connection, CanMatchNodeRequest canMatchNodeRequest, diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 5fcfb2b9766cd..50b9ec1d33533 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -13,6 +13,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.MultiReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector; From 99a094d30e4702dfa6fed2e19e4a7d6d5c5cc4eb Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Thu, 1 May 2025 08:36:50 +0000 Subject: [PATCH 12/12] [CI] Auto commit changes from spotless --- .../src/main/java/org/elasticsearch/search/query/QueryPhase.java | 1 - 1 file changed, 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java index 50b9ec1d33533..5fcfb2b9766cd 100644 --- a/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java +++ b/server/src/main/java/org/elasticsearch/search/query/QueryPhase.java @@ -13,7 +13,6 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.LeafReaderContext; -import org.apache.lucene.index.MultiReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; import org.apache.lucene.search.Collector;