Skip to content

Commit

Permalink
Merge branch 'main' into filterbug
Browse files Browse the repository at this point in the history
Signed-off-by: sahil <[email protected]>
  • Loading branch information
buddharajusahil authored Dec 31, 2024
2 parents bdd753d + c728f02 commit 4124d06
Show file tree
Hide file tree
Showing 20 changed files with 95 additions and 44 deletions.
2 changes: 2 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1180,6 +1180,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
throw;
}
jniUtil->ReleaseLongArrayElements(env, filterIdsJ, filteredIdsArray, JNI_ABORT);
} else {
faiss::SearchParameters *searchParameters = nullptr;
faiss::SearchParametersHNSW hnswParams;
Expand All @@ -1202,6 +1203,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter
throw;
}
}
jniUtil->ReleaseFloatArrayElements(env, queryVectorJ, rawQueryVector, JNI_ABORT);

// lims is structured to support batched queries, it has a length of nq + 1 (where nq is the number of queries),
// lims[i] - lims[i-1] gives the number of results for the i-th query. With a single query we used in k-NN,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,13 @@ protected static final boolean isRunningAgainstOldCluster() {
protected final Optional<String> getBWCVersion() {
return Optional.ofNullable(System.getProperty(BWC_VERSION, null));
}

@Override
protected Settings getKNNDefaultIndexSettings() {
if (isApproximateThresholdSupported(getBWCVersion())) {
return super.getKNNDefaultIndexSettings();
}
// for bwc will return old default setting without approximate value threshold setting
return getDefaultIndexSettings();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,13 @@ protected final Optional<String> getBWCVersion() {
return Optional.ofNullable(System.getProperty(BWC_VERSION, null));
}

@Override
protected Settings getKNNDefaultIndexSettings() {
if (isApproximateThresholdSupported(getBWCVersion())) {
return super.getKNNDefaultIndexSettings();
}
// for bwc will return old default setting without approximate value threshold setting
return getDefaultIndexSettings();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public void testKNNDefaultIndexSettings() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
createKnnIndex(testIndex, getDefaultIndexSettings(), createKnnIndexMapping(TEST_FIELD, DIMENSIONS));
int docIdOld = 0;
addKNNDocs(testIndex, TEST_FIELD, DIMENSIONS, docIdOld, NUM_DOCS);
break;
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/opensearch/knn/index/KNNSettings.java
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ public static Integer getFilteredExactSearchThreshold(final String indexName) {
.getAsInt(ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD, ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD_DEFAULT_VALUE);
}

public static boolean isShardLevelRescoringEnabledForDiskBasedVector(String indexName) {
public static boolean isShardLevelRescoringDisabledForDiskBasedVector(String indexName) {
return KNNSettings.state().clusterService.state()
.getMetadata()
.index(indexName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,11 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (rescoreContext == null) {
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, finalK);
} else {
boolean isShardLevelRescoringEnabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(knnQuery.getIndexName());
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension);
int firstPassK = rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension);
perLeafResults = doSearch(indexSearcher, leafReaderContexts, knnWeight, firstPassK);
if (isShardLevelRescoringEnabled == true) {
if (isShardLevelRescoringDisabled == false) {
ResultUtil.reduceToTopK(perLeafResults, firstPassK);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,17 @@ public static RescoreContext getDefault() {
* based on the vector dimension if shard-level rescoring is disabled.
*
* @param finalK The final number of results to return for the entire shard.
* @param isShardLevelRescoringEnabled A boolean flag indicating whether shard-level rescoring is enabled.
* If true, the dimension-based oversampling logic is bypassed.
* @param isShardLevelRescoringDisabled A boolean flag indicating whether shard-level rescoring is disabled.
* If false, the dimension-based oversampling logic is bypassed.
* @param dimension The dimension of the vector. This is used to determine the oversampling factor when
* shard-level rescoring is disabled.
* @return The number of results to return for the first pass of rescoring, adjusted by the oversample factor.
*/
public int getFirstPassK(int finalK, boolean isShardLevelRescoringEnabled, int dimension) {
public int getFirstPassK(int finalK, boolean isShardLevelRescoringDisabled, int dimension) {
// Only apply default dimension-based oversampling logic when:
// 1. Shard-level rescoring is disabled
// 2. The oversample factor was not provided by the user
if (!isShardLevelRescoringEnabled && !userProvided) {
if (isShardLevelRescoringDisabled && !userProvided) {
// Apply new dimension-based oversampling logic when shard-level rescoring is disabled
if (dimension >= DIMENSION_THRESHOLD_1000) {
oversampleFactor = OVERSAMPLE_FACTOR_1000; // No oversampling for dimensions >= 1000
Expand Down
9 changes: 1 addition & 8 deletions src/test/java/org/opensearch/knn/KNNSingleNodeTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public void tearDown() throws Exception {
* Create a k-NN index with default settings
*/
protected IndexService createKNNIndex(String indexName) {
return createIndex(indexName, getKNNDefaultIndexSettings());
return createIndex(indexName, getKNNDefaultIndexSettingsBuildsGraphAlways());
}

/**
Expand Down Expand Up @@ -161,13 +161,6 @@ protected void createKnnNestedIndexMapping(String indexName, String fieldPath, I
OpenSearchAssertions.assertAcked(client().admin().indices().putMapping(request).actionGet());
}

/**
* Get default k-NN settings for test cases
*/
protected Settings getKNNDefaultIndexSettings() {
return Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", true).build();
}

/**
* Get default k-NN settings for test cases with build graph always
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,11 @@ public void testCreateIndexWithInvalidSpaceType() throws IOException {
}

public void testUpdateIndexSetting() throws IOException {
Settings settings = Settings.builder().put("index.knn", true).put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 512).build();
Settings settings = Settings.builder()
.put("index.knn", true)
.put(KNNSettings.KNN_ALGO_PARAM_EF_SEARCH, 512)
.put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0)
.build();
createKnnIndex(INDEX_NAME, settings, createKnnIndexMapping(FIELD_NAME, 2));
assertEquals("512", getIndexSettingByName(INDEX_NAME, KNNSettings.KNN_ALGO_PARAM_EF_SEARCH));

Expand All @@ -122,7 +126,7 @@ public void testUpdateIndexSetting() throws IOException {

@SuppressWarnings("unchecked")
public void testCacheRebuiltAfterUpdateIndexSettings() throws Exception {
createKnnIndex(INDEX_NAME, buildKNNIndexSettings(0), createKnnIndexMapping(FIELD_NAME, 2));
createKnnIndex(INDEX_NAME, getKNNDefaultIndexSettings(), createKnnIndexMapping(FIELD_NAME, 2));

Float[] vector = { 6.0f, 6.0f };
addKnnDoc(INDEX_NAME, "1", FIELD_NAME, vector);
Expand Down
6 changes: 3 additions & 3 deletions src/test/java/org/opensearch/knn/index/KNNSettingsTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,15 @@ public void testGetEfSearch_whenEFSearchValueSetByUser_thenReturnValue() {
}

@SneakyThrows
public void testShardLevelRescoringEnabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
public void testShardLevelRescoringDisabled_whenNoValuesProvidedByUser_thenDefaultSettingsUsed() {
Node mockNode = createMockNode(Collections.emptyMap());
mockNode.start();
ClusterService clusterService = mockNode.injector().getInstance(ClusterService.class);
mockNode.client().admin().cluster().state(new ClusterStateRequest()).actionGet();
mockNode.client().admin().indices().create(new CreateIndexRequest(INDEX_NAME)).actionGet();
KNNSettings.state().setClusterService(clusterService);

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertFalse(shardLevelRescoringDisabled);
}
Expand All @@ -188,7 +188,7 @@ public void testShardLevelRescoringDisabled_whenValueProvidedByUser_thenSettingA

mockNode.client().admin().indices().updateSettings(new UpdateSettingsRequest(rescoringDisabledSetting, INDEX_NAME)).actionGet();

boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(INDEX_NAME);
boolean shardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(INDEX_NAME);
mockNode.close();
assertEquals(userDefinedRescoringDisabled, shardLevelRescoringDisabled);
}
Expand Down
4 changes: 2 additions & 2 deletions src/test/java/org/opensearch/knn/index/OpenSearchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ public void testKNNIndex_whenBuildGraphThresholdIsPresent_thenGetThresholdValue(
public void testKNNIndex_whenBuildThresholdIsNotProvided_thenShouldNotReturnSetting() throws Exception {
final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT));
final String indexName = "test-index-with-build-graph-settings";
createKnnIndex(indexName, knnIndexMapping);
createKnnIndex(indexName, getDefaultIndexSettings(), knnIndexMapping);
final String buildVectorDataStructureThresholdSetting = getIndexSettingByName(
indexName,
KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD
Expand All @@ -655,7 +655,7 @@ public void testKNNIndex_whenBuildThresholdIsNotProvided_thenShouldNotReturnSett
public void testKNNIndex_whenGetIndexSettingWithDefaultIsCalled_thenReturnDefaultBuildGraphThresholdValue() throws Exception {
final String knnIndexMapping = createKnnIndexMapping(FIELD_NAME, KNNEngine.getMaxDimensionByEngine(KNNEngine.DEFAULT));
final String indexName = "test-index-with-build-vector-graph-settings";
createKnnIndex(indexName, knnIndexMapping);
createKnnIndex(indexName, getDefaultIndexSettings(), knnIndexMapping);
final String buildVectorDataStructureThresholdSetting = getIndexSettingByName(
indexName,
KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ public void testRescoreWhenShardLevelRescoringEnabled() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

// Mock ResultUtil to return valid TopDocs
mockedResultUtil.when(() -> ResultUtil.resultMapToTopDocs(any(), anyInt()))
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testRescore() {
) {

// When shard-level re-scoring is enabled
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringEnabledForDiskBasedVector(any())).thenReturn(true);
mockedKnnSettings.when(() -> KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(any())).thenReturn(false);

mockedResultUtil.when(() -> ResultUtil.reduceToTopK(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
mockedResultUtil.when(() -> ResultUtil.resultMapToDocIds(any(), anyInt())).thenAnswer(InvocationOnMock::callRealMethod);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,23 @@ public void testGetFirstPassK() {
float oversample = 2.6f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build();
int finalK = 100;
boolean isShardLevelRescoringEnabled = true;
boolean isShardLevelRescoringDisabled = false;
int dimension = 500;

// Case 1: Test with standard oversample factor when shard-level rescoring is enabled
assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(260, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 2: Test with a very small finalK that should result in a value less than MIN_FIRST_PASS_RESULTS
finalK = 1;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 3: Test with finalK = 0, should return MIN_FIRST_PASS_RESULTS
finalK = 0;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 4: Test with finalK = MAX_FIRST_PASS_RESULTS, should cap at MAX_FIRST_PASS_RESULTS
finalK = MAX_FIRST_PASS_RESULTS;
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MAX_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));
}

public void testGetFirstPassKWithDimensionBasedOversampling() {
Expand All @@ -42,44 +42,44 @@ public void testGetFirstPassKWithDimensionBasedOversampling() {
// Case 1: Test no oversampling for dimensions >= 1000 when shard-level rescoring is disabled
dimension = 1000;
RescoreContext rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensuring dimension-based logic applies
assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // No oversampling
assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // No oversampling

// Case 2: Test 2x oversampling for dimensions >= 768 but < 1000 when shard-level rescoring is disabled
dimension = 800;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
assertEquals(200, rescoreContext.getFirstPassK(finalK, false, dimension)); // 2x oversampling
assertEquals(200, rescoreContext.getFirstPassK(finalK, true, dimension)); // 2x oversampling

// Case 3: Test 3x oversampling for dimensions < 768 when shard-level rescoring is disabled
dimension = 700;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure previous values don't carry over
assertEquals(300, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling
assertEquals(300, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling

// Case 4: Shard-level rescoring enabled, oversample factor should be used as provided by the user (ignore dimension)
rescoreContext = RescoreContext.builder().oversampleFactor(5.0f).userProvided(true).build(); // Provided by user
dimension = 500;
assertEquals(500, rescoreContext.getFirstPassK(finalK, true, dimension)); // User-defined oversample factor should be used
assertEquals(500, rescoreContext.getFirstPassK(finalK, false, dimension)); // User-defined oversample factor should be used

// Case 5: Test finalK where oversampling factor results in a value less than MIN_FIRST_PASS_RESULTS
finalK = 10;
dimension = 700;
rescoreContext = RescoreContext.builder().userProvided(false).build(); // Ensure dimension-based logic applies
assertEquals(100, rescoreContext.getFirstPassK(finalK, false, dimension)); // 3x oversampling results in 30
assertEquals(100, rescoreContext.getFirstPassK(finalK, true, dimension)); // 3x oversampling results in 30
}

public void testGetFirstPassKWithMinPassK() {
float oversample = 0.5f;
RescoreContext rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
boolean isShardLevelRescoringEnabled = false;
boolean isShardLevelRescoringDisabled = true;

// Case 1: Test where finalK * oversample is smaller than MIN_FIRST_PASS_RESULTS
int finalK = 10;
int dimension = 700;
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));

// Case 2: Test where finalK * oversample results in exactly MIN_FIRST_PASS_RESULTS
finalK = 100;
oversample = 1.0f; // This will result in exactly 100 (MIN_FIRST_PASS_RESULTS)
rescoreContext = RescoreContext.builder().oversampleFactor(oversample).userProvided(true).build(); // User provided
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringEnabled, dimension));
assertEquals(MIN_FIRST_PASS_RESULTS, rescoreContext.getFirstPassK(finalK, isShardLevelRescoringDisabled, dimension));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.NestedKnnDocBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.mapper.Mode;
Expand Down Expand Up @@ -329,6 +330,7 @@ private void createKnnIndex(
.put("number_of_shards", numOfShards)
.put("number_of_replicas", 0)
.put("index.knn", true)
.put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0)
.build();
createKnnIndex(INDEX_NAME, settings, mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.SpaceType;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.opensearch.client.Request;
Expand Down Expand Up @@ -840,7 +841,11 @@ private void createIndexAndAssertScriptScore(
/*
* Create knn index and populate data
*/
Settings settings = Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", enableKnn).build();
Settings.Builder builder = Settings.builder().put("number_of_shards", 1).put("number_of_replicas", 0).put("index.knn", enableKnn);
if (enableKnn) {
builder.put(KNNSettings.INDEX_KNN_ADVANCED_APPROXIMATE_THRESHOLD, 0);
}
Settings settings = builder.build();
createKnnIndex(INDEX_NAME, settings, mapper);
try {
final int numDocsWithField = randomIntBetween(4, 10);
Expand Down
Loading

0 comments on commit 4124d06

Please sign in to comment.