diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java index 5f758e7d87..3344829859 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutor.java @@ -8,6 +8,7 @@ import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; import java.util.Map; +import java.util.Objects; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -132,13 +133,11 @@ private Protocol buildProtocolForDefaultQuery(Client client, DefaultQueryAction return protocol; } - private boolean isDefaultCursor(SearchResponse searchResponse, DefaultQueryAction queryAction) { + protected boolean isDefaultCursor(SearchResponse searchResponse, DefaultQueryAction queryAction) { if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { - if (searchResponse.getHits().getTotalHits().value < queryAction.getSqlRequest().fetchSize()) { - return false; - } else { - return true; - } + return queryAction.getSqlRequest().fetchSize() != 0 + && Objects.requireNonNull(searchResponse.getHits().getTotalHits()).value + >= queryAction.getSqlRequest().fetchSize(); } else { return !Strings.isNullOrEmpty(searchResponse.getScrollId()); } diff --git a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java index c589edcf50..e5011d1af8 100644 --- a/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java +++ b/legacy/src/main/java/org/opensearch/sql/legacy/executor/join/ElasticJoinExecutor.java @@ -99,6 +99,7 @@ public void run() throws IOException, SqlParseException { this.metaResults.setTookImMilli(joinTimeInMilli); } catch (Exception e) { LOG.error("Failed during join query run.", e); + throw new IllegalStateException("Error occurred during join query run", e); } finally { if (LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) { try { diff --git a/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutorTest.java b/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutorTest.java new file mode 100644 index 0000000000..1387412d37 --- /dev/null +++ b/legacy/src/test/java/org/opensearch/sql/legacy/executor/format/PrettyFormatRestExecutorTest.java @@ -0,0 +1,89 @@ +package org.opensearch.sql.legacy.executor.format; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.common.setting.Settings.Key.SQL_PAGINATION_API_SEARCH_AFTER; + +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnitRunner; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.sql.legacy.esdomain.LocalClusterState; +import org.opensearch.sql.legacy.query.DefaultQueryAction; +import org.opensearch.sql.legacy.request.SqlRequest; +import org.opensearch.sql.opensearch.setting.OpenSearchSettings; + +@RunWith(MockitoJUnitRunner.class) +public class PrettyFormatRestExecutorTest { + + @Mock private SearchResponse searchResponse; + @Mock private SearchHits searchHits; + @Mock private SearchHit searchHit; + @Mock private DefaultQueryAction queryAction; + @Mock private SqlRequest sqlRequest; + private PrettyFormatRestExecutor executor; + + @Before + public void setUp() { + OpenSearchSettings settings = mock(OpenSearchSettings.class); + LocalClusterState.state().setPluginSettings(settings); + when(LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(true); + when(queryAction.getSqlRequest()).thenReturn(sqlRequest); + executor = new PrettyFormatRestExecutor("jdbc"); + } + + @Test + public void testIsDefaultCursor_fetchSizeZero() { + when(sqlRequest.fetchSize()).thenReturn(0); + + assertFalse(executor.isDefaultCursor(searchResponse, queryAction)); + } + + @Test + public void testIsDefaultCursor_totalHitsLessThanFetchSize() { + when(sqlRequest.fetchSize()).thenReturn(10); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 1.0F)); + + assertFalse(executor.isDefaultCursor(searchResponse, queryAction)); + } + + @Test + public void testIsDefaultCursor_totalHitsGreaterThanOrEqualToFetchSize() { + when(sqlRequest.fetchSize()).thenReturn(5); + when(searchResponse.getHits()) + .thenReturn( + new SearchHits( + new SearchHit[] {searchHit}, new TotalHits(5, TotalHits.Relation.EQUAL_TO), 1.0F)); + + assertTrue(executor.isDefaultCursor(searchResponse, queryAction)); + } + + @Test + public void testIsDefaultCursor_PaginationApiDisabled() { + when(LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(false); + when(searchResponse.getScrollId()).thenReturn("someScrollId"); + + assertTrue(executor.isDefaultCursor(searchResponse, queryAction)); + } + + @Test + public void testIsDefaultCursor_PaginationApiDisabled_NoScrollId() { + when(LocalClusterState.state().getSettingValue(SQL_PAGINATION_API_SEARCH_AFTER)) + .thenReturn(false); + when(searchResponse.getScrollId()).thenReturn(null); + + assertFalse(executor.isDefaultCursor(searchResponse, queryAction)); + } +}