From c20458512ff4e68f359de32bfad91a6a7487f8a1 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Tue, 5 Dec 2023 17:21:27 +0000 Subject: [PATCH] [Search Pipelines] Add request-scoped state shared between processors (#9405) To handle cases where multiple search pipeline processors need to share information, we will allocate a context holder for the lifetime of the request and pass it to each processor to get/set values. To explain this behavior and benefit from it, this change also introduces three new processors: 1. The "oversample" request processor that increases "size", storing the original size in the context. 2. The "truncate" response processor that discards results after some number, by default using the original size before oversampling. 3. The "collapse" response processor offers similar behavior to a collapse query, discarding results that have a field value in common with a higher-scoring result. Signed-off-by: Michael Froh --- CHANGELOG.md | 2 + .../search/pipeline/common/BasicMap.java | 126 ++++++ .../common/CollapseResponseProcessor.java | 122 ++++++ .../common/OversampleRequestProcessor.java | 83 ++++ .../common/ScriptRequestProcessor.java | 39 +- .../SearchPipelineCommonModulePlugin.java | 13 +- .../pipeline/common/SearchRequestMap.java | 140 +++++++ .../SearchRequestMapProcessingException.java | 10 +- .../common/TruncateHitsResponseProcessor.java | 96 +++++ .../pipeline/common/helpers/ContextUtils.java | 38 ++ .../common/helpers/SearchRequestMap.java | 395 ------------------ .../common/helpers/SearchResponseUtil.java | 93 +++++ .../CollapseResponseProcessorTests.java | 86 ++++ .../OversampleRequestProcessorTests.java | 62 +++ .../common/ScriptRequestProcessorTests.java | 28 +- .../{helpers => }/SearchRequestMapTests.java | 2 +- .../TruncateHitsResponseProcessorTests.java | 91 ++++ .../60_oversample_truncate.yml | 105 +++++ .../search_pipeline/70_script_truncate.yml | 70 ++++ .../opensearch/search/pipeline/Pipeline.java | 27 +- .../pipeline/PipelineProcessingContext.java | 38 ++ .../search/pipeline/PipelinedRequest.java | 10 +- .../opensearch/search/pipeline/Processor.java | 7 - .../pipeline/SearchPhaseResultsProcessor.java | 16 + .../pipeline/SearchPipelineService.java | 3 +- .../pipeline/SearchRequestProcessor.java | 31 +- .../pipeline/SearchResponseProcessor.java | 33 +- .../StatefulSearchRequestProcessor.java | 25 ++ .../StatefulSearchResponseProcessor.java | 27 ++ .../pipeline/SearchPipelineServiceTests.java | 90 ++++ 30 files changed, 1444 insertions(+), 464 deletions(-) create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/BasicMap.java create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMap.java rename modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/{helpers => }/SearchRequestMapProcessingException.java (76%) create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java delete mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMap.java create mode 100644 modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java create mode 100644 modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java create mode 100644 modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java rename modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/{helpers => }/SearchRequestMapTests.java (99%) create mode 100644 modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java create mode 100644 modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/60_oversample_truncate.yml create mode 100644 modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml create mode 100644 server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java create mode 100644 server/src/main/java/org/opensearch/search/pipeline/StatefulSearchRequestProcessor.java create mode 100644 server/src/main/java/org/opensearch/search/pipeline/StatefulSearchResponseProcessor.java diff --git a/CHANGELOG.md b/CHANGELOG.md index cc66cc82f00af..449d7d813f764 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -94,6 +94,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Admission control] Add Resource usage collector service and resource usage tracker ([#9890](https://github.com/opensearch-project/OpenSearch/pull/9890)) - [Admission control] Add enhancements to FS stats to include read/write time, queue size and IO time ([#10541](https://github.com/opensearch-project/OpenSearch/pull/10541)) - [Remote cluster state] Change file names for remote cluster state ([#10557](https://github.com/opensearch-project/OpenSearch/pull/10557)) +- [Search Pipelines] Add request-scoped state shared between processors (and three new processors) ([#9405](https://github.com/opensearch-project/OpenSearch/pull/9405)) +- Per request phase latency ([#10351](https://github.com/opensearch-project/OpenSearch/issues/10351)) - [Remote Store] Add repository stats for remote store([#10567](https://github.com/opensearch-project/OpenSearch/pull/10567)) - [Remote cluster state] Upload global metadata in cluster state to remote store([#10404](https://github.com/opensearch-project/OpenSearch/pull/10404)) - [Remote cluster state] Download functionality of global metadata from remote store ([#10535](https://github.com/opensearch-project/OpenSearch/pull/10535)) diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/BasicMap.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/BasicMap.java new file mode 100644 index 0000000000000..6ddc22420416b --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/BasicMap.java @@ -0,0 +1,126 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import java.util.Collection; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Function; + +/** + * Helper for map abstractions passed to scripting processors. Throws {@link UnsupportedOperationException} for almost + * all methods. Subclasses just need to implement get and put. + */ +abstract class BasicMap implements Map { + + /** + * No-args constructor. + */ + protected BasicMap() {} + + private static final String UNSUPPORTED_OP_ERR = " Method not supported in Search pipeline script"; + + @Override + public boolean isEmpty() { + throw new UnsupportedOperationException("isEmpty" + UNSUPPORTED_OP_ERR); + } + + public int size() { + throw new UnsupportedOperationException("size" + UNSUPPORTED_OP_ERR); + } + + public boolean containsKey(Object key) { + return get(key) != null; + } + + public boolean containsValue(Object value) { + throw new UnsupportedOperationException("containsValue" + UNSUPPORTED_OP_ERR); + } + + public Object remove(Object key) { + throw new UnsupportedOperationException("remove" + UNSUPPORTED_OP_ERR); + } + + public void putAll(Map m) { + throw new UnsupportedOperationException("putAll" + UNSUPPORTED_OP_ERR); + } + + public void clear() { + throw new UnsupportedOperationException("clear" + UNSUPPORTED_OP_ERR); + } + + public Set keySet() { + throw new UnsupportedOperationException("keySet" + UNSUPPORTED_OP_ERR); + } + + public Collection values() { + throw new UnsupportedOperationException("values" + UNSUPPORTED_OP_ERR); + } + + public Set> entrySet() { + throw new UnsupportedOperationException("entrySet" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object getOrDefault(Object key, Object defaultValue) { + throw new UnsupportedOperationException("getOrDefault" + UNSUPPORTED_OP_ERR); + } + + @Override + public void forEach(BiConsumer action) { + throw new UnsupportedOperationException("forEach" + UNSUPPORTED_OP_ERR); + } + + @Override + public void replaceAll(BiFunction function) { + throw new UnsupportedOperationException("replaceAll" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object putIfAbsent(String key, Object value) { + throw new UnsupportedOperationException("putIfAbsent" + UNSUPPORTED_OP_ERR); + } + + @Override + public boolean remove(Object key, Object value) { + throw new UnsupportedOperationException("remove" + UNSUPPORTED_OP_ERR); + } + + @Override + public boolean replace(String key, Object oldValue, Object newValue) { + throw new UnsupportedOperationException("replace" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object replace(String key, Object value) { + throw new UnsupportedOperationException("replace" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object computeIfAbsent(String key, Function mappingFunction) { + throw new UnsupportedOperationException("computeIfAbsent" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object computeIfPresent(String key, BiFunction remappingFunction) { + throw new UnsupportedOperationException("computeIfPresent" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object compute(String key, BiFunction remappingFunction) { + throw new UnsupportedOperationException("compute" + UNSUPPORTED_OP_ERR); + } + + @Override + public Object merge(String key, Object value, BiFunction remappingFunction) { + throw new UnsupportedOperationException("merge" + UNSUPPORTED_OP_ERR); + } +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java new file mode 100644 index 0000000000000..3e6c4fef6a559 --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/CollapseResponseProcessor.java @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.document.DocumentField; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * A simple implementation of field collapsing on search responses. Note that this is not going to work as well as + * field collapsing at the shard level, as implemented with the "collapse" parameter in a search request. Mostly + * just using this to demo the oversample / truncate_hits processors. + */ +public class CollapseResponseProcessor extends AbstractProcessor implements SearchResponseProcessor { + /** + * Key to reference this processor type from a search pipeline. + */ + public static final String TYPE = "collapse"; + static final String COLLAPSE_FIELD = "field"; + private final String collapseField; + + private CollapseResponseProcessor(String tag, String description, boolean ignoreFailure, String collapseField) { + super(tag, description, ignoreFailure); + this.collapseField = Objects.requireNonNull(collapseField); + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { + + if (response.getHits() != null) { + if (response.getHits().getCollapseField() != null) { + throw new IllegalStateException( + "Cannot collapse on " + collapseField + ". Results already collapsed on " + response.getHits().getCollapseField() + ); + } + Map collapsedHits = new LinkedHashMap<>(); + List collapseValues = new ArrayList<>(); + for (SearchHit hit : response.getHits()) { + Object fieldValue = null; + DocumentField docField = hit.getFields().get(collapseField); + if (docField != null) { + if (docField.getValues().size() > 1) { + throw new IllegalStateException( + "Failed to collapse " + hit.getId() + ": doc has multiple values for field " + collapseField + ); + } + fieldValue = docField.getValues().get(0); + } else if (hit.getSourceAsMap() != null) { + fieldValue = hit.getSourceAsMap().get(collapseField); + } + String fieldValueString; + if (fieldValue == null) { + fieldValueString = "__missing__"; + } else { + fieldValueString = fieldValue.toString(); + } + + // Results are already sorted by sort criterion. Only keep the first hit for each field. + if (collapsedHits.containsKey(fieldValueString) == false) { + collapsedHits.put(fieldValueString, hit); + collapseValues.add(fieldValue); + } + } + SearchHit[] newHits = new SearchHit[collapsedHits.size()]; + int i = 0; + for (SearchHit collapsedHit : collapsedHits.values()) { + newHits[i++] = collapsedHit; + } + SearchHits searchHits = new SearchHits( + newHits, + response.getHits().getTotalHits(), + response.getHits().getMaxScore(), + response.getHits().getSortFields(), + collapseField, + collapseValues.toArray() + ); + return SearchResponseUtil.replaceHits(searchHits, response); + } + return response; + } + + static class Factory implements Processor.Factory { + + @Override + public CollapseResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) { + String collapseField = ConfigurationUtils.readStringProperty(TYPE, tag, config, COLLAPSE_FIELD); + return new CollapseResponseProcessor(tag, description, ignoreFailure, collapseField); + } + } + +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java new file mode 100644 index 0000000000000..182cf6ba79504 --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java @@ -0,0 +1,83 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchService; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.StatefulSearchRequestProcessor; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; + +import java.util.Map; + +import static org.opensearch.search.pipeline.common.helpers.ContextUtils.applyContextPrefix; + +/** + * Multiplies the "size" parameter on the {@link SearchRequest} by the given scaling factor, storing the original value + * in the request context as "original_size". + */ +public class OversampleRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor { + + /** + * Key to reference this processor type from a search pipeline. + */ + public static final String TYPE = "oversample"; + static final String SAMPLE_FACTOR = "sample_factor"; + static final String ORIGINAL_SIZE = "original_size"; + private final double sampleFactor; + private final String contextPrefix; + + private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor, String contextPrefix) { + super(tag, description, ignoreFailure); + this.sampleFactor = sampleFactor; + this.contextPrefix = contextPrefix; + } + + @Override + public SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) { + if (request.source() != null) { + int originalSize = request.source().size(); + if (originalSize == -1) { + originalSize = SearchService.DEFAULT_SIZE; + } + requestContext.setAttribute(applyContextPrefix(contextPrefix, ORIGINAL_SIZE), originalSize); + int newSize = (int) Math.ceil(originalSize * sampleFactor); + request.source().size(newSize); + } + return request; + } + + @Override + public String getType() { + return TYPE; + } + + static class Factory implements Processor.Factory { + @Override + public OversampleRequestProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) { + double sampleFactor = ConfigurationUtils.readDoubleProperty(TYPE, tag, config, SAMPLE_FACTOR); + if (sampleFactor < 1.0) { + throw ConfigurationUtils.newConfigurationException(TYPE, tag, SAMPLE_FACTOR, "Value must be >= 1.0"); + } + String contextPrefix = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ContextUtils.CONTEXT_PREFIX_PARAMETER); + return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor, contextPrefix); + } + } +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java index 90f71fd1754e4..a4052d0892ee6 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java @@ -23,9 +23,10 @@ import org.opensearch.script.ScriptType; import org.opensearch.script.SearchScript; import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; -import org.opensearch.search.pipeline.common.helpers.SearchRequestMap; +import org.opensearch.search.pipeline.StatefulSearchRequestProcessor; import java.io.InputStream; import java.util.HashMap; @@ -38,7 +39,7 @@ * Processor that evaluates a script with a search request in its context * and then returns the modified search request. */ -public final class ScriptRequestProcessor extends AbstractProcessor implements SearchRequestProcessor { +public final class ScriptRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor { /** * Key to reference this processor type from a search pipeline. */ @@ -72,15 +73,8 @@ public final class ScriptRequestProcessor extends AbstractProcessor implements S this.scriptService = scriptService; } - /** - * Executes the script with the search request in context. - * - * @param request The search request passed into the script context. - * @return The modified search request. - * @throws Exception if an error occurs while processing the request. - */ @Override - public SearchRequest processRequest(SearchRequest request) throws Exception { + public SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) throws Exception { // assert request is not null and source is not null if (request == null || request.source() == null) { throw new IllegalArgumentException("search request must not be null"); @@ -93,10 +87,33 @@ public SearchRequest processRequest(SearchRequest request) throws Exception { searchScript = precompiledSearchScript; } // execute the script with the search request in context - searchScript.execute(Map.of("_source", new SearchRequestMap(request))); + searchScript.execute(Map.of("_source", new SearchRequestMap(request), "request_context", new RequestContextMap(requestContext))); return request; } + private static class RequestContextMap extends BasicMap { + private final PipelineProcessingContext pipelinedRequestContext; + + private RequestContextMap(PipelineProcessingContext pipelinedRequestContext) { + this.pipelinedRequestContext = pipelinedRequestContext; + } + + @Override + public Object get(Object key) { + if (key instanceof String) { + return pipelinedRequestContext.getAttribute(key.toString()); + } + return null; + } + + @Override + public Object put(String key, Object value) { + Object originalValue = get(key); + pipelinedRequestContext.setAttribute(key, value); + return originalValue; + } + } + /** * Returns the type of the processor. * diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchPipelineCommonModulePlugin.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchPipelineCommonModulePlugin.java index 49681b80fdead..5378a6721efb2 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchPipelineCommonModulePlugin.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchPipelineCommonModulePlugin.java @@ -38,12 +38,21 @@ public Map> getRequestProcesso FilterQueryRequestProcessor.TYPE, new FilterQueryRequestProcessor.Factory(parameters.namedXContentRegistry), ScriptRequestProcessor.TYPE, - new ScriptRequestProcessor.Factory(parameters.scriptService) + new ScriptRequestProcessor.Factory(parameters.scriptService), + OversampleRequestProcessor.TYPE, + new OversampleRequestProcessor.Factory() ); } @Override public Map> getResponseProcessors(Parameters parameters) { - return Map.of(RenameFieldResponseProcessor.TYPE, new RenameFieldResponseProcessor.Factory()); + return Map.of( + RenameFieldResponseProcessor.TYPE, + new RenameFieldResponseProcessor.Factory(), + TruncateHitsResponseProcessor.TYPE, + new TruncateHitsResponseProcessor.Factory(), + CollapseResponseProcessor.TYPE, + new CollapseResponseProcessor.Factory() + ); } } diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMap.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMap.java new file mode 100644 index 0000000000000..c6430b96dcbed --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMap.java @@ -0,0 +1,140 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; + +import java.util.Map; + +/** + * A custom implementation of {@link Map} that provides access to the properties of a {@link SearchRequest}'s + * {@link SearchSourceBuilder}. The class allows retrieving and modifying specific properties of the search request. + */ +class SearchRequestMap extends BasicMap implements Map { + + private final SearchSourceBuilder source; + + /** + * Constructs a new instance of the {@link SearchRequestMap} with the provided {@link SearchRequest}. + * + * @param searchRequest The SearchRequest containing the SearchSourceBuilder to be accessed. + */ + public SearchRequestMap(SearchRequest searchRequest) { + source = searchRequest.source(); + } + + /** + * Checks if the SearchSourceBuilder is empty. + * + * @return {@code true} if the SearchSourceBuilder is empty, {@code false} otherwise. + */ + @Override + public boolean isEmpty() { + return source == null; + } + + /** + * Retrieves the value associated with the specified property from the SearchSourceBuilder. + * + * @param key The SearchSourceBuilder property whose value is to be retrieved. + * @return The value associated with the specified property or null if the property has not been initialized. + * @throws IllegalArgumentException if the property name is not a String. + * @throws SearchRequestMapProcessingException if the property is not supported. + */ + @Override + public Object get(Object key) { + if (!(key instanceof String)) { + throw new IllegalArgumentException("key must be a String"); + } + // This is the explicit implementation of fetch value from source + switch ((String) key) { + case "from": + return source.from(); + case "size": + return source.size(); + case "explain": + return source.explain(); + case "version": + return source.version(); + case "seq_no_primary_term": + return source.seqNoAndPrimaryTerm(); + case "track_scores": + return source.trackScores(); + case "track_total_hits": + return source.trackTotalHitsUpTo(); + case "min_score": + return source.minScore(); + case "terminate_after": + return source.terminateAfter(); + case "profile": + return source.profile(); + default: + throw new SearchRequestMapProcessingException("Unsupported key: " + key); + } + } + + /** + * Sets the value for the specified property in the SearchSourceBuilder. + * + * @param key The property whose value is to be set. + * @param value The value to be set for the specified property. + * @return The original value associated with the property, or null if none existed. + * @throws IllegalArgumentException if the property is not a String. + * @throws SearchRequestMapProcessingException if the property is not supported or an error occurs during the setting. + */ + @Override + public Object put(String key, Object value) { + Object originalValue = get(key); + try { + switch (key) { + case "from": + source.from((Integer) value); + break; + case "size": + source.size((Integer) value); + break; + case "explain": + source.explain((Boolean) value); + break; + case "version": + source.version((Boolean) value); + break; + case "seq_no_primary_term": + source.seqNoAndPrimaryTerm((Boolean) value); + break; + case "track_scores": + source.trackScores((Boolean) value); + break; + case "track_total_hits": + source.trackTotalHitsUpTo((Integer) value); + break; + case "min_score": + source.minScore((Float) value); + break; + case "terminate_after": + source.terminateAfter((Integer) value); + break; + case "profile": + source.profile((Boolean) value); + break; + case "stats": // Not modifying stats, sorts, docvalue_fields, etc. as they require more complex handling + case "sort": + case "timeout": + case "docvalue_fields": + case "indices_boost": + default: + throw new SearchRequestMapProcessingException("Unsupported SearchRequest source property: " + key); + } + } catch (Exception e) { + throw new SearchRequestMapProcessingException("Error while setting value for SearchRequest source property: " + key, e); + } + return originalValue; + } +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMapProcessingException.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMapProcessingException.java similarity index 76% rename from modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMapProcessingException.java rename to modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMapProcessingException.java index cb1e45a20b624..2f00d0f82c2f1 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMapProcessingException.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/SearchRequestMapProcessingException.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.search.pipeline.common.helpers; +package org.opensearch.search.pipeline.common; import org.opensearch.OpenSearchException; import org.opensearch.OpenSearchWrapperException; @@ -14,12 +14,12 @@ /** * An exception that indicates an error occurred while processing a {@link SearchRequestMap}. */ -public class SearchRequestMapProcessingException extends OpenSearchException implements OpenSearchWrapperException { +class SearchRequestMapProcessingException extends OpenSearchException implements OpenSearchWrapperException { /** * Constructs a new SearchRequestMapProcessingException with the specified message. * - * @param msg The error message. + * @param msg The error message. * @param args Arguments to substitute in the error message. */ public SearchRequestMapProcessingException(String msg, Object... args) { @@ -29,9 +29,9 @@ public SearchRequestMapProcessingException(String msg, Object... args) { /** * Constructs a new SearchRequestMapProcessingException with the specified message and cause. * - * @param msg The error message. + * @param msg The error message. * @param cause The cause of the exception. - * @param args Arguments to substitute in the error message. + * @param args Arguments to substitute in the error message. */ public SearchRequestMapProcessingException(String msg, Throwable cause, Object... args) { super(msg, cause, args); diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java new file mode 100644 index 0000000000000..e3413bf41720f --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.ingest.ConfigurationUtils; +import org.opensearch.search.SearchHit; +import org.opensearch.search.pipeline.AbstractProcessor; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.search.pipeline.StatefulSearchResponseProcessor; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; +import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil; + +import java.util.Map; + +import static org.opensearch.search.pipeline.common.helpers.ContextUtils.applyContextPrefix; + +/** + * Truncates the returned search hits from the {@link SearchResponse}. If no target size is specified in the pipeline, then + * we try using the "original_size" value from the request context, which may have been set by {@link OversampleRequestProcessor}. + */ +public class TruncateHitsResponseProcessor extends AbstractProcessor implements StatefulSearchResponseProcessor { + /** + * Key to reference this processor type from a search pipeline. + */ + public static final String TYPE = "truncate_hits"; + static final String TARGET_SIZE = "target_size"; + private final int targetSize; + private final String contextPrefix; + + @Override + public String getType() { + return TYPE; + } + + private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize, String contextPrefix) { + super(tag, description, ignoreFailure); + this.targetSize = targetSize; + this.contextPrefix = contextPrefix; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) { + int size; + if (targetSize < 0) { // No value specified in processor config. Use context value instead. + String key = applyContextPrefix(contextPrefix, OversampleRequestProcessor.ORIGINAL_SIZE); + Object o = requestContext.getAttribute(key); + if (o == null) { + throw new IllegalStateException("Must specify " + TARGET_SIZE + " unless an earlier processor set " + key); + } + size = (int) o; + } else { + size = targetSize; + } + if (response.getHits() != null && response.getHits().getHits().length > size) { + SearchHit[] newHits = new SearchHit[size]; + System.arraycopy(response.getHits().getHits(), 0, newHits, 0, size); + return SearchResponseUtil.replaceHits(newHits, response); + } + return response; + } + + static class Factory implements Processor.Factory { + @Override + public TruncateHitsResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + PipelineContext pipelineContext + ) { + Integer targetSize = ConfigurationUtils.readIntProperty(TYPE, tag, config, TARGET_SIZE, null); + if (targetSize == null) { + // Use -1 as an "unset" marker to avoid repeated unboxing of an Integer. + targetSize = -1; + } else { + // Explicitly set values must be >= 0. + if (targetSize < 0) { + throw ConfigurationUtils.newConfigurationException(TYPE, tag, TARGET_SIZE, "Value must be >= 0"); + } + } + String contextPrefix = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ContextUtils.CONTEXT_PREFIX_PARAMETER); + return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize, contextPrefix); + } + } +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java new file mode 100644 index 0000000000000..9697da85dbecf --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common.helpers; + +/** + * Helpers for working with request-scoped context. + */ +public final class ContextUtils { + private ContextUtils() {} + + /** + * Parameter that can be passed to a stateful processor to avoid collisions between contextual variables by + * prefixing them with distinct qualifiers. + */ + public static final String CONTEXT_PREFIX_PARAMETER = "context_prefix"; + + /** + * Replaces a "global" variable name with one scoped to a given context prefix (unless prefix is null or empty). + * @param contextPrefix the prefix qualifier for the variable + * @param variableName the generic "global" form of the context variable + * @return the variableName prefixed with contextPrefix followed by ".", or just variableName if contextPrefix is null or empty + */ + public static String applyContextPrefix(String contextPrefix, String variableName) { + String contextVariable; + if (contextPrefix != null && contextPrefix.isEmpty() == false) { + contextVariable = contextPrefix + "." + variableName; + } else { + contextVariable = variableName; + } + return contextVariable; + } +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMap.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMap.java deleted file mode 100644 index 7af3ac66be146..0000000000000 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMap.java +++ /dev/null @@ -1,395 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ - -package org.opensearch.search.pipeline.common.helpers; - -import org.opensearch.action.search.SearchRequest; -import org.opensearch.search.builder.SearchSourceBuilder; - -import java.util.Collection; -import java.util.Map; -import java.util.Set; -import java.util.function.BiConsumer; -import java.util.function.BiFunction; -import java.util.function.Function; - -/** - * A custom implementation of {@link Map} that provides access to the properties of a {@link SearchRequest}'s - * {@link SearchSourceBuilder}. The class allows retrieving and modifying specific properties of the search request. - */ -public class SearchRequestMap implements Map { - private static final String UNSUPPORTED_OP_ERR = " Method not supported in Search pipeline script"; - - private final SearchSourceBuilder source; - - /** - * Constructs a new instance of the {@link SearchRequestMap} with the provided {@link SearchRequest}. - * - * @param searchRequest The SearchRequest containing the SearchSourceBuilder to be accessed. - */ - public SearchRequestMap(SearchRequest searchRequest) { - source = searchRequest.source(); - } - - /** - * Retrieves the number of properties in the SearchSourceBuilder. - * - * @return The number of properties in the SearchSourceBuilder. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public int size() { - throw new UnsupportedOperationException("size" + UNSUPPORTED_OP_ERR); - } - - /** - * Checks if the SearchSourceBuilder is empty. - * - * @return {@code true} if the SearchSourceBuilder is empty, {@code false} otherwise. - */ - @Override - public boolean isEmpty() { - return source == null; - } - - /** - * Checks if the SearchSourceBuilder contains the specified property. - * - * @param key The property to check for. - * @return {@code true} if the SearchSourceBuilder contains the specified property, {@code false} otherwise. - */ - @Override - public boolean containsKey(Object key) { - return get(key) != null; - } - - /** - * Checks if the SearchSourceBuilder contains the specified value. - * - * @param value The value to check for. - * @return {@code true} if the SearchSourceBuilder contains the specified value, {@code false} otherwise. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public boolean containsValue(Object value) { - throw new UnsupportedOperationException("containsValue" + UNSUPPORTED_OP_ERR); - } - - /** - * Retrieves the value associated with the specified property from the SearchSourceBuilder. - * - * @param key The SearchSourceBuilder property whose value is to be retrieved. - * @return The value associated with the specified property or null if the property has not been initialized. - * @throws IllegalArgumentException if the property name is not a String. - * @throws SearchRequestMapProcessingException if the property is not supported. - */ - @Override - public Object get(Object key) { - if (!(key instanceof String)) { - throw new IllegalArgumentException("key must be a String"); - } - // This is the explicit implementation of fetch value from source - switch ((String) key) { - case "from": - return source.from(); - case "size": - return source.size(); - case "explain": - return source.explain(); - case "version": - return source.version(); - case "seq_no_primary_term": - return source.seqNoAndPrimaryTerm(); - case "track_scores": - return source.trackScores(); - case "track_total_hits": - return source.trackTotalHitsUpTo(); - case "min_score": - return source.minScore(); - case "terminate_after": - return source.terminateAfter(); - case "profile": - return source.profile(); - default: - throw new SearchRequestMapProcessingException("Unsupported key: " + key); - } - } - - /** - * Sets the value for the specified property in the SearchSourceBuilder. - * - * @param key The property whose value is to be set. - * @param value The value to be set for the specified property. - * @return The original value associated with the property, or null if none existed. - * @throws IllegalArgumentException if the property is not a String. - * @throws SearchRequestMapProcessingException if the property is not supported or an error occurs during the setting. - */ - @Override - public Object put(String key, Object value) { - Object originalValue = get(key); - try { - switch (key) { - case "from": - source.from((Integer) value); - break; - case "size": - source.size((Integer) value); - break; - case "explain": - source.explain((Boolean) value); - break; - case "version": - source.version((Boolean) value); - break; - case "seq_no_primary_term": - source.seqNoAndPrimaryTerm((Boolean) value); - break; - case "track_scores": - source.trackScores((Boolean) value); - break; - case "track_total_hits": - source.trackTotalHitsUpTo((Integer) value); - break; - case "min_score": - source.minScore((Float) value); - break; - case "terminate_after": - source.terminateAfter((Integer) value); - break; - case "profile": - source.profile((Boolean) value); - break; - case "stats": // Not modifying stats, sorts, docvalue_fields, etc. as they require more complex handling - case "sort": - case "timeout": - case "docvalue_fields": - case "indices_boost": - default: - throw new SearchRequestMapProcessingException("Unsupported SearchRequest source property: " + key); - } - } catch (Exception e) { - throw new SearchRequestMapProcessingException("Error while setting value for SearchRequest source property: " + key, e); - } - return originalValue; - } - - /** - * Removes the specified property from the SearchSourceBuilder. - * - * @param key The name of the property that will be removed. - * @return The value associated with the property before it was removed, or null if the property was not found. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object remove(Object key) { - throw new UnsupportedOperationException("remove" + UNSUPPORTED_OP_ERR); - } - - /** - * Sets all the properties from the specified map to the SearchSourceBuilder. - * - * @param m The map containing the properties to be set. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public void putAll(Map m) { - throw new UnsupportedOperationException("putAll" + UNSUPPORTED_OP_ERR); - } - - /** - * Removes all properties from the SearchSourceBuilder. - * - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public void clear() { - throw new UnsupportedOperationException("clear" + UNSUPPORTED_OP_ERR); - } - - /** - * Returns a set view of the property names in the SearchSourceBuilder. - * - * @return A set view of the property names in the SearchSourceBuilder. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Set keySet() { - throw new UnsupportedOperationException("keySet" + UNSUPPORTED_OP_ERR); - } - - /** - * Returns a collection view of the property values in the SearchSourceBuilder. - * - * @return A collection view of the property values in the SearchSourceBuilder. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Collection values() { - throw new UnsupportedOperationException("values" + UNSUPPORTED_OP_ERR); - } - - /** - * Returns a set view of the properties in the SearchSourceBuilder. - * - * @return A set view of the properties in the SearchSourceBuilder. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Set> entrySet() { - throw new UnsupportedOperationException("entrySet" + UNSUPPORTED_OP_ERR); - } - - /** - * Returns the value to which the specified property has, or the defaultValue if the property is not present in the - * SearchSourceBuilder. - * - * @param key The property whose associated value is to be returned. - * @param defaultValue The default value to be returned if the property is not present. - * @return The value to which the specified property has, or the defaultValue if the property is not present. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object getOrDefault(Object key, Object defaultValue) { - throw new UnsupportedOperationException("getOrDefault" + UNSUPPORTED_OP_ERR); - } - - /** - * Performs the given action for each property in the SearchSourceBuilder until all properties have been processed or the - * action throws an exception - * - * @param action The action to be performed for each property. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public void forEach(BiConsumer action) { - throw new UnsupportedOperationException("forEach" + UNSUPPORTED_OP_ERR); - } - - /** - * Replaces each property's value with the result of invoking the given function on that property until all properties have - * been processed or the function throws an exception. - * - * @param function The function to apply to each property. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public void replaceAll(BiFunction function) { - throw new UnsupportedOperationException("replaceAll" + UNSUPPORTED_OP_ERR); - } - - /** - * If the specified property is not already associated with a value, associates it with the given value and returns null, - * else returns the current value. - * - * @param key The property whose value is to be set if absent. - * @param value The value to be associated with the specified property. - * @return The current value associated with the property, or null if the property is not present. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object putIfAbsent(String key, Object value) { - throw new UnsupportedOperationException("putIfAbsent" + UNSUPPORTED_OP_ERR); - } - - /** - * Removes the property only if it has the given value. - * - * @param key The property to be removed. - * @param value The value expected to be associated with the property. - * @return {@code true} if the entry was removed, {@code false} otherwise. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public boolean remove(Object key, Object value) { - throw new UnsupportedOperationException("remove" + UNSUPPORTED_OP_ERR); - } - - /** - * Replaces the specified property only if it has the given value. - * - * @param key The property to be replaced. - * @param oldValue The value expected to be associated with the property. - * @param newValue The value to be associated with the property. - * @return {@code true} if the property was replaced, {@code false} otherwise. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public boolean replace(String key, Object oldValue, Object newValue) { - throw new UnsupportedOperationException("replace" + UNSUPPORTED_OP_ERR); - } - - /** - * Replaces the specified property only if it has the given value. - * - * @param key The property to be replaced. - * @param value The value to be associated with the property. - * @return The previous value associated with the property, or null if the property was not found. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object replace(String key, Object value) { - throw new UnsupportedOperationException("replace" + UNSUPPORTED_OP_ERR); - } - - /** - * The computed value associated with the property, or null if the property is not present. - * - * @param key The property whose value is to be computed if absent. - * @param mappingFunction The function to compute a value based on the property. - * @return The computed value associated with the property, or null if the property is not present. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object computeIfAbsent(String key, Function mappingFunction) { - throw new UnsupportedOperationException("computeIfAbsent" + UNSUPPORTED_OP_ERR); - } - - /** - * If the value for the specified property is present, attempts to compute a new mapping given the property and its current - * mapped value. - * - * @param key The property for which the mapping is to be computed. - * @param remappingFunction The function to compute a new mapping. - * @return The new value associated with the property, or null if the property is not present. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object computeIfPresent(String key, BiFunction remappingFunction) { - throw new UnsupportedOperationException("computeIfPresent" + UNSUPPORTED_OP_ERR); - } - - /** - * If the value for the specified property is present, attempts to compute a new mapping given the property and its current - * mapped value, or removes the property if the computed value is null. - * - * @param key The property for which the mapping is to be computed. - * @param remappingFunction The function to compute a new mapping. - * @return The new value associated with the property, or null if the property is not present. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object compute(String key, BiFunction remappingFunction) { - throw new UnsupportedOperationException("compute" + UNSUPPORTED_OP_ERR); - } - - /** - * If the specified property is not already associated with a value or is associated with null, associates it with the - * given non-null value. Otherwise, replaces the associated value with the results of applying the given - * remapping function to the current and new values. - * - * @param key The property for which the mapping is to be merged. - * @param value The non-null value to be merged with the existing value. - * @param remappingFunction The function to merge the existing and new values. - * @return The new value associated with the property, or null if the property is not present. - * @throws UnsupportedOperationException always, as the method is not supported. - */ - @Override - public Object merge(String key, Object value, BiFunction remappingFunction) { - throw new UnsupportedOperationException("merge" + UNSUPPORTED_OP_ERR); - } -} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java new file mode 100644 index 0000000000000..0710548c6429f --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java @@ -0,0 +1,93 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common.helpers; + +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.SearchResponseSections; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.profile.SearchProfileShardResults; + +/** + * Helper methods for manipulating {@link SearchResponse}. + */ +public final class SearchResponseUtil { + private SearchResponseUtil() { + + } + + /** + * Construct a new {@link SearchResponse} based on an existing one, replacing just the {@link SearchHits}. + * @param newHits new {@link SearchHits} + * @param response the existing search response + * @return a new search response where the {@link SearchHits} has been replaced + */ + public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) { + SearchResponseSections searchResponseSections; + if (response.getAggregations() == null || response.getAggregations() instanceof InternalAggregations) { + // We either have no aggregations, or we have Writeable InternalAggregations. + // Either way, we can produce a Writeable InternalSearchResponse. + searchResponseSections = new InternalSearchResponse( + newHits, + (InternalAggregations) response.getAggregations(), + response.getSuggest(), + new SearchProfileShardResults(response.getProfileResults()), + response.isTimedOut(), + response.isTerminatedEarly(), + response.getNumReducePhases() + ); + } else { + // We have non-Writeable Aggregations, so the whole SearchResponseSections is non-Writeable. + searchResponseSections = new SearchResponseSections( + newHits, + response.getAggregations(), + response.getSuggest(), + response.isTimedOut(), + response.isTerminatedEarly(), + new SearchProfileShardResults(response.getProfileResults()), + response.getNumReducePhases() + ); + } + + return new SearchResponse( + searchResponseSections, + response.getScrollId(), + response.getTotalShards(), + response.getSuccessfulShards(), + response.getSkippedShards(), + response.getTook().millis(), + response.getShardFailures(), + response.getClusters(), + response.pointInTimeId() + ); + } + + /** + * Convenience method when only replacing the {@link SearchHit} array within the {@link SearchHits} in a {@link SearchResponse}. + * @param newHits the new array of {@link SearchHit} elements. + * @param response the search response to update + * @return a {@link SearchResponse} where the underlying array of {@link SearchHit} within the {@link SearchHits} has been replaced. + */ + public static SearchResponse replaceHits(SearchHit[] newHits, SearchResponse response) { + if (response.getHits() == null) { + throw new IllegalStateException("Response must have hits"); + } + SearchHits searchHits = new SearchHits( + newHits, + response.getHits().getTotalHits(), + response.getHits().getMaxScore(), + response.getHits().getSortFields(), + response.getHits().getCollapseField(), + response.getHits().getCollapseValues() + ); + return replaceHits(searchHits, response); + } +} diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java new file mode 100644 index 0000000000000..cda011f24fea1 --- /dev/null +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/CollapseResponseProcessorTests.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.common.document.DocumentField; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class CollapseResponseProcessorTests extends OpenSearchTestCase { + public void testWithDocumentFields() { + testProcessor(true); + } + + public void testWithSourceField() { + testProcessor(false); + } + + private void testProcessor(boolean includeDocField) { + Map config = new HashMap<>(Map.of(CollapseResponseProcessor.COLLAPSE_FIELD, "groupid")); + CollapseResponseProcessor processor = new CollapseResponseProcessor.Factory().create( + Collections.emptyMap(), + null, + null, + false, + config, + null + ); + int numHits = randomIntBetween(1, 100); + SearchResponse inputResponse = generateResponse(numHits, includeDocField); + + SearchResponse processedResponse = processor.processResponse(new SearchRequest(), inputResponse); + if (numHits % 2 == 0) { + assertEquals(numHits / 2, processedResponse.getHits().getHits().length); + } else { + assertEquals(numHits / 2 + 1, processedResponse.getHits().getHits().length); + } + for (SearchHit collapsedHit : processedResponse.getHits()) { + assertEquals(0, collapsedHit.docId() % 2); + } + assertEquals("groupid", processedResponse.getHits().getCollapseField()); + assertEquals(processedResponse.getHits().getHits().length, processedResponse.getHits().getCollapseValues().length); + for (int i = 0; i < processedResponse.getHits().getHits().length; i++) { + assertEquals(i, processedResponse.getHits().getCollapseValues()[i]); + } + } + + private static SearchResponse generateResponse(int numHits, boolean includeDocField) { + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + Map docFields; + int groupValue = i / 2; + if (includeDocField) { + docFields = Map.of("groupid", new DocumentField("groupid", List.of(groupValue))); + } else { + docFields = Collections.emptyMap(); + } + SearchHit hit = new SearchHit(i, Integer.toString(i), docFields, Collections.emptyMap()); + hit.sourceRef(new BytesArray("{\"groupid\": " + groupValue + "}")); + hitsArray[i] = hit; + } + SearchHits searchHits = new SearchHits( + hitsArray, + new TotalHits(Math.max(numHits, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + 1.0f + ); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, false, 0); + return new SearchResponse(internalSearchResponse, null, 1, 1, 0, 10, null, null); + } +} diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java new file mode 100644 index 0000000000000..96e99dff9cc03 --- /dev/null +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java @@ -0,0 +1,62 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class OversampleRequestProcessorTests extends OpenSearchTestCase { + + public void testEmptySource() { + OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory(); + Map config = new HashMap<>(Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0)); + OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + SearchRequest request = new SearchRequest(); + PipelineProcessingContext context = new PipelineProcessingContext(); + SearchRequest transformedRequest = processor.processRequest(request, context); + assertEquals(request, transformedRequest); + assertNull(context.getAttribute("original_size")); + } + + public void testBasicBehavior() { + OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory(); + Map config = new HashMap<>(Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0)); + OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10); + SearchRequest request = new SearchRequest().source(sourceBuilder); + PipelineProcessingContext context = new PipelineProcessingContext(); + SearchRequest transformedRequest = processor.processRequest(request, context); + assertEquals(30, transformedRequest.source().size()); + assertEquals(10, context.getAttribute("original_size")); + } + + public void testContextPrefix() { + OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory(); + Map config = new HashMap<>( + Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0, ContextUtils.CONTEXT_PREFIX_PARAMETER, "foo") + ); + OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10); + SearchRequest request = new SearchRequest().source(sourceBuilder); + PipelineProcessingContext context = new PipelineProcessingContext(); + SearchRequest transformedRequest = processor.processRequest(request, context); + assertEquals(30, transformedRequest.source().size()); + assertEquals(10, context.getAttribute("foo.original_size")); + } +} diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java index fde9757312e30..b372b220b71ac 100644 --- a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java @@ -18,7 +18,7 @@ import org.opensearch.script.ScriptType; import org.opensearch.script.SearchScript; import org.opensearch.search.builder.SearchSourceBuilder; -import org.opensearch.search.pipeline.common.helpers.SearchRequestMap; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.test.OpenSearchTestCase; import org.junit.Before; @@ -27,8 +27,6 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.hamcrest.core.Is.is; - public class ScriptRequestProcessorTests extends OpenSearchTestCase { private ScriptService scriptService; @@ -87,7 +85,7 @@ public void testScriptingWithoutPrecompiledScriptFactory() throws Exception { searchRequest.source(createSearchSourceBuilder()); assertNotNull(searchRequest); - processor.processRequest(searchRequest); + processor.processRequest(searchRequest, new PipelineProcessingContext()); assertSearchRequest(searchRequest); } @@ -104,7 +102,7 @@ public void testScriptingWithPrecompiledIngestScript() throws Exception { searchRequest.source(createSearchSourceBuilder()); assertNotNull(searchRequest); - processor.processRequest(searchRequest); + processor.processRequest(searchRequest, new PipelineProcessingContext()); assertSearchRequest(searchRequest); } @@ -124,15 +122,15 @@ private SearchSourceBuilder createSearchSourceBuilder() { } private void assertSearchRequest(SearchRequest searchRequest) { - assertThat(searchRequest.source().from(), is(20)); - assertThat(searchRequest.source().size(), is(30)); - assertThat(searchRequest.source().explain(), is(false)); - assertThat(searchRequest.source().version(), is(false)); - assertThat(searchRequest.source().seqNoAndPrimaryTerm(), is(false)); - assertThat(searchRequest.source().trackScores(), is(false)); - assertThat(searchRequest.source().trackTotalHitsUpTo(), is(4)); - assertThat(searchRequest.source().minScore(), is(2.0f)); - assertThat(searchRequest.source().timeout(), is(new TimeValue(60, TimeUnit.SECONDS))); - assertThat(searchRequest.source().terminateAfter(), is(6)); + assertEquals(20, searchRequest.source().from()); + assertEquals(30, searchRequest.source().size()); + assertFalse(searchRequest.source().explain()); + assertFalse(searchRequest.source().version()); + assertFalse(searchRequest.source().seqNoAndPrimaryTerm()); + assertFalse(searchRequest.source().trackScores()); + assertEquals(4, searchRequest.source().trackTotalHitsUpTo().intValue()); + assertEquals(2.0f, searchRequest.source().minScore(), 0.0001); + assertEquals(new TimeValue(60, TimeUnit.SECONDS), searchRequest.source().timeout()); + assertEquals(6, searchRequest.source().terminateAfter()); } } diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMapTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/SearchRequestMapTests.java similarity index 99% rename from modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMapTests.java rename to modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/SearchRequestMapTests.java index 5572f28335e1c..c982ada7b5ea5 100644 --- a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/helpers/SearchRequestMapTests.java +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/SearchRequestMapTests.java @@ -5,7 +5,7 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.search.pipeline.common.helpers; +package org.opensearch.search.pipeline.common; import org.opensearch.action.search.SearchRequest; import org.opensearch.search.builder.SearchSourceBuilder; diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java new file mode 100644 index 0000000000000..7615225c7f77e --- /dev/null +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java @@ -0,0 +1,91 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class TruncateHitsResponseProcessorTests extends OpenSearchTestCase { + + public void testBasicBehavior() { + int targetSize = randomInt(50); + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + Map config = new HashMap<>(Map.of(TruncateHitsResponseProcessor.TARGET_SIZE, targetSize)); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + SearchResponse transformedResponse = processor.processResponse(new SearchRequest(), response, new PipelineProcessingContext()); + assertEquals(Math.min(targetSize, numHits), transformedResponse.getHits().getHits().length); + } + + public void testTargetSizePassedViaContext() { + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null); + + int targetSize = randomInt(50); + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + PipelineProcessingContext requestContext = new PipelineProcessingContext(); + requestContext.setAttribute("original_size", targetSize); + SearchResponse transformedResponse = processor.processResponse(new SearchRequest(), response, requestContext); + assertEquals(Math.min(targetSize, numHits), transformedResponse.getHits().getHits().length); + } + + public void testTargetSizePassedViaContextWithPrefix() { + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + Map config = new HashMap<>(Map.of(ContextUtils.CONTEXT_PREFIX_PARAMETER, "foo")); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + int targetSize = randomInt(50); + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + PipelineProcessingContext requestContext = new PipelineProcessingContext(); + requestContext.setAttribute("foo.original_size", targetSize); + SearchResponse transformedResponse = processor.processResponse(new SearchRequest(), response, requestContext); + assertEquals(Math.min(targetSize, numHits), transformedResponse.getHits().getHits().length); + } + + public void testTargetSizeMissing() { + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null); + + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + assertThrows( + IllegalStateException.class, + () -> processor.processResponse(new SearchRequest(), response, new PipelineProcessingContext()) + ); + } + + private static SearchResponse constructResponse(int numHits) { + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + hitsArray[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); + } + SearchHits searchHits = new SearchHits( + hitsArray, + new TotalHits(Math.max(numHits, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + 1.0f + ); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, false, 0); + return new SearchResponse(internalSearchResponse, null, 1, 1, 0, 10, null, null); + } +} diff --git a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/60_oversample_truncate.yml b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/60_oversample_truncate.yml new file mode 100644 index 0000000000000..1f9e95084322d --- /dev/null +++ b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/60_oversample_truncate.yml @@ -0,0 +1,105 @@ +--- +teardown: + - do: + search_pipeline.delete: + id: "my_pipeline" + ignore: 404 + +--- +"Test state propagating from oversample to truncate_hits processor": + - do: + search_pipeline.put: + id: "my_pipeline" + body: > + { + "description": "_description", + "request_processors": [ + { + "oversample" : { + "sample_factor" : 2 + } + } + ], + "response_processors": [ + { + "collapse" : { + "field" : "group_id" + } + }, + { + "truncate_hits" : {} + } + ] + } + - match: { acknowledged: true } + + - do: + index: + index: test + id: 1 + body: { + "group_id": "a", + "popularity" : 1 + } + - do: + index: + index: test + id: 2 + body: { + "group_id": "a", + "popularity" : 2 + } + - do: + index: + index: test + id: 3 + body: { + "group_id": "b", + "popularity" : 3 + } + - do: + index: + index: test + id: 4 + body: { + "group_id": "b", + "popularity" : 4 + } + - do: + indices.refresh: + index: test + + - do: + search: + body: { + "query" : { + "function_score" : { + "field_value_factor" : { + "field" : "popularity" + } + } + }, + "size" : 2 + } + - match: { hits.total.value: 4 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.1._id: "3" } + + - do: + search: + search_pipeline: my_pipeline + body: { + "query" : { + "function_score" : { + "field_value_factor" : { + "field" : "popularity" + } + } + }, + "size" : 2 + } + - match: { hits.total.value: 4 } + - length: { hits.hits: 2 } + - match: { hits.hits.0._id: "4" } + - match: { hits.hits.1._id: "2" } diff --git a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml new file mode 100644 index 0000000000000..9c9f6747e9bdc --- /dev/null +++ b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml @@ -0,0 +1,70 @@ +--- +teardown: + - do: + search_pipeline.delete: + id: "my_pipeline" + ignore: 404 + +--- +"Test state propagating from script request to truncate_hits processor": + - do: + search_pipeline.put: + id: "my_pipeline" + body: > + { + "description": "_description", + "request_processors": [ + { + "script" : { + "source" : "ctx.request_context['foo.original_size'] = 2" + } + } + ], + "response_processors": [ + { + "truncate_hits" : { + "context_prefix" : "foo" + } + } + ] + } + - match: { acknowledged: true } + + - do: + index: + index: test + id: 1 + body: {} + - do: + index: + index: test + id: 2 + body: {} + - do: + index: + index: test + id: 3 + body: {} + - do: + index: + index: test + id: 4 + body: {} + - do: + indices.refresh: + index: test + + - do: + search: + body: { + } + - match: { hits.total.value: 4 } + - length: { hits.hits: 4 } + + - do: + search: + search_pipeline: my_pipeline + body: { + } + - match: { hits.total.value: 4 } + - length: { hits.hits: 2 } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java index 8bab961423f91..c88dfb2060393 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Pipeline.java @@ -119,7 +119,8 @@ protected void afterResponseProcessor(Processor processor, long timeInNanos) {} protected void onResponseProcessorFailed(Processor processor) {} - void transformRequest(SearchRequest request, ActionListener requestListener) throws SearchPipelineProcessingException { + void transformRequest(SearchRequest request, ActionListener requestListener, PipelineProcessingContext requestContext) + throws SearchPipelineProcessingException { if (searchRequestProcessors.isEmpty()) { requestListener.onResponse(request); return; @@ -137,7 +138,7 @@ void transformRequest(SearchRequest request, ActionListener reque return; } - ActionListener finalListener = getTerminalSearchRequestActionListener(requestListener); + ActionListener finalListener = getTerminalSearchRequestActionListener(requestListener, requestContext); // Chain listeners back-to-front ActionListener currentListener = finalListener; @@ -147,7 +148,7 @@ void transformRequest(SearchRequest request, ActionListener reque currentListener = ActionListener.wrap(r -> { long start = relativeTimeSupplier.getAsLong(); beforeRequestProcessor(processor); - processor.processRequestAsync(r, ActionListener.wrap(rr -> { + processor.processRequestAsync(r, requestContext, ActionListener.wrap(rr -> { long took = TimeUnit.NANOSECONDS.toMillis(relativeTimeSupplier.getAsLong() - start); afterRequestProcessor(processor, took); nextListener.onResponse(rr); @@ -176,13 +177,16 @@ void transformRequest(SearchRequest request, ActionListener reque currentListener.onResponse(request); } - private ActionListener getTerminalSearchRequestActionListener(ActionListener requestListener) { + private ActionListener getTerminalSearchRequestActionListener( + ActionListener requestListener, + PipelineProcessingContext requestContext + ) { final long pipelineStart = relativeTimeSupplier.getAsLong(); return ActionListener.wrap(r -> { long took = TimeUnit.NANOSECONDS.toMillis(relativeTimeSupplier.getAsLong() - pipelineStart); afterTransformRequest(took); - requestListener.onResponse(new PipelinedRequest(this, r)); + requestListener.onResponse(new PipelinedRequest(this, r, requestContext)); }, e -> { long took = TimeUnit.NANOSECONDS.toMillis(relativeTimeSupplier.getAsLong() - pipelineStart); afterTransformRequest(took); @@ -191,7 +195,11 @@ private ActionListener getTerminalSearchRequestActionListener(Act }); } - ActionListener transformResponseListener(SearchRequest request, ActionListener responseListener) { + ActionListener transformResponseListener( + SearchRequest request, + ActionListener responseListener, + PipelineProcessingContext requestContext + ) { if (searchResponseProcessors.isEmpty()) { // No response transformation necessary return responseListener; @@ -219,7 +227,7 @@ ActionListener transformResponseListener(SearchRequest request, responseListener = ActionListener.wrap(r -> { beforeResponseProcessor(processor); final long start = relativeTimeSupplier.getAsLong(); - processor.processResponseAsync(request, r, ActionListener.wrap(rr -> { + processor.processResponseAsync(request, r, requestContext, ActionListener.wrap(rr -> { long took = TimeUnit.NANOSECONDS.toMillis(relativeTimeSupplier.getAsLong() - start); afterResponseProcessor(processor, took); currentFinalListener.onResponse(rr); @@ -257,14 +265,15 @@ void runSearchPhaseResultsTransformer( SearchPhaseResults searchPhaseResult, SearchPhaseContext context, String currentPhase, - String nextPhase + String nextPhase, + PipelineProcessingContext requestContext ) throws SearchPipelineProcessingException { try { for (SearchPhaseResultsProcessor searchPhaseResultsProcessor : searchPhaseResultsProcessors) { if (currentPhase.equals(searchPhaseResultsProcessor.getBeforePhase().getName()) && nextPhase.equals(searchPhaseResultsProcessor.getAfterPhase().getName())) { try { - searchPhaseResultsProcessor.process(searchPhaseResult, context); + searchPhaseResultsProcessor.process(searchPhaseResult, context, requestContext); } catch (Exception e) { if (searchPhaseResultsProcessor.isIgnoreFailure()) { logger.warn( diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java b/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java new file mode 100644 index 0000000000000..a1f2b8b99d958 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelineProcessingContext.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline; + +import java.util.HashMap; +import java.util.Map; + +/** + * A holder for state that is passed through each processor in the pipeline. + */ +public class PipelineProcessingContext { + private final Map attributes = new HashMap<>(); + + /** + * Set a generic attribute in the state for this request. Overwrites any existing value. + * + * @param name the name of the attribute to set + * @param value the value to set on the attributen + */ + public void setAttribute(String name, Object value) { + attributes.put(name, value); + } + + /** + * Retrieves a generic attribute value from the state for this request. + * @param name the name of the attribute + * @return the value of the attribute if previously set (and null otherwise) + */ + public Object getAttribute(String name) { + return attributes.get(name); + } +} diff --git a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java index 77dfc6bcd4fc5..d550fbb768133 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java +++ b/server/src/main/java/org/opensearch/search/pipeline/PipelinedRequest.java @@ -22,18 +22,20 @@ */ public final class PipelinedRequest extends SearchRequest { private final Pipeline pipeline; + private final PipelineProcessingContext requestContext; - PipelinedRequest(Pipeline pipeline, SearchRequest transformedRequest) { + PipelinedRequest(Pipeline pipeline, SearchRequest transformedRequest, PipelineProcessingContext requestContext) { super(transformedRequest); this.pipeline = pipeline; + this.requestContext = requestContext; } public void transformRequest(ActionListener requestListener) { - pipeline.transformRequest(this, requestListener); + pipeline.transformRequest(this, requestListener, requestContext); } public ActionListener transformResponseListener(ActionListener responseListener) { - return pipeline.transformResponseListener(this, responseListener); + return pipeline.transformResponseListener(this, responseListener, requestContext); } public void transformSearchPhaseResults( @@ -42,7 +44,7 @@ public void transformSearchPhaseResults( final String currentPhase, final String nextPhase ) { - pipeline.runSearchPhaseResultsTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase); + pipeline.runSearchPhaseResultsTransformer(searchPhaseResult, searchPhaseContext, currentPhase, nextPhase, requestContext); } // Visible for testing diff --git a/server/src/main/java/org/opensearch/search/pipeline/Processor.java b/server/src/main/java/org/opensearch/search/pipeline/Processor.java index 0120d68ceb5aa..a06383fbe9cef 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Processor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Processor.java @@ -21,13 +21,6 @@ * @opensearch.internal */ public interface Processor { - /** - * Processor configuration key to let the factory know the context for pipeline creation. - *

- * See {@link PipelineSource}. - */ - String PIPELINE_SOURCE = "pipeline_source"; - /** * Gets the type of processor */ diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java index 772dc8758bace..a64266cfb2a2b 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPhaseResultsProcessor.java @@ -32,6 +32,22 @@ void process( final SearchPhaseContext searchPhaseContext ); + /** + * Processes the {@link SearchPhaseResults} obtained from a SearchPhase which will be returned to next + * SearchPhase. Receives the {@link PipelineProcessingContext} passed to other processors. + * @param searchPhaseResult {@link SearchPhaseResults} + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} + * @param {@link SearchPhaseResult} + */ + default void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext + ) { + process(searchPhaseResult, searchPhaseContext); + } + /** * The phase which should have run before, this processor can start executing. * @return {@link SearchPhaseName} diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index 580fe1b7c4216..2175b5d135394 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -408,7 +408,8 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest) { pipeline = pipelineHolder.pipeline; } } - return new PipelinedRequest(pipeline, searchRequest); + PipelineProcessingContext requestContext = new PipelineProcessingContext(); + return new PipelinedRequest(pipeline, searchRequest, requestContext); } Map> getRequestProcessorFactories() { diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchRequestProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchRequestProcessor.java index 427c9e4ab694c..30adc9b0afbe8 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchRequestProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchRequestProcessor.java @@ -15,18 +15,27 @@ * Interface for a search pipeline processor that modifies a search request. */ public interface SearchRequestProcessor extends Processor { - /** - * Transform a {@link SearchRequest}. Executed on the coordinator node before any {@link org.opensearch.action.search.SearchPhase} - * executes. - *

+ * Process a SearchRequest without receiving request-scoped state. * Implement this method if the processor makes no asynchronous calls. - * @param request the executed {@link SearchRequest} - * @return a new {@link SearchRequest} (or the input {@link SearchRequest} if no changes) - * @throws Exception if an error occurs during processing + * @param request the search request (which may have been modified by an earlier processor) + * @return the modified search request + * @throws Exception implementation-specific processing exception */ SearchRequest processRequest(SearchRequest request) throws Exception; + /** + * Process a SearchRequest, with request-scoped state shared across processors in the pipeline + * Implement this method if the processor makes no asynchronous calls. + * @param request the search request (which may have been modified by an earlier processor) + * @param requestContext request-scoped state shared across processors in the pipeline + * @return the modified search request + * @throws Exception implementation-specific processing exception + */ + default SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) throws Exception { + return processRequest(request); + } + /** * Transform a {@link SearchRequest}. Executed on the coordinator node before any {@link org.opensearch.action.search.SearchPhase} * executes. @@ -35,9 +44,13 @@ public interface SearchRequestProcessor extends Processor { * @param request the executed {@link SearchRequest} * @param requestListener callback to be invoked on successful processing or on failure */ - default void processRequestAsync(SearchRequest request, ActionListener requestListener) { + default void processRequestAsync( + SearchRequest request, + PipelineProcessingContext requestContext, + ActionListener requestListener + ) { try { - requestListener.onResponse(processRequest(request)); + requestListener.onResponse(processRequest(request, requestContext)); } catch (Exception e) { requestListener.onFailure(e); } diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchResponseProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/SearchResponseProcessor.java index 21136ce208fee..98591ab9d0def 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchResponseProcessor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchResponseProcessor.java @@ -21,24 +21,47 @@ public interface SearchResponseProcessor extends Processor { * Transform a {@link SearchResponse}, possibly based on the executed {@link SearchRequest}. *

* Implement this method if the processor makes no asynchronous calls. - * @param request the executed {@link SearchRequest} + * + * @param request the executed {@link SearchRequest} * @param response the current {@link SearchResponse}, possibly modified by earlier processors * @return a modified {@link SearchResponse} (or the input {@link SearchResponse} if no changes) * @throws Exception if an error occurs during processing */ SearchResponse processResponse(SearchRequest request, SearchResponse response) throws Exception; + /** + * Process a SearchResponse, with request-scoped state shared across processors in the pipeline + *

+ * Implement this method if the processor makes no asynchronous calls. + * + * @param request the (maybe transformed) search request + * @param response the search response (which may have been modified by an earlier processor) + * @param requestContext request-scoped state shared across processors in the pipeline + * @return the modified search response + * @throws Exception implementation-specific processing exception + */ + default SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) + throws Exception { + return processResponse(request, response); + } + /** * Transform a {@link SearchResponse}, possibly based on the executed {@link SearchRequest}. *

* Expert method: Implement this if the processor needs to make asynchronous calls. Otherwise, implement processResponse. - * @param request the executed {@link SearchRequest} - * @param response the current {@link SearchResponse}, possibly modified by earlier processors + * + * @param request the executed {@link SearchRequest} + * @param response the current {@link SearchResponse}, possibly modified by earlier processors * @param responseListener callback to be invoked on successful processing or on failure */ - default void processResponseAsync(SearchRequest request, SearchResponse response, ActionListener responseListener) { + default void processResponseAsync( + SearchRequest request, + SearchResponse response, + PipelineProcessingContext requestContext, + ActionListener responseListener + ) { try { - responseListener.onResponse(processResponse(request, response)); + responseListener.onResponse(processResponse(request, response, requestContext)); } catch (Exception e) { responseListener.onFailure(e); } diff --git a/server/src/main/java/org/opensearch/search/pipeline/StatefulSearchRequestProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/StatefulSearchRequestProcessor.java new file mode 100644 index 0000000000000..67e1c1147cb87 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/pipeline/StatefulSearchRequestProcessor.java @@ -0,0 +1,25 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline; + +import org.opensearch.action.search.SearchRequest; + +/** + * A specialization of {@link SearchRequestProcessor} that makes use of the request-scoped processor state. + * Implementors must implement the processRequest method that accepts request-scoped processor state. + */ +public interface StatefulSearchRequestProcessor extends SearchRequestProcessor { + @Override + default SearchRequest processRequest(SearchRequest request) { + throw new UnsupportedOperationException(); + } + + @Override + SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) throws Exception; +} diff --git a/server/src/main/java/org/opensearch/search/pipeline/StatefulSearchResponseProcessor.java b/server/src/main/java/org/opensearch/search/pipeline/StatefulSearchResponseProcessor.java new file mode 100644 index 0000000000000..f0842d24e1b56 --- /dev/null +++ b/server/src/main/java/org/opensearch/search/pipeline/StatefulSearchResponseProcessor.java @@ -0,0 +1,27 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; + +/** + * A specialization of {@link SearchResponseProcessor} that makes use of the request-scoped processor state. + * Implementors must implement the processResponse method that accepts request-scoped processor state. + */ +public interface StatefulSearchResponseProcessor extends SearchResponseProcessor { + @Override + default SearchResponse processResponse(SearchRequest request, SearchResponse response) { + throw new UnsupportedOperationException(); + } + + @Override + SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) + throws Exception; +} diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 98d2a7e84d672..f5851e669a2da 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -41,6 +41,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; @@ -68,6 +69,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static org.mockito.ArgumentMatchers.anyString; @@ -1378,4 +1380,92 @@ public void testExtraParameterInProcessorConfig() { fail("Wrong exception type: " + e.getClass()); } } + + private static class FakeStatefulRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor { + private final String type; + private final Consumer stateConsumer; + + public FakeStatefulRequestProcessor(String type, Consumer stateConsumer) { + super(null, null, false); + this.type = type; + this.stateConsumer = stateConsumer; + } + + @Override + public String getType() { + return type; + } + + @Override + public SearchRequest processRequest(SearchRequest request, PipelineProcessingContext requestContext) throws Exception { + stateConsumer.accept(requestContext); + return request; + } + } + + private static class FakeStatefulResponseProcessor extends AbstractProcessor implements StatefulSearchResponseProcessor { + private final String type; + private final Consumer stateConsumer; + + public FakeStatefulResponseProcessor(String type, Consumer stateConsumer) { + super(null, null, false); + this.type = type; + this.stateConsumer = stateConsumer; + } + + @Override + public String getType() { + return type; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response, PipelineProcessingContext requestContext) + throws Exception { + stateConsumer.accept(requestContext); + return response; + } + } + + public void testStatefulProcessors() throws Exception { + AtomicReference contextHolder = new AtomicReference<>(); + SearchPipelineService searchPipelineService = createWithProcessors( + Map.of( + "write_context", + (pf, t, d, igf, cfg, ctx) -> new FakeStatefulRequestProcessor("write_context", (c) -> c.setAttribute("a", "b")) + ), + Map.of( + "read_context", + (pf, t, d, igf, cfg, ctx) -> new FakeStatefulResponseProcessor( + "read_context", + (c) -> contextHolder.set((String) c.getAttribute("a")) + ) + ), + Collections.emptyMap() + ); + + SearchPipelineMetadata metadata = new SearchPipelineMetadata( + Map.of( + "p1", + new PipelineConfiguration( + "p1", + new BytesArray( + "{\"request_processors\" : [ { \"write_context\": {} } ], \"response_processors\": [ { \"read_context\": {} }] }" + ), + XContentType.JSON + ) + ) + ); + ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build(); + ClusterState previousState = clusterState; + clusterState = ClusterState.builder(clusterState) + .metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata)) + .build(); + searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); + + PipelinedRequest request = searchPipelineService.resolvePipeline(new SearchRequest().pipeline("p1")); + assertNull(contextHolder.get()); + syncExecutePipeline(request, new SearchResponse(null, null, 0, 0, 0, 0, null, null)); + assertNotNull(contextHolder.get()); + assertEquals("b", contextHolder.get()); + } }