Skip to content

Commit

Permalink
fix misusing doc value
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr committed May 9, 2024
1 parent 638c310 commit 33620b8
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,21 @@ public abstract class KNNVectorScriptDocValues extends ScriptDocValues<float[]>
@Getter
private final VectorDataType vectorDataType;
private boolean docExists = false;
private int lastDocID = -1;

@Override
public void setNextDocId(int docId) throws IOException {
docExists = vectorValues.docID() == docId || vectorValues.advance(docId) == docId;
if (docId < lastDocID) {
throw new IllegalArgumentException("docs were sent out-of-order: lastDocID=" + lastDocID + " vs docID=" + docId);
}

lastDocID = docId;

int curDocID = vectorValues.docID();
if (lastDocID > curDocID) {
curDocID = vectorValues.advance(docId);
}
docExists = lastDocID == curDocID;
}

public float[] getValue() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.function.BiFunction;
import java.util.function.Function;
import org.opensearch.ExceptionsHelper;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.common.KNNConstants;
Expand Down Expand Up @@ -193,7 +194,7 @@ public void testUnequalDimensions() throws Exception {
}

@SuppressWarnings("unchecked")
public void testKNNScoreforNonVectorDocument() throws Exception {
public void testKNNScoreForNonVectorDocument() throws Exception {
/*
* Create knn index and populate data
*/
Expand Down Expand Up @@ -599,7 +600,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception {
if (spaceType != SpaceType.HAMMING_BIT) {
final float[] queryVector = randomVector(dimensions);
final BiFunction<float[], float[], Float> scoreFunction = getScoreFunction(spaceType, queryVector);
createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector);
createIndexAndAssertScriptScore(testMapping, spaceType, scoreFunction, dimensions, queryVector, true);
}
}
}
Expand Down Expand Up @@ -634,12 +635,22 @@ private float[] randomVector(int dimensions) {
return vector;
}

private Map<String, KNNResult> createDataset(Function<float[], Float> scoreFunction, int dimensions, int numDocs) {
final Map<String, KNNResult> dataset = new HashMap<>(numDocs);
for (int i = 0; i < numDocs; i++) {
private Map<String, KNNResult> createDataset(
Function<float[], Float> scoreFunction,
int dimensions,
int numDocsWithField,
boolean dense
) {
final Map<String, KNNResult> dataset = new HashMap<>(dense ? numDocsWithField : numDocsWithField * 3);
int id = 0;
for (int i = 0; i < numDocsWithField; i++) {
final int dummyDocs = dense ? 0 : randomIntBetween(1, 5);
for (int j = 0; j < dummyDocs; j++) {
dataset.put(Integer.toString(id++), null);
}
final float[] vector = randomVector(dimensions);
final float score = scoreFunction.apply(vector);
dataset.put(Integer.toString(i), new KNNResult(Integer.toString(i), vector, score));
dataset.put(Integer.toString(id), new KNNResult(Integer.toString(id++), vector, score));
}
return dataset;
}
Expand Down Expand Up @@ -678,7 +689,8 @@ private void testKNNScriptScore(SpaceType spaceType) throws Exception {
final float[] queryVector = randomVector(dims);
final BiFunction<float[], float[], Float> scoreFunction = getScoreFunction(spaceType, queryVector);
for (String mapper : createMappers(dims)) {
createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector);
createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, true);
createIndexAndAssertScriptScore(mapper, spaceType, scoreFunction, dims, queryVector, false);
}
}

Expand All @@ -687,16 +699,20 @@ private void createIndexAndAssertScriptScore(
SpaceType spaceType,
BiFunction<float[], float[], Float> scoreFunction,
int dimensions,
float[] queryVector
float[] queryVector,
boolean dense
) throws Exception {
/*
* Create knn index and populate data
*/
createKnnIndex(INDEX_NAME, mapper);
Map<String, KNNResult> dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, randomIntBetween(4, 10));
for (Map.Entry<String, KNNResult> entry : dataset.entrySet()) {
addKnnDoc(INDEX_NAME, entry.getKey(), FIELD_NAME, entry.getValue().getVector());
}
final int numDocsWithField = randomIntBetween(4, 10);
Map<String, KNNResult> dataset = createDataset(v -> scoreFunction.apply(queryVector, v), dimensions, numDocsWithField, dense);
final float[] dummyVector = new float[1];
dataset.forEach((k, v) -> {
final float[] vector = (v != null) ? v.getVector() : dummyVector;
ExceptionsHelper.catchAsRuntimeException(() -> addKnnDoc(INDEX_NAME, k, (v != null) ? FIELD_NAME : "dummy", vector));
});

/**
* Construct Search Request
Expand All @@ -712,7 +728,7 @@ private void createIndexAndAssertScriptScore(
params.put("field", FIELD_NAME);
params.put("query_value", queryVector);
params.put("space_type", spaceType.getValue());
Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params);
Request request = constructKNNScriptQueryRequest(INDEX_NAME, qb, params, numDocsWithField);
Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

Expand Down
10 changes: 9 additions & 1 deletion src/testFixtures/java/org/opensearch/knn/KNNRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.opensearch.index.query.functionscore.ScriptScoreQueryBuilder;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.script.Script;
import org.opensearch.search.SearchService;
import org.opensearch.search.aggregations.metrics.ScriptedMetricAggregationBuilder;

import javax.management.MBeanServerInvocationHandler;
Expand Down Expand Up @@ -954,9 +955,16 @@ protected Request constructScriptScoreContextSearchRequest(
}

protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map<String, Object> params) throws Exception {
return constructKNNScriptQueryRequest(indexName, qb, params, SearchService.DEFAULT_SIZE);
}

protected Request constructKNNScriptQueryRequest(String indexName, QueryBuilder qb, Map<String, Object> params, int size)
throws Exception {
Script script = new Script(Script.DEFAULT_SCRIPT_TYPE, KNNScoringScriptEngine.NAME, KNNScoringScriptEngine.SCRIPT_SOURCE, params);
ScriptScoreQueryBuilder sc = new ScriptScoreQueryBuilder(qb, script);
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
XContentBuilder builder = XContentFactory.jsonBuilder().startObject();
builder.field("size", size);
builder.startObject("query");
builder.startObject("script_score");
builder.field("query");
sc.query().toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand Down

0 comments on commit 33620b8

Please sign in to comment.