From 524dbd003e908cc7a4a81d6bcc2387e2d398cd48 Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Mon, 9 Sep 2024 16:27:15 -0700 Subject: [PATCH] Fixing filtered vector search when quantization is applied (#2076) Signed-off-by: Navneet Verma --- .../knn/index/query/ExactSearcher.java | 86 +++++++++++++------ .../opensearch/knn/index/query/KNNWeight.java | 62 ++++++------- .../query/SegmentLevelQuantizationInfo.java | 46 ++++++++++ .../query/SegmentLevelQuantizationUtil.java | 60 +++++++++++++ .../filtered/FilteredIdsKNNIterator.java | 30 ++++++- .../NestedFilteredIdsKNNIterator.java | 17 +++- .../nativelib/NativeEngineKnnVectorQuery.java | 11 ++- .../knn/index/query/KNNWeightTests.java | 5 +- .../NativeEngineKNNVectorQueryTests.java | 6 +- 9 files changed, 249 insertions(+), 74 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java create mode 100644 src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java diff --git a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java index 249c66d03..5b6029766 100644 --- a/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java +++ b/src/main/java/org/opensearch/knn/index/query/ExactSearcher.java @@ -6,6 +6,8 @@ package org.opensearch.knn.index.query; import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Value; import lombok.extern.log4j.Log4j2; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReaderContext; @@ -42,27 +44,18 @@ public class ExactSearcher { /** * Execute an exact search on a subset of documents of a leaf * - * @param leafReaderContext LeafReaderContext to be searched over - * @param matchedDocs matched documents - * @param knnQuery KNN Query - * @param k number of results to return - * @param isParentHits whether the matchedDocs contains parent ids or child ids. This is relevant in the case of - * filtered nested search where the matchedDocs contain the parent ids and {@link NestedFilteredIdsKNNIterator} - * needs to be used. + * @param leafReaderContext {@link LeafReaderContext} + * @param exactSearcherContext {@link ExactSearcherContext} * @return Map of re-scored results + * @throws IOException exception during execution of exact search */ - public Map searchLeaf( - final LeafReaderContext leafReaderContext, - final BitSet matchedDocs, - final KNNQuery knnQuery, - int k, - boolean isParentHits - ) throws IOException { - KNNIterator iterator = getMatchedKNNIterator(leafReaderContext, matchedDocs, knnQuery, isParentHits); - if (matchedDocs.cardinality() <= k) { + public Map searchLeaf(final LeafReaderContext leafReaderContext, final ExactSearcherContext exactSearcherContext) + throws IOException { + KNNIterator iterator = getMatchedKNNIterator(leafReaderContext, exactSearcherContext); + if (exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) { return scoreAllDocs(iterator); } - return searchTopK(iterator, k); + return searchTopK(iterator, exactSearcherContext.getK()); } private Map scoreAllDocs(KNNIterator iterator) throws IOException { @@ -105,17 +98,15 @@ private Map searchTopK(KNNIterator iterator, int k) throws IOExc return docToScore; } - private KNNIterator getMatchedKNNIterator( - final LeafReaderContext leafReaderContext, - final BitSet matchedDocs, - KNNQuery knnQuery, - boolean isParentHits - ) throws IOException { + private KNNIterator getMatchedKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) + throws IOException { + final KNNQuery knnQuery = exactSearcherContext.getKnnQuery(); + final BitSet matchedDocs = exactSearcherContext.getMatchedDocs(); final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader()); final FieldInfo fieldInfo = reader.getFieldInfos().fieldInfo(knnQuery.getField()); final SpaceType spaceType = FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); - boolean isNestedRequired = isParentHits && knnQuery.getParentsFilter() != null; + boolean isNestedRequired = exactSearcherContext.isParentHits() && knnQuery.getParentsFilter() != null; if (VectorDataType.BINARY == knnQuery.getVectorDataType() && isNestedRequired) { final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); @@ -137,6 +128,17 @@ private KNNIterator getMatchedKNNIterator( spaceType ); } + final byte[] quantizedQueryVector; + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; + if (exactSearcherContext.isUseQuantizedVectorsForSearch()) { + // Build Segment Level Quantization info. + segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build(reader, fieldInfo, knnQuery.getField()); + // Quantize the Query Vector Once. + quantizedQueryVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo); + } else { + segmentLevelQuantizationInfo = null; + quantizedQueryVector = null; + } final KNNVectorValues vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader); if (isNestedRequired) { @@ -145,10 +147,42 @@ private KNNIterator getMatchedKNNIterator( knnQuery.getQueryVector(), (KNNFloatVectorValues) vectorValues, spaceType, - knnQuery.getParentsFilter().getBitSet(leafReaderContext) + knnQuery.getParentsFilter().getBitSet(leafReaderContext), + quantizedQueryVector, + segmentLevelQuantizationInfo ); } - return new FilteredIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNFloatVectorValues) vectorValues, spaceType); + return new FilteredIdsKNNIterator( + matchedDocs, + knnQuery.getQueryVector(), + (KNNFloatVectorValues) vectorValues, + spaceType, + quantizedQueryVector, + segmentLevelQuantizationInfo + ); + } + + /** + * Stores the context that is used to do the exact search. This class will help in reducing the explosion of attributes + * for doing exact search. + */ + @Value + @Builder + public static class ExactSearcherContext { + /** + * controls whether we should use Quantized vectors during exact search or not. This is useful because when we do + * re-scoring we need to re-score using full precision vectors and not quantized vectors. + */ + boolean useQuantizedVectorsForSearch; + int k; + BitSet matchedDocs; + KNNQuery knnQuery; + /** + * whether the matchedDocs contains parent ids or child ids. This is relevant in the case of + * filtered nested search where the matchedDocs contain the parent ids and {@link NestedFilteredIdsKNNIterator} + * needs to be used. + */ + boolean isParentHits; } } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 1769328fe..b1ba9de59 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -27,7 +27,6 @@ import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; -import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; import org.opensearch.knn.index.memory.NativeMemoryAllocation; import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; @@ -39,8 +38,6 @@ import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; -import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; -import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import java.io.IOException; import java.nio.file.Path; @@ -140,8 +137,17 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I * This improves the recall. */ Map docIdsToScoreMap; + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + .k(k) + .isParentHits(true) + .matchedDocs(filterBitSet) + // setting to true, so that if quantization details are present we want to do search on the quantized + // vectors as this flow is used in first pass of search. + .useQuantizedVectorsForSearch(true) + .knnQuery(knnQuery) + .build(); if (filterWeight != null && canDoExactSearch(cardinality)) { - docIdsToScoreMap = exactSearch(context, filterBitSet, true, k); + docIdsToScoreMap = exactSearch(context, exactSearcherContext); } else { docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k); if (docIdsToScoreMap == null) { @@ -155,7 +161,7 @@ public Map searchLeaf(LeafReaderContext context, int k) throws I docIdsToScoreMap.size(), cardinality ); - docIdsToScoreMap = exactSearch(context, filterBitSet, true, k); + docIdsToScoreMap = exactSearch(context, exactSearcherContext); } } if (docIdsToScoreMap.isEmpty()) { @@ -258,10 +264,13 @@ private Map doANNSearch( ); } - QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); - + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo = SegmentLevelQuantizationInfo.build( + reader, + fieldInfo, + knnQuery.getField() + ); // TODO: Change type of vector once more quantization methods are supported - byte[] quantizedVector = getQuantizedVector(quantizationParams, reader, fieldInfo); + final byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(knnQuery.getQueryVector(), segmentLevelQuantizationInfo); List engineFiles = getEngineFiles(reader, knnEngine.getExtension()); if (engineFiles.isEmpty()) { @@ -285,7 +294,7 @@ private Map doANNSearch( knnEngine, knnQuery.getIndexName(), // TODO: In the future, more vector data types will be supported with quantization - quantizationParams == null ? vectorDataType : VectorDataType.BINARY + quantizedVector == null ? vectorDataType : VectorDataType.BINARY ), knnQuery.getIndexName(), modelId @@ -310,11 +319,11 @@ private Map doANNSearch( int[] parentIds = getParentIdsArray(context); if (k > 0) { if (knnQuery.getVectorDataType() == VectorDataType.BINARY - || quantizationParams != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) { + || quantizedVector != null && quantizationService.getVectorDataTypeForTransfer(fieldInfo) == VectorDataType.BINARY) { results = JNIService.queryBinaryIndex( indexAllocation.getMemoryAddress(), // TODO: In the future, quantizedVector can have other data types than byte - quantizationParams == null ? knnQuery.getByteQueryVector() : quantizedVector, + quantizedVector == null ? knnQuery.getByteQueryVector() : quantizedVector, k, knnQuery.getMethodParameters(), knnEngine, @@ -391,16 +400,14 @@ List getEngineFiles(SegmentReader reader, String extension) throws IOExc /** * Execute exact search for the given matched doc ids and return the results as a map of docId to score. * - * @param leafReaderContext The leaf reader context for the current segment. - * @param matchSet The filterIds to search for. - * @param isParentHits Whether the matchedDocs contains parent ids or child ids. - * @param k The number of results to return. * @return Map of docId to score for the exact search results. * @throws IOException If an error occurs during the search. */ - public Map exactSearch(final LeafReaderContext leafReaderContext, final BitSet matchSet, boolean isParentHits, int k) - throws IOException { - return exactSearcher.searchLeaf(leafReaderContext, matchSet, knnQuery, k, isParentHits); + public Map exactSearch( + final LeafReaderContext leafReaderContext, + final ExactSearcher.ExactSearcherContext exactSearcherContext + ) throws IOException { + return exactSearcher.searchLeaf(leafReaderContext, exactSearcherContext); } @Override @@ -462,23 +469,4 @@ private boolean isExactSearchThresholdSettingSet(int filterThresholdValue) { private boolean canDoExactSearchAfterANNSearch(final int filterIdsCount, final int annResultCount) { return filterWeight != null && filterIdsCount >= knnQuery.getK() && knnQuery.getK() > annResultCount; } - - // TODO: this will eventually return more types than just byte - private byte[] getQuantizedVector(QuantizationParams quantizationParams, SegmentReader reader, FieldInfo fieldInfo) throws IOException { - if (quantizationParams != null) { - QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); - reader.searchNearestVectors(knnQuery.getField(), new float[0], tempCollector, null); - if (tempCollector.getQuantizationState() == null) { - throw new IllegalStateException(String.format("No quantization state found for field %s", fieldInfo.getName())); - } - QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(quantizationParams); - // TODO: In the future, byte array will not be the only output type from this method - return (byte[]) quantizationService.quantize( - tempCollector.getQuantizationState(), - knnQuery.getQueryVector(), - quantizationOutput - ); - } - return null; - } } diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java new file mode 100644 index 000000000..d25774cdc --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java @@ -0,0 +1,46 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.AccessLevel; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.LeafReader; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; + +/** + * This class encapsulate the necessary details to do the quantization of the vectors present in a lucene segment. + */ +@Getter +@RequiredArgsConstructor(access = AccessLevel.PRIVATE) +public class SegmentLevelQuantizationInfo { + private final QuantizationParams quantizationParams; + private final QuantizationState quantizationState; + + /** + * A builder like function to build the {@link SegmentLevelQuantizationInfo} + * @param leafReader {@link LeafReader} + * @param fieldInfo {@link FieldInfo} + * @param fieldName {@link String} + * @return {@link SegmentLevelQuantizationInfo} + * @throws IOException exception while creating the {@link SegmentLevelQuantizationInfo} object. + */ + public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName) + throws IOException { + final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); + if (quantizationParams == null) { + return null; + } + final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName); + return new SegmentLevelQuantizationInfo(quantizationParams, quantizationState); + } + +} diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java new file mode 100644 index 000000000..46db8bb6b --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationUtil.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.query; + +import lombok.experimental.UtilityClass; +import org.apache.lucene.index.LeafReader; +import org.opensearch.knn.index.codec.KNN990Codec.QuantizationConfigKNNCollector; +import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.Locale; + +/** + * A utility class for doing Quantization related operation at a segment level. We can move this utility in {@link SegmentLevelQuantizationInfo} + * but I am keeping it thinking that {@link SegmentLevelQuantizationInfo} free from these utility functions to reduce + * the responsibilities of {@link SegmentLevelQuantizationInfo} class. + */ +@UtilityClass +public class SegmentLevelQuantizationUtil { + + /** + * A simple function to convert a vector to a quantized vector for a segment. + * @param vector array of float + * @return array of byte + */ + @SuppressWarnings("unchecked") + public static byte[] quantizeVector(final float[] vector, final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo) { + if (segmentLevelQuantizationInfo == null) { + return null; + } + final QuantizationService quantizationService = QuantizationService.getInstance(); + // TODO: We are converting the output of Quantize to byte array for now. But this needs to be fixed when + // other types of quantized outputs are returned like float[]. + return (byte[]) quantizationService.quantize( + segmentLevelQuantizationInfo.getQuantizationState(), + vector, + quantizationService.createQuantizationOutput(segmentLevelQuantizationInfo.getQuantizationParams()) + ); + } + + /** + * A utility function to get {@link QuantizationState} for a given segment and field. + * @param leafReader {@link LeafReader} + * @param fieldName {@link String} + * @return {@link QuantizationState} + * @throws IOException exception during reading the {@link QuantizationState} + */ + static QuantizationState getQuantizationState(final LeafReader leafReader, String fieldName) throws IOException { + final QuantizationConfigKNNCollector tempCollector = new QuantizationConfigKNNCollector(); + leafReader.searchNearestVectors(fieldName, new float[0], tempCollector, null); + if (tempCollector.getQuantizationState() == null) { + throw new IllegalStateException(String.format(Locale.ROOT, "No quantization state found for field %s", fieldName)); + } + return tempCollector.getQuantizationState(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java index a0d7694c9..56d291470 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/FilteredIdsKNNIterator.java @@ -9,6 +9,8 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo; +import org.opensearch.knn.index.query.SegmentLevelQuantizationUtil; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import java.io.IOException; @@ -24,16 +26,29 @@ public class FilteredIdsKNNIterator implements KNNIterator { protected final BitSet filterIdsBitSet; protected final BitSetIterator bitSetIterator; protected final float[] queryVector; + private final byte[] quantizedQueryVector; protected final KNNFloatVectorValues knnFloatVectorValues; protected final SpaceType spaceType; protected float currentScore = Float.NEGATIVE_INFINITY; protected int docId; + private final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo; - public FilteredIdsKNNIterator( + FilteredIdsKNNIterator( final BitSet filterIdsBitSet, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType + ) { + this(filterIdsBitSet, queryVector, knnFloatVectorValues, spaceType, null, null); + } + + public FilteredIdsKNNIterator( + final BitSet filterIdsBitSet, + final float[] queryVector, + final KNNFloatVectorValues knnFloatVectorValues, + final SpaceType spaceType, + final byte[] quantizedQueryVector, + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo ) { this.filterIdsBitSet = filterIdsBitSet; this.bitSetIterator = new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length()); @@ -41,6 +56,8 @@ public FilteredIdsKNNIterator( this.knnFloatVectorValues = knnFloatVectorValues; this.spaceType = spaceType; this.docId = bitSetIterator.nextDoc(); + this.quantizedQueryVector = quantizedQueryVector; + this.segmentLevelQuantizationInfo = segmentLevelQuantizationInfo; } /** @@ -68,8 +85,13 @@ public float score() { protected float computeScore() throws IOException { final float[] vector = knnFloatVectorValues.getVector(); - // Calculates a similarity score between the two vectors with a specified function. Higher similarity - // scores correspond to closer vectors. - return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + if (segmentLevelQuantizationInfo != null && quantizedQueryVector != null) { + byte[] quantizedVector = SegmentLevelQuantizationUtil.quantizeVector(vector, segmentLevelQuantizationInfo); + return SpaceType.HAMMING.getKnnVectorSimilarityFunction().compare(quantizedQueryVector, quantizedVector); + } else { + // Calculates a similarity score between the two vectors with a specified function. Higher similarity + // scores correspond to closer vectors. + return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector); + } } } diff --git a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java index 259b004f8..53ac72882 100644 --- a/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java +++ b/src/main/java/org/opensearch/knn/index/query/filtered/NestedFilteredIdsKNNIterator.java @@ -8,6 +8,7 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.query.SegmentLevelQuantizationInfo; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import java.io.IOException; @@ -19,14 +20,26 @@ public class NestedFilteredIdsKNNIterator extends FilteredIdsKNNIterator { private final BitSet parentBitSet; - public NestedFilteredIdsKNNIterator( + NestedFilteredIdsKNNIterator( final BitSet filterIdsArray, final float[] queryVector, final KNNFloatVectorValues knnFloatVectorValues, final SpaceType spaceType, final BitSet parentBitSet ) { - super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType); + this(filterIdsArray, queryVector, knnFloatVectorValues, spaceType, parentBitSet, null, null); + } + + public NestedFilteredIdsKNNIterator( + final BitSet filterIdsArray, + final float[] queryVector, + final KNNFloatVectorValues knnFloatVectorValues, + final SpaceType spaceType, + final BitSet parentBitSet, + final byte[] quantizedVector, + final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo + ) { + super(filterIdsArray, queryVector, knnFloatVectorValues, spaceType, quantizedVector, segmentLevelQuantizationInfo); this.parentBitSet = parentBitSet; } diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index c13d86554..945da850a 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -20,6 +20,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.opensearch.common.StopWatch; +import org.opensearch.knn.index.query.ExactSearcher; import org.opensearch.knn.index.query.KNNQuery; import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.ResultUtil; @@ -108,7 +109,15 @@ private List> doRescore( int finalI = i; rescoreTasks.add(() -> { BitSet convertedBitSet = ResultUtil.resultMapToMatchBitSet(perLeafResults.get(finalI)); - return knnWeight.exactSearch(leafReaderContext, convertedBitSet, false, k); + final ExactSearcher.ExactSearcherContext exactSearcherContext = ExactSearcher.ExactSearcherContext.builder() + .matchedDocs(convertedBitSet) + // setting to false because in re-scoring we want to do exact search on full precision vectors + .useQuantizedVectorsForSearch(false) + .k(k) + .isParentHits(false) + .knnQuery(knnQuery) + .build(); + return knnWeight.exactSearch(leafReaderContext, exactSearcherContext); }); } return indexSearcher.getTaskExecutor().invokeAll(rescoreTasks); diff --git a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java index a2b41804a..810f49c15 100644 --- a/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java +++ b/src/test/java/org/opensearch/knn/index/query/KNNWeightTests.java @@ -84,6 +84,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isNull; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mockStatic; import static org.mockito.Mockito.times; @@ -1363,9 +1364,9 @@ private KNNQueryResult[] getFilteredKNNQueryResults() { public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { try (MockedStatic quantizationServiceMockedStatic = Mockito.mockStatic(QuantizationService.class)) { QuantizationService quantizationService = Mockito.mock(QuantizationService.class); + quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); QuantizationParams quantizationParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); Mockito.when(quantizationService.getQuantizationParams(any(FieldInfo.class))).thenReturn(quantizationParams); - quantizationServiceMockedStatic.when(QuantizationService::getInstance).thenReturn(quantizationService); // Given int k = 3; @@ -1413,6 +1414,8 @@ public void testANNWithQuantizationParams_whenStateNotFound_thenFail() { when(reader.getFieldInfos()).thenReturn(fieldInfos); when(fieldInfos.fieldInfo(any())).thenReturn(fieldInfo); when(fieldInfo.attributes()).thenReturn(attributesMap); + // fieldName, new float[0], tempCollector, null) + doNothing().when(reader).searchNearestVectors(any(), eq(new float[0]), any(), any()); expectThrows(IllegalStateException.class, () -> knnWeight.scorer(leafReaderContext)); } diff --git a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java index ee53818f1..06350f39c 100644 --- a/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/nativelib/NativeEngineKNNVectorQueryTests.java @@ -36,7 +36,6 @@ import java.util.concurrent.Callable; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -186,8 +185,9 @@ public void testRescore() { when(knnWeight.getQuery()).thenReturn(knnQuery); when(knnWeight.searchLeaf(leaf1, firstPassK)).thenReturn(initialLeaf1Results); when(knnWeight.searchLeaf(leaf2, firstPassK)).thenReturn(initialLeaf2Results); - when(knnWeight.exactSearch(eq(leaf1), any(), anyBoolean(), anyInt())).thenReturn(rescoredLeaf1Results); - when(knnWeight.exactSearch(eq(leaf2), any(), anyBoolean(), anyInt())).thenReturn(rescoredLeaf2Results); + + when(knnWeight.exactSearch(eq(leaf1), any())).thenReturn(rescoredLeaf1Results); + when(knnWeight.exactSearch(eq(leaf2), any())).thenReturn(rescoredLeaf2Results); try (MockedStatic mockedResultUtil = mockStatic(ResultUtil.class)) { mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod); mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(eq(rescoredLeaf1Results), anyInt())).thenAnswer(t -> topDocs1);