Skip to content

Commit

Permalink
Multiple innerHit in nested fields
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <[email protected]>
  • Loading branch information
heemin32 committed Dec 3, 2024
1 parent 9276c77 commit 33dd900
Show file tree
Hide file tree
Showing 40 changed files with 1,923 additions and 206 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/k-NN/compare/2.18...2.x)
### Features
- Add Support for Multi Values in innerHit for Nested k-NN Fields in Lucene and FAISS (#2283)[https://github.com/opensearch-project/k-NN/pull/2283]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
### Bug Fixes
Expand Down
1 change: 1 addition & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ dependencies {
testImplementation group: 'net.bytebuddy', name: 'byte-buddy', version: '1.15.10'
testImplementation group: 'org.objenesis', name: 'objenesis', version: '3.3'
testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.15.4'
testFixturesImplementation 'com.jayway.jsonpath:json-path:2.8.0'
testFixturesImplementation "org.opensearch:common-utils:${version}"
implementation 'com.github.oshi:oshi-core:6.4.13'
api "net.java.dev.jna:jna:5.13.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ public enum KNNEngine implements KNNLibrary {
private static final Set<KNNEngine> CUSTOM_SEGMENT_FILE_ENGINES = ImmutableSet.of(KNNEngine.NMSLIB, KNNEngine.FAISS);
private static final Set<KNNEngine> ENGINES_SUPPORTING_FILTERS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_RADIAL_SEARCH = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);
public static final Set<KNNEngine> ENGINES_SUPPORTING_MULTI_VECTORS = ImmutableSet.of(KNNEngine.LUCENE, KNNEngine.FAISS);

private static Map<KNNEngine, Integer> MAX_DIMENSIONS_BY_ENGINE = Map.of(
KNNEngine.NMSLIB,
Expand Down
10 changes: 5 additions & 5 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.HitQueue;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.lucene.Lucene;
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
Expand Down Expand Up @@ -68,8 +67,8 @@ public Map<Integer, Float> searchLeaf(final LeafReaderContext leafReaderContext,
if (exactSearcherContext.getKnnQuery().getRadius() != null) {
return doRadialSearch(leafReaderContext, exactSearcherContext, iterator);
}
if (exactSearcherContext.getMatchedDocs() != null
&& exactSearcherContext.getMatchedDocs().cardinality() <= exactSearcherContext.getK()) {
if (exactSearcherContext.getMatchedDocsIterator() != null
&& exactSearcherContext.numberOfMatchedDocs <= exactSearcherContext.getK()) {
return scoreAllDocs(iterator);
}
return searchTopCandidates(iterator, exactSearcherContext.getK(), Predicates.alwaysTrue());
Expand Down Expand Up @@ -155,7 +154,7 @@ private Map<Integer, Float> filterDocsByMinScore(ExactSearcherContext context, K

private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSearcherContext exactSearcherContext) throws IOException {
final KNNQuery knnQuery = exactSearcherContext.getKnnQuery();
final BitSet matchedDocs = exactSearcherContext.getMatchedDocs();
final DocIdSetIterator matchedDocs = exactSearcherContext.getMatchedDocsIterator();
final SegmentReader reader = Lucene.segmentReader(leafReaderContext.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
Expand Down Expand Up @@ -245,7 +244,8 @@ public static class ExactSearcherContext {
*/
boolean useQuantizedVectorsForSearch;
int k;
BitSet matchedDocs;
DocIdSetIterator matchedDocsIterator;
long numberOfMatchedDocs;
KNNQuery knnQuery;
/**
* whether the matchedDocs contains parent ids or child ids. This is relevant in the case of
Expand Down
1 change: 1 addition & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public class KNNQuery extends Query {

@Setter
private Query filterQuery;
@Getter
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
Expand Down
55 changes: 41 additions & 14 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.DiversifyingChildrenByteKnnVectorQuery;
import org.apache.lucene.search.join.DiversifyingChildrenFloatKnnVectorQuery;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.query.common.QueryUtils;
import org.opensearch.knn.index.query.lucenelib.NestedKnnVectorQueryFactory;
import org.opensearch.knn.index.query.nativelib.NativeEngineKnnVectorQuery;
import org.opensearch.knn.index.query.rescore.RescoreContext;

Expand All @@ -24,13 +24,13 @@
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_EF_SEARCH;
import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;
import static org.opensearch.knn.index.VectorDataType.SUPPORTED_VECTOR_DATA_TYPES;
import static org.opensearch.knn.index.engine.KNNEngine.ENGINES_SUPPORTING_MULTI_VECTORS;

/**
* Creates the Lucene k-NN queries
*/
@Log4j2
public class KNNQueryFactory extends BaseQueryFactory {

/**
* Creates a Lucene query for a particular engine.
* @param createQueryRequest request object that has all required fields to construct the query
Expand All @@ -48,11 +48,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
final Query filterQuery = getFilterQuery(createQueryRequest);
final Map<String, ?> methodParameters = createQueryRequest.getMethodParameters();
final RescoreContext rescoreContext = createQueryRequest.getRescoreContext().orElse(null);
final KNNEngine knnEngine = createQueryRequest.getKnnEngine();

BitSetProducer parentFilter = null;
boolean isInnerHitQuery = false;
if (createQueryRequest.getContext().isPresent()) {
QueryShardContext context = createQueryRequest.getContext().get();
parentFilter = context.getParentFilter();
isInnerHitQuery = context.isInnerHitQuery();
}

if (KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(createQueryRequest.getKnnEngine())) {
Expand Down Expand Up @@ -95,7 +98,14 @@ public static Query create(CreateQueryRequest createQueryRequest) {
.rescoreContext(rescoreContext)
.build();
}
return createQueryRequest.getRescoreContext().isPresent() ? new NativeEngineKnnVectorQuery(knnQuery) : knnQuery;

if (createQueryRequest.getRescoreContext().isPresent()) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, isInnerHitQuery);
} else if (ENGINES_SUPPORTING_MULTI_VECTORS.contains(knnEngine) && isInnerHitQuery) {
return new NativeEngineKnnVectorQuery(knnQuery, QueryUtils.INSTANCE, isInnerHitQuery);
} else {
return knnQuery;
}
}

Integer requestEfSearch = null;
Expand All @@ -106,9 +116,9 @@ public static Query create(CreateQueryRequest createQueryRequest) {
log.debug(String.format("Creating Lucene k-NN query for index: %s \"\", field: %s \"\", k: %d", indexName, fieldName, k));
switch (vectorDataType) {
case BYTE:
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter);
return getKnnByteVectorQuery(fieldName, byteVector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
case FLOAT:
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter);
return getKnnFloatVectorQuery(fieldName, vector, luceneK, filterQuery, parentFilter, isInnerHitQuery);
default:
throw new IllegalArgumentException(
String.format(
Expand All @@ -131,38 +141,55 @@ private static Query validateFilterQuerySupport(final Query filterQuery, final K
}

/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenByteKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
* If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory}
* which will create query to dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnByteVectorQuery(
final String fieldName,
final byte[] byteVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
assert isInnerHitQuery == false;
return new KnnByteVectorQuery(fieldName, byteVector, k, filterQuery);
} else {
return new DiversifyingChildrenByteKnnVectorQuery(fieldName, byteVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
byteVector,
k,
filterQuery,
parentFilter,
isInnerHitQuery
);
}
}

/**
* If parentFilter is not null, it is a nested query. Therefore, we return {@link DiversifyingChildrenFloatKnnVectorQuery}
* which will dedupe search result per parent so that we can get k parent results at the end.
* If parentFilter is not null, it is a nested query. Therefore, we delegate creation of query to {@link NestedKnnVectorQueryFactory}
* which will create query to dedupe search result per parent so that we can get k parent results at the end.
*/
private static Query getKnnFloatVectorQuery(
final String fieldName,
final float[] floatVector,
final int k,
final Query filterQuery,
final BitSetProducer parentFilter
final BitSetProducer parentFilter,
final boolean isInnerHitQuery
) {
if (parentFilter == null) {
return new KnnFloatVectorQuery(fieldName, floatVector, k, filterQuery);
} else {
return new DiversifyingChildrenFloatKnnVectorQuery(fieldName, floatVector, filterQuery, k, parentFilter);
return NestedKnnVectorQueryFactory.createNestedKnnVectorQuery(
fieldName,
floatVector,
k,
filterQuery,
parentFilter,
isInnerHitQuery
);
}
}
}
32 changes: 20 additions & 12 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.LeafReaderContext;
Expand Down Expand Up @@ -66,6 +67,7 @@ public class KNNWeight extends Weight {
private final float boost;

private final NativeMemoryCacheManager nativeMemoryCacheManager;
@Getter
private final Weight filterWeight;
private final ExactSearcher exactSearcher;

Expand Down Expand Up @@ -109,7 +111,7 @@ public Explanation explain(LeafReaderContext context, int doc) {

@Override
public Scorer scorer(LeafReaderContext context) throws IOException {
final Map<Integer, Float> docIdToScoreMap = searchLeaf(context, knnQuery.getK());
final Map<Integer, Float> docIdToScoreMap = searchLeaf(context, knnQuery.getK()).getResult();
if (docIdToScoreMap.isEmpty()) {
return KNNScorer.emptyScorer(this);
}
Expand All @@ -125,32 +127,34 @@ public Scorer scorer(LeafReaderContext context) throws IOException {
* @param k Number of results to return
* @return A Map of docId to scores for top k results
*/
public Map<Integer, Float> searchLeaf(LeafReaderContext context, int k) throws IOException {
public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOException {
final BitSet filterBitSet = getFilteredDocsBitSet(context);
int cardinality = filterBitSet.cardinality();
// We don't need to go to JNI layer if no documents are found which satisfy the filters
// We should give this condition a deeper look that where it should be placed. For now I feel this is a good
// place,
if (filterWeight != null && cardinality == 0) {
return Collections.emptyMap();
return PerLeafResult.EMPTY_RESULT;
}
/*
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
* This improves the recall.
*/
if (isFilteredExactSearchPreferred(cardinality)) {
return doExactSearch(context, filterBitSet, k);
Map<Integer, Float> result = doExactSearch(context, new BitSetIterator(filterBitSet, cardinality), cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, filterBitSet, cardinality, k);
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
if (isExactSearchRequire(context, cardinality, docIdsToScoreMap.size())) {
final BitSet docs = filterWeight != null ? filterBitSet : null;
return doExactSearch(context, docs, k);
final BitSetIterator docs = filterWeight != null ? new BitSetIterator(filterBitSet, cardinality) : null;
Map<Integer, Float> result = doExactSearch(context, docs, cardinality, k);
return new PerLeafResult(filterWeight == null ? null : filterBitSet, result);
}
return docIdsToScoreMap;
return new PerLeafResult(filterWeight == null ? null : filterBitSet, docIdsToScoreMap);
}

private BitSet getFilteredDocsBitSet(final LeafReaderContext ctx) throws IOException {
Expand Down Expand Up @@ -205,17 +209,21 @@ private int[] bitSetToIntArray(final BitSet bitSet) {
return intArray;
}

private Map<Integer, Float> doExactSearch(final LeafReaderContext context, final BitSet acceptedDocs, int k) throws IOException {
private Map<Integer, Float> doExactSearch(
final LeafReaderContext context,
final DocIdSetIterator acceptedDocs,
final long numberOfAcceptedDocs,
int k
) throws IOException {
final ExactSearcherContextBuilder exactSearcherContextBuilder = ExactSearcher.ExactSearcherContext.builder()
.isParentHits(true)
.k(k)
// setting to true, so that if quantization details are present we want to do search on the quantized
// vectors as this flow is used in first pass of search.
.useQuantizedVectorsForSearch(true)
.knnQuery(knnQuery);
if (acceptedDocs != null) {
exactSearcherContextBuilder.matchedDocs(acceptedDocs);
}
.knnQuery(knnQuery)
.matchedDocsIterator(acceptedDocs)
.numberOfMatchedDocs(numberOfAcceptedDocs);
return exactSearch(context, exactSearcherContextBuilder.build());
}

Expand Down
26 changes: 26 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/PerLeafResult.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query;

import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.util.Bits;

import java.util.Collections;
import java.util.Map;

@Getter
public class PerLeafResult {
public static final PerLeafResult EMPTY_RESULT = new PerLeafResult(new Bits.MatchNoBits(0), Collections.emptyMap());
private final Bits filterBits;
@Setter
private Map<Integer, Float> result;

public PerLeafResult(Bits filterBits, Map<Integer, Float> result) {
this.filterBits = filterBits == null ? new Bits.MatchAllBits(0) : filterBits;
this.result = result;
}
}
24 changes: 11 additions & 13 deletions src/main/java/org/opensearch/knn/index/query/ResultUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.DocIdSetBuilder;

import java.io.IOException;
Expand All @@ -30,14 +29,14 @@ public final class ResultUtil {
* @param perLeafResults Results from the list
* @param k the number of results across all leaf results to return
*/
public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k) {
public static void reduceToTopK(List<PerLeafResult> perLeafResults, int k) {
// Iterate over all scores to get min competitive score
PriorityQueue<Float> topKMinQueue = new PriorityQueue<>(k);

int count = 0;
for (Map<Integer, Float> perLeafResult : perLeafResults) {
count += perLeafResult.size();
for (Float score : perLeafResult.values()) {
for (PerLeafResult perLeafResult : perLeafResults) {
count += perLeafResult.getResult().size();
for (Float score : perLeafResult.getResult().values()) {
if (topKMinQueue.size() < k) {
topKMinQueue.add(score);
} else if (topKMinQueue.peek() != null && score > topKMinQueue.peek()) {
Expand All @@ -54,23 +53,22 @@ public static void reduceToTopK(List<Map<Integer, Float>> perLeafResults, int k)

// Reduce the results based on min competitive score
float minScore = topKMinQueue.peek() == null ? -Float.MAX_VALUE : topKMinQueue.peek();
perLeafResults.forEach(results -> results.entrySet().removeIf(entry -> entry.getValue() < minScore));
perLeafResults.forEach(results -> results.getResult().entrySet().removeIf(entry -> entry.getValue() < minScore));
}

/**
* Convert map to bit set, if resultMap is empty or null then returns an Optional. Returning an optional here to
* ensure that the caller is aware that BitSet may not be present
* Convert map of docs to doc id set iterator
*
* @param resultMap Map of results
* @return BitSet of results; null is returned if the result map is empty
* @return Doc id set iterator
* @throws IOException If an error occurs during the search.
*/
public static BitSet resultMapToMatchBitSet(Map<Integer, Float> resultMap) throws IOException {
if (resultMap == null || resultMap.isEmpty()) {
return null;
public static DocIdSetIterator resultMapToDocIds(Map<Integer, Float> resultMap) throws IOException {
if (resultMap.isEmpty()) {
return DocIdSetIterator.empty();
}
final int maxDoc = Collections.max(resultMap.keySet()) + 1;
return BitSet.of(resultMapToDocIds(resultMap, maxDoc), maxDoc);
return resultMapToDocIds(resultMap, maxDoc);
}

/**
Expand Down
Loading

0 comments on commit 33dd900

Please sign in to comment.