Skip to content

Commit

Permalink
Switch from byte stream to byte ref (opensearch-project#1825)
Browse files Browse the repository at this point in the history
Avoids copy during serialization and deserialization by switching from
requiring byte streams to only requiring byte refs. This can speed up
operations by 15-20% for exact search.

This will not have impact on KnnVectorsFormat structures. Those will use
lucenes vector serde which is already optimized. This change is meant to
add a boost in performance until that is used completely.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Jul 16, 2024
1 parent 48478b4 commit e5b90ce
Show file tree
Hide file tree
Showing 16 changed files with 114 additions and 117 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Adds dynamic query parameter ef_search in radial search faiss engine [#1790](https://github.com/opensearch-project/k-NN/pull/1790)
* Add binary format support with HNSW method in Faiss Engine [#1781](https://github.com/opensearch-project/k-NN/pull/1781)
### Enhancements
* Switch from byte stream to byte ref for serde [#1825](https://github.com/opensearch-project/k-NN/pull/1825)
### Bug Fixes
* Fixing the arithmetic to find the number of vectors to stream from java to jni layer.[#1804](https://github.com/opensearch-project/k-NN/pull/1804)
* Release memory properly for an array type [#1820](https://github.com/opensearch-project/k-NN/pull/1820)
Expand Down
6 changes: 2 additions & 4 deletions src/main/java/org/opensearch/knn/index/VectorDataType.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.util.Arrays;
import java.util.Locale;
import java.util.Objects;
Expand Down Expand Up @@ -71,9 +70,8 @@ public FieldType createKnnVectorFieldType(int dimension, VectorSimilarityFunctio

@Override
public float[] getVectorFromBytesRef(BytesRef binaryValue) {
ByteArrayInputStream byteStream = new ByteArrayInputStream(binaryValue.bytes, binaryValue.offset, binaryValue.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
return vectorSerializer.byteToFloatArray(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(binaryValue);
return vectorSerializer.byteToFloatArray(binaryValue);
}

};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
package org.opensearch.knn.index.codec.transfer;

import lombok.Data;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.SerializationMode;

import java.io.ByteArrayInputStream;

/**
* Abstract class to transfer vector value from Java to native memory
*/
Expand All @@ -36,9 +35,9 @@ public VectorTransfer(final long vectorsStreamingMemoryLimit) {
/**
* Transfer a single vector
*
* @param byteStream a vector in byte stream format
* @param bytesRef a vector in bytes format
*/
abstract public void transfer(final ByteArrayInputStream byteStream);
abstract public void transfer(final BytesRef bytesRef);

/**
* Close the transfer
Expand All @@ -48,8 +47,8 @@ public VectorTransfer(final long vectorsStreamingMemoryLimit) {
/**
* Get serialization mode of given byte stream
*
* @param byteStream byte stream of a vector
* @param bytesRef bytes of a vector
* @return serialization mode
*/
abstract public SerializationMode getSerializationMode(final ByteArrayInputStream byteStream);
abstract public SerializationMode getSerializationMode(final BytesRef bytesRef);
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

package org.opensearch.knn.index.codec.transfer;

import org.apache.lucene.util.ArrayUtil;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.SerializationMode;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -30,19 +31,18 @@ public void init(final long totalLiveDocs) {
}

@Override
public void transfer(final ByteArrayInputStream byteStream) {
final byte[] vector = byteStream.readAllBytes();
dimension = vector.length * 8;
public void transfer(final BytesRef bytesRef) {
dimension = bytesRef.length * 8;
if (vectorsPerTransfer == Integer.MIN_VALUE) {
vectorsPerTransfer = (vector.length * totalLiveDocs) / vectorsStreamingMemoryLimit;
vectorsPerTransfer = (bytesRef.length * totalLiveDocs) / vectorsStreamingMemoryLimit;
// This condition comes if vectorsStreamingMemoryLimit is higher than total number floats to transfer
// Doing this will reduce 1 extra trip to JNI layer.
if (vectorsPerTransfer == 0) {
vectorsPerTransfer = totalLiveDocs;
}
}

vectorList.add(vector);
vectorList.add(ArrayUtil.copyOfSubArray(bytesRef.bytes, bytesRef.offset, bytesRef.offset + bytesRef.length));
if (vectorList.size() == vectorsPerTransfer) {
transfer();
}
Expand All @@ -54,7 +54,7 @@ public void close() {
}

@Override
public SerializationMode getSerializationMode(final ByteArrayInputStream byteStream) {
public SerializationMode getSerializationMode(final BytesRef bytesRef) {
return SerializationMode.COLLECTIONS_OF_BYTES;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

package org.opensearch.knn.index.codec.transfer;

import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;
import org.opensearch.knn.index.codec.util.SerializationMode;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.util.ArrayList;
import java.util.List;

Expand All @@ -32,9 +32,9 @@ public void init(final long totalLiveDocs) {
}

@Override
public void transfer(final ByteArrayInputStream byteStream) {
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
public void transfer(final BytesRef bytesRef) {
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(bytesRef);
final float[] vector = vectorSerializer.byteToFloatArray(bytesRef);
dimension = vector.length;

if (vectorsPerTransfer == Integer.MIN_VALUE) {
Expand All @@ -58,8 +58,8 @@ public void close() {
}

@Override
public SerializationMode getSerializationMode(final ByteArrayInputStream byteStream) {
return KNNVectorSerializerFactory.getSerializerModeFromStream(byteStream);
public SerializationMode getSerializationMode(final BytesRef bytesRef) {
return KNNVectorSerializerFactory.getSerializerModeFromBytesRef(bytesRef);
}

private void transfer() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.knn.index.codec.KNN80Codec.KNN80BinaryDocValues;
import org.opensearch.knn.index.codec.transfer.VectorTransfer;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -41,16 +40,22 @@ public static final class Pair {
public SerializationMode serializationMode;
}

/**
* Extract docIds and vectors from binary doc values.
*
* @param values Binary doc values
* @param vectorTransfer Utility to make transfer
* @return KNNCodecUtil.Pair representing doc ids and corresponding vectors
* @throws IOException thrown when unable to get binary of vectors
*/
public static KNNCodecUtil.Pair getPair(final BinaryDocValues values, final VectorTransfer vectorTransfer) throws IOException {
List<Integer> docIdList = new ArrayList<>();
SerializationMode serializationMode = SerializationMode.COLLECTION_OF_FLOATS;
vectorTransfer.init(getTotalLiveDocsCount(values));
for (int doc = values.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = values.nextDoc()) {
BytesRef bytesref = values.binaryValue();
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesref.bytes, bytesref.offset, bytesref.length)) {
serializationMode = vectorTransfer.getSerializationMode(byteStream);
vectorTransfer.transfer(byteStream);
}
serializationMode = vectorTransfer.getSerializationMode(bytesref);
vectorTransfer.transfer(bytesref);
docIdList.add(doc);
}
vectorTransfer.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.index.codec.util;

import org.apache.lucene.util.BytesRef;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -31,8 +33,8 @@ public byte[] floatToByteArray(float[] input) {
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
try {
public float[] byteToFloatArray(BytesRef bytesRef) {
try (ByteArrayInputStream byteStream = new ByteArrayInputStream(bytesRef.bytes, bytesRef.offset, bytesRef.length)) {
final ObjectInputStream objectStream = new ObjectInputStream(byteStream);
final float[] vector = (float[]) objectStream.readObject();
return vector;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import org.apache.lucene.util.BytesRef;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.stream.IntStream;
Expand All @@ -26,15 +27,13 @@ public byte[] floatToByteArray(float[] input) {
}

@Override
public float[] byteToFloatArray(ByteArrayInputStream byteStream) {
if (byteStream == null || byteStream.available() % BYTES_IN_FLOAT != 0) {
public float[] byteToFloatArray(BytesRef bytesRef) {
if (bytesRef == null || bytesRef.length % BYTES_IN_FLOAT != 0) {
throw new IllegalArgumentException("Byte stream cannot be deserialized to array of floats");
}
final byte[] vectorAsByteArray = new byte[byteStream.available()];
byteStream.read(vectorAsByteArray, 0, byteStream.available());
final int sizeOfFloatArray = vectorAsByteArray.length / BYTES_IN_FLOAT;
final int sizeOfFloatArray = bytesRef.length / BYTES_IN_FLOAT;
final float[] vector = new float[sizeOfFloatArray];
ByteBuffer.wrap(vectorAsByteArray).asFloatBuffer().get(vector);
ByteBuffer.wrap(bytesRef.bytes, bytesRef.offset, bytesRef.length).asFloatBuffer().get(vector);
return vector;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

package org.opensearch.knn.index.codec.util;

import java.io.ByteArrayInputStream;
import org.apache.lucene.util.BytesRef;

/**
* Interface abstracts the vector serializer object that is responsible for serialization and de-serialization of k-NN vector
Expand All @@ -20,8 +20,9 @@ public interface KNNVectorSerializer {

/**
* Deserializes all bytes from the stream to array of floats
* @param byteStream stream of bytes that will be used for deserialization to array of floats
*
* @param bytesRef bytes that will be used for deserialization to array of floats
* @return array of floats deserialized from the stream
*/
float[] byteToFloatArray(ByteArrayInputStream byteStream);
float[] byteToFloatArray(BytesRef bytesRef);
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
package org.opensearch.knn.index.codec.util;

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.util.BytesRef;

import java.io.ByteArrayInputStream;
import java.io.ObjectStreamConstants;
import java.util.Arrays;
import java.util.Map;

import static org.opensearch.knn.index.codec.util.SerializationMode.ARRAY;
Expand Down Expand Up @@ -51,25 +50,24 @@ public static KNNVectorSerializer getDefaultSerializer() {
return getSerializerBySerializationMode(COLLECTION_OF_FLOATS);
}

public static KNNVectorSerializer getSerializerByStreamContent(final ByteArrayInputStream byteStream) {
final SerializationMode serializationMode = getSerializerModeFromStream(byteStream);
public static KNNVectorSerializer getSerializerByBytesRef(final BytesRef bytesRef) {
final SerializationMode serializationMode = getSerializerModeFromBytesRef(bytesRef);
return getSerializerBySerializationMode(serializationMode);
}

public static SerializationMode getSerializerModeFromStream(ByteArrayInputStream byteStream) {
int numberOfAvailableBytesInStream = byteStream.available();
if (numberOfAvailableBytesInStream < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
public static SerializationMode getSerializerModeFromBytesRef(BytesRef bytesRef) {
int numberOfAvailableBytes = bytesRef.length;
if (numberOfAvailableBytes < ARRAY_HEADER_OFFSET) {
return getSerializerOrThrowError(numberOfAvailableBytes, COLLECTION_OF_FLOATS);
}
final byte[] byteArray = new byte[SERIALIZATION_PROTOCOL_HEADER_PREFIX.length];
byteStream.read(byteArray, 0, SERIALIZATION_PROTOCOL_HEADER_PREFIX.length);
byteStream.reset();
// checking if stream protocol grammar in header is valid for serialized array
if (Arrays.equals(SERIALIZATION_PROTOCOL_HEADER_PREFIX, byteArray)) {
int numberOfAvailableBytesAfterHeader = numberOfAvailableBytesInStream - ARRAY_HEADER_OFFSET;
return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY);

for (int i = 0; i < SERIALIZATION_PROTOCOL_HEADER_PREFIX.length; i++) {
if (bytesRef.bytes[i + bytesRef.offset] != SERIALIZATION_PROTOCOL_HEADER_PREFIX[i]) {
return getSerializerOrThrowError(numberOfAvailableBytes, COLLECTION_OF_FLOATS);
}
}
return getSerializerOrThrowError(numberOfAvailableBytesInStream, COLLECTION_OF_FLOATS);
int numberOfAvailableBytesAfterHeader = numberOfAvailableBytes - ARRAY_HEADER_OFFSET;
return getSerializerOrThrowError(numberOfAvailableBytesAfterHeader, ARRAY);
}

private static SerializationMode getSerializerOrThrowError(int numberOfRemainingBytes, final SerializationMode serializationMode) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import org.opensearch.knn.index.codec.util.KNNVectorSerializer;
import org.opensearch.knn.index.codec.util.KNNVectorSerializerFactory;

import java.io.ByteArrayInputStream;
import java.io.IOException;

/**
Expand Down Expand Up @@ -72,9 +71,8 @@ public float score() {

protected float computeScore() throws IOException {
final BytesRef value = binaryDocValues.binaryValue();
final ByteArrayInputStream byteStream = new ByteArrayInputStream(value.bytes, value.offset, value.length);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByStreamContent(byteStream);
final float[] vector = vectorSerializer.byteToFloatArray(byteStream);
final KNNVectorSerializer vectorSerializer = KNNVectorSerializerFactory.getSerializerByBytesRef(value);
final float[] vector = vectorSerializer.byteToFloatArray(value);
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@

import junit.framework.TestCase;
import lombok.SneakyThrows;
import org.apache.lucene.util.BytesRef;
import org.opensearch.knn.index.codec.util.SerializationMode;
import org.opensearch.knn.jni.JNICommons;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.util.Random;

Expand All @@ -19,17 +19,17 @@
public class VectorTransferByteTests extends TestCase {
@SneakyThrows
public void testTransfer_whenCalled_thenAdded() {
final ByteArrayInputStream bais1 = getByteArrayOfVectors(20);
final ByteArrayInputStream bais2 = getByteArrayOfVectors(20);
final BytesRef bytesRef1 = getByteArrayOfVectors(20);
final BytesRef bytesRef2 = getByteArrayOfVectors(20);
VectorTransferByte vectorTransfer = new VectorTransferByte(1000);
try {
vectorTransfer.init(2);

vectorTransfer.transfer(bais1);
vectorTransfer.transfer(bytesRef1);
// flush is not called
assertEquals(0, vectorTransfer.getVectorAddress());

vectorTransfer.transfer(bais2);
vectorTransfer.transfer(bytesRef2);
// flush should be called
assertNotEquals(0, vectorTransfer.getVectorAddress());
} finally {
Expand All @@ -41,16 +41,16 @@ public void testTransfer_whenCalled_thenAdded() {

@SneakyThrows
public void testSerializationMode_whenCalled_thenReturn() {
final ByteArrayInputStream bais = getByteArrayOfVectors(20);
final BytesRef bytesRef = getByteArrayOfVectors(20);
VectorTransferByte vectorTransfer = new VectorTransferByte(1000);

// Verify
assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, vectorTransfer.getSerializationMode(bais));
assertEquals(SerializationMode.COLLECTIONS_OF_BYTES, vectorTransfer.getSerializationMode(bytesRef));
}

private ByteArrayInputStream getByteArrayOfVectors(int vectorLength) throws IOException {
private BytesRef getByteArrayOfVectors(int vectorLength) throws IOException {
byte[] vector = new byte[vectorLength];
new Random().nextBytes(vector);
return new ByteArrayInputStream(vector);
return new BytesRef(vector);
}
}
Loading

0 comments on commit e5b90ce

Please sign in to comment.