diff --git a/CHANGELOG.md b/CHANGELOG.md index aa31819ffae97..5fab88d418f10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -85,6 +85,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Update supported version for max_shard_size parameter in Shrink API ([#11439](https://github.com/opensearch-project/OpenSearch/pull/11439)) - Fix typo in API annotation check message ([11836](https://github.com/opensearch-project/OpenSearch/pull/11836)) - Update supported version for must_exist parameter in update aliases API ([#11872](https://github.com/opensearch-project/OpenSearch/pull/11872)) +- [Bug] Check phase name before SearchRequestOperationsListener onPhaseStart ([#12035](https://github.com/opensearch-project/OpenSearch/pull/12035)) ### Security diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 5b41c2a13b596..519b9592a1e0e 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -439,9 +439,11 @@ private void onPhaseEnd(SearchRequestContext searchRequestContext) { this.searchRequestContext.getSearchRequestOperationsListener().onPhaseEnd(this, searchRequestContext); } - private void onPhaseStart(SearchPhase phase) { + void onPhaseStart(SearchPhase phase) { setCurrentPhase(phase); - this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); + if (SearchPhaseName.isValidName(phase.getName())) { + this.searchRequestContext.getSearchRequestOperationsListener().onPhaseStart(this); + } } private void onRequestEnd(SearchRequestContext searchRequestContext) { diff --git a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java index 8cf92934c8a52..c6f3d4c70632d 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java +++ b/server/src/main/java/org/opensearch/action/search/SearchPhaseName.java @@ -10,6 +10,9 @@ import org.opensearch.common.annotation.PublicApi; +import java.util.HashSet; +import java.util.Set; + /** * Enum for different Search Phases in OpenSearch * @@ -25,6 +28,12 @@ public enum SearchPhaseName { CAN_MATCH("can_match"); private final String name; + private static final Set PHASE_NAMES = new HashSet<>(); + static { + for (SearchPhaseName phaseName : SearchPhaseName.values()) { + PHASE_NAMES.add(phaseName.name); + } + } SearchPhaseName(final String name) { this.name = name; @@ -33,4 +42,8 @@ public enum SearchPhaseName { public String getName() { return name; } + + public static boolean isValidName(String phaseName) { + return PHASE_NAMES.contains(phaseName); + } } diff --git a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java index 842c10b700d24..79e599ec9387b 100644 --- a/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/opensearch/action/search/TransportSearchAction.java @@ -1238,7 +1238,7 @@ private AbstractSearchAsyncAction searchAsyncAction clusters, searchRequestContext ); - return new SearchPhase(action.getName()) { + return new SearchPhase("none") { @Override public void run() { action.start(); diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index 76129341fc9a2..fddd6b32446ea 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -358,8 +358,8 @@ public void run() { } }, "message", null); - assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); - assertEquals(0, testListener.getPhaseTotal(action.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseCurrent(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); + assertEquals(0, testListener.getPhaseTotal(searchDfsQueryThenFetchAsyncAction.getSearchPhaseName())); FetchSearchPhase fetchPhase = createFetchSearchPhase(); ShardId shardId = new ShardId(randomAlphaOfLengthBetween(5, 10), randomAlphaOfLength(10), randomInt()); @@ -403,6 +403,30 @@ public void run() { assertEquals(requestIds, releasedContexts); } + public void testOnPhaseStart() { + ClusterSettings clusterSettings = new ClusterSettings(Settings.EMPTY, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS); + SearchRequestStats testListener = new SearchRequestStats(clusterSettings); + + final List requestOperationListeners = new ArrayList<>(List.of(testListener)); + SearchQueryThenFetchAsyncAction action = createSearchQueryThenFetchAsyncAction(requestOperationListeners); + + action.onPhaseStart(new SearchPhase("test") { + @Override + public void run() {} + }); + action.onPhaseStart(new SearchPhase("none") { + @Override + public void run() {} + }); + assertEquals(0, testListener.getPhaseCurrent(action.getSearchPhaseName())); + + action.onPhaseStart(new SearchPhase(action.getName()) { + @Override + public void run() {} + }); + assertEquals(1, testListener.getPhaseCurrent(action.getSearchPhaseName())); + } + public void testShardNotAvailableWithDisallowPartialFailures() { SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(false); AtomicReference exception = new AtomicReference<>();