diff --git a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java index f69ad850e2..55ff655167 100644 --- a/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java +++ b/src/main/java/org/opensearch/knn/index/KNNVectorScriptDocValues.java @@ -25,10 +25,21 @@ public abstract class KNNVectorScriptDocValues extends ScriptDocValues @Getter private final VectorDataType vectorDataType; private boolean docExists = false; + private int lastDocID = -1; @Override public void setNextDocId(int docId) throws IOException { - docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId; + if (docId < lastDocID) { + throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId); + } + + lastDocID = docId; + + int curDocID = vectorValues.docID(); + if (lastDocID > curDocID) { + curDocID = vectorValues.advance(docId); + } + docExists = lastDocID == curDocID; } public float[] getValue() { diff --git a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java index 11c626ff7b..71c99eb5fe 100644 --- a/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java +++ b/src/test/java/org/opensearch/knn/plugin/script/KNNScriptScoringIT.java @@ -7,6 +7,7 @@ import java.util.function.BiFunction; import java.util.function.Function; +import org.opensearch.ExceptionsHelper; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.KNNResult; import org.opensearch.knn.common.KNNConstants; @@ -193,7 +194,7 @@ public void testUnequalDimensions() throws Exception { } @SuppressWarnings("unchecked") - public void testKNNScoreforNonVectorDocument() throws Exception { + public void testKNNScoreForNonVectorDocument() throws Exception { /* * Create knn index and populate data */ @@ -599,7 +600,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception { if (spaceType != SpaceType.HAMMING_BIT) { final float[] queryVector = randomVector(dimensions); final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); - createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector); + createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true); } } } @@ -634,12 +635,22 @@ private float[] randomVector(int dimensions) { return vector; } - private Map createDataset(Function scoreFunction, int dimensions, int numDocs) { - final Map dataset = new HashMap<>(numDocs); - for (int i = 0; i < numDocs; i++) { + private Map createDataset( + Function scoreFunction, + int dimensions, + int numDocsWithField, + boolean dense + ) { + final Map dataset = new HashMap<>(dense ? numDocsWithField : numDocsWithField * 3); + int id = 0; + for (int i = 0; i < numDocsWithField; i++) { + final int dummyDocs = dense ? 0 : randomIntBetween(1, 5); + for (int j = 0; j < dummyDocs; j++) { + dataset.put(Integer.toString(id++), null); + } final float[] vector = randomVector(dimensions); final float score = scoreFunction.apply(vector); - dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score)); + dataset.put(Integer.toString(id), new KNNResult(Integer.toString(id++), vector, score)); } return dataset; } @@ -678,7 +689,8 @@ private void testKNNScriptScore(SpaceType spaceType) throws Exception { final float[] queryVector = randomVector(dims); final BiFunction scoreFunction = getScoreFunction(spaceType, queryVector); for (String mapper : createMappers(dims)) { - createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector); + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true); + createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false); } } @@ -687,16 +699,20 @@ private void createIndexAndAssertScriptScore( SpaceType spaceType, BiFunction scoreFunction, int dimensions, - float[] queryVector + float[] queryVector, + boolean dense ) throws Exception { /* * Create knn index and populate data */ createKnnIndex(INDEX_NAME, mapper); - Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10)); - for (Map.Entry entry : dataset.entrySet()) { - addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector()); - } + final int numDocsWithField = randomIntBetween(4, 10); + Map dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, numDocsWithField, dense); + final float[] dummyVector = new float[1]; + dataset.forEach((k, v) -> { + final float[] vector = (v != null) ? v.getVector() : dummyVector; + ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector)); + }); /** * Construct Search Request @@ -712,7 +728,7 @@ private void createIndexAndAssertScriptScore( params.put("field", FIELD_NAME); params.put("query_value", queryVector); params.put("space_type", spaceType.getValue()); - Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params); + Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, numDocsWithField); Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); diff --git a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java index 5010ff6ee3..6ba1eae656 100644 --- a/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java +++ b/src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java @@ -39,6 +39,7 @@ import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder; import org.opensearch.core.rest.RestStatus; import org.opensearch.script.Script; +import org.opensearch.search.SearchService; import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder; import javax.management.MBeanServerInvocationHandler; @@ -954,9 +955,16 @@ protected Request constructScriptScoreContextSearchRequest( } protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params) throws Exception { + return constructKNNScriptQueryRequest(indexName, qb, params, SearchService.DEFAULT_SIZE); + } + + protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map params, int size) + throws Exception { Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, params); ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script); - XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); + XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); + builder.field("size", size); + builder.startObject("query"); builder.startObject("script_score"); builder.field("query"); sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS);