From bbaaaf900e79e311dea3d5ff5fca2bf3976e7655 Mon Sep 17 00:00:00 2001 From: Vikasht34 Date: Mon, 26 Aug 2024 08:00:09 -0700 Subject: [PATCH] Integration of Quantization Framework for Binary Quantization with Indexing Flow (#1996) * Integration of Quantization Framework for Binary Quantization with Indexing Flow Signed-off-by: VIKASH TIWARI * Integration With Qunatization Config Signed-off-by: VIKASH TIWARI --------- Signed-off-by: VIKASH TIWARI --- .../NativeEngines990KnnVectorsWriter.java | 132 ++++++++++++--- .../DefaultIndexBuildStrategy.java | 38 +++-- .../codec/nativeindex/IndexBuildSetup.java | 40 +++++ .../MemOptimizedNativeIndexBuildStrategy.java | 39 +++-- .../codec/nativeindex/NativeIndexWriter.java | 65 ++++++- .../nativeindex/QuantizationIndexUtils.java | 68 ++++++++ .../nativeindex/model/BuildIndexParams.java | 8 +- .../engine/faiss/AbstractFaissMethod.java | 4 +- .../KNNVectorQuantizationTrainingRequest.java | 55 ++++++ .../QuantizationService.java | 128 ++++++++++++++ .../DefaultQuantizationState.java | 15 ++ .../MultiBitScalarQuantizationState.java | 46 +++++ .../OneBitScalarQuantizationState.java | 33 ++++ .../quantizationState/QuantizationState.java | 24 +++ .../models/requests/TrainingRequest.java | 6 +- .../quantizer/MultiBitScalarQuantizer.java | 16 +- .../quantizer/OneBitScalarQuantizer.java | 4 +- .../knn/quantization/quantizer/Quantizer.java | 4 +- .../quantizer/QuantizerHelper.java | 60 +++++-- ...NativeEngines990KnnVectorsFormatTests.java | 72 +++++++- .../DefaultIndexBuildStrategyTests.java | 116 +++++++++++++ ...ptimizedNativeIndexBuildStrategyTests.java | 118 +++++++++++++ .../QuantizationIndexUtilsTests.java | 109 ++++++++++++ .../knn/index/engine/faiss/FaissTests.java | 4 +- .../QuantizationServiceTests.java | 159 ++++++++++++++++++ .../opensearch/knn/integ/QFrameworkIT.java | 36 ++-- .../factory/QuantizerFactoryTests.java | 48 ++---- .../factory/QuantizerRegistryTests.java | 28 +-- .../QuantizationStateTests.java | 73 ++++++++ .../MultiBitScalarQuantizerTests.java | 8 +- .../quantizer/OneBitScalarQuantizerTests.java | 31 +++- 31 files changed, 1414 insertions(+), 173 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java create mode 100644 src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java create mode 100644 src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java create mode 100644 src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java create mode 100644 src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java create mode 100644 src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 65736a63e..43f4d7ad6 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -24,10 +24,13 @@ import org.apache.lucene.index.Sorter; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.NativeIndexWriter; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.util.ArrayList; @@ -46,6 +49,7 @@ public class NativeEngines990KnnVectorsWriter extends KnnVectorsWriter { private final FlatVectorsWriter flatVectorsWriter; private final List> fields = new ArrayList<>(); private boolean finished; + private final QuantizationService quantizationService = QuantizationService.getInstance(); /** * Add new field for indexing. @@ -68,17 +72,14 @@ public KnnFieldVectorsWriter addField(final FieldInfo fieldInfo) throws IOExc */ @Override public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { - // simply write data in the flat file flatVectorsWriter.flush(maxDoc, sortMap); for (final NativeEngineFieldVectorsWriter field : fields) { - final VectorDataType vectorDataType = extractVectorDataType(field.getFieldInfo()); - final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( - vectorDataType, - field.getDocsWithField(), - field.getVectors() + trainAndIndex( + field.getFieldInfo(), + (vectorDataType, fieldInfo, fieldVectorsWriter) -> getKNNVectorValues(vectorDataType, fieldVectorsWriter), + NativeIndexWriter::flushIndex, + field ); - - NativeIndexWriter.getWriter(field.getFieldInfo(), segmentWriteState).flushIndex(knnVectorValues); } } @@ -86,24 +87,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState) throws IOException { // This will ensure that we are merging the FlatIndex during force merge. flatVectorsWriter.mergeOneField(fieldInfo, mergeState); - // For merge, pick values from flat vector and reindex again. This will use the flush operation to create graphs - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); - final KNNVectorValues knnVectorValues; - switch (fieldInfo.getVectorEncoding()) { - case FLOAT32: - final FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); - break; - case BYTE: - final ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - knnVectorValues = KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); - break; - default: - throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); - } + trainAndIndex(fieldInfo, this::getKNNVectorValuesForMerge, NativeIndexWriter::mergeIndex, mergeState); - NativeIndexWriter.getWriter(fieldInfo, segmentWriteState).mergeIndex(knnVectorValues); } /** @@ -146,4 +132,102 @@ public long ramBytesUsed() { .sum(); } + /** + * Retrieves the {@link KNNVectorValues} for a specific field based on the vector data type and field writer. + * + * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. + * @param field The {@link NativeEngineFieldVectorsWriter} representing the field from which to retrieve vectors. + * @param The type of vectors being processed. + * @return The {@link KNNVectorValues} associated with the field. + */ + private KNNVectorValues getKNNVectorValues(final VectorDataType vectorDataType, final NativeEngineFieldVectorsWriter field) { + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, field.getDocsWithField(), field.getVectors()); + } + + /** + * Retrieves the {@link KNNVectorValues} for a specific field during a merge operation, based on the vector data type. + * + * @param vectorDataType The {@link VectorDataType} representing the type of vectors stored. + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. + * @param mergeState The {@link MergeState} representing the state of the merge operation. + * @param The type of vectors being processed. + * @return The {@link KNNVectorValues} associated with the field during the merge. + * @throws IOException If an I/O error occurs during the retrieval. + */ + private KNNVectorValues getKNNVectorValuesForMerge( + final VectorDataType vectorDataType, + final FieldInfo fieldInfo, + final MergeState mergeState + ) throws IOException { + switch (fieldInfo.getVectorEncoding()) { + case FLOAT32: + FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedFloats); + case BYTE: + ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); + return (KNNVectorValues) KNNVectorValuesFactory.getVectorValues(vectorDataType, mergedBytes); + default: + throw new IllegalStateException("Unsupported vector encoding [" + fieldInfo.getVectorEncoding() + "]"); + } + } + + /** + * Functional interface representing an operation that indexes the provided {@link KNNVectorValues}. + * + * @param The type of vectors being processed. + */ + @FunctionalInterface + private interface IndexOperation { + void buildAndWrite(NativeIndexWriter writer, KNNVectorValues knnVectorValues) throws IOException; + } + + /** + * Functional interface representing a method that retrieves {@link KNNVectorValues} based on + * the vector data type, field information, and the merge state. + * + * @param The type of the data representing the vector (e.g., {@link VectorDataType}). + * @param The metadata about the field. + * @param The state of the merge operation. + * @param The result of the retrieval, typically {@link KNNVectorValues}. + */ + @FunctionalInterface + private interface VectorValuesRetriever { + Result apply(DataType vectorDataType, FieldInfo fieldInfo, MergeState mergeState) throws IOException; + } + + /** + * Unified method for processing a field during either the indexing or merge operation. This method retrieves vector values + * based on the provided vector data type and applies the specified index operation, potentially including quantization if needed. + * + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field. + * @param vectorValuesRetriever A functional interface that retrieves {@link KNNVectorValues} based on the vector data type, + * field information, and additional context (e.g., merge state or field writer). + * @param indexOperation A functional interface that performs the indexing operation using the retrieved + * {@link KNNVectorValues}. + * @param VectorProcessingContext The additional context required for retrieving the vector values (e.g., {@link MergeState} or {@link NativeEngineFieldVectorsWriter}). + * From Flush we need NativeFieldWriter which contains total number of vectors while from Merge we need merge state which contains vector information + * @param The type of vectors being processed. + * @param The type of the context needed for retrieving the vector values. + * @throws IOException If an I/O error occurs during the processing. + */ + private void trainAndIndex( + final FieldInfo fieldInfo, + final VectorValuesRetriever> vectorValuesRetriever, + final IndexOperation indexOperation, + final C VectorProcessingContext + ) throws IOException { + final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + KNNVectorValues knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); + QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); + QuantizationState quantizationState = null; + if (quantizationParams != null) { + quantizationState = quantizationService.train(quantizationParams, knnVectorValues); + } + NativeIndexWriter writer = (quantizationParams != null) + ? NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState) + : NativeIndexWriter.getWriter(fieldInfo, segmentWriteState); + + knnVectorValues = vectorValuesRetriever.apply(vectorDataType, fieldInfo, VectorProcessingContext); + indexOperation.buildAndWrite(writer, knnVectorValues); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 5787ea76b..d2a6027db 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -39,16 +39,32 @@ public static DefaultIndexBuildStrategy getInstance() { return INSTANCE; } + /** + * Builds and writes a k-NN index using the provided vector values and index parameters. This method handles both + * quantized and non-quantized vectors, transferring them off-heap before building the index using native JNI services. + * + *

The method first iterates over the vector values to calculate the necessary bytes per vector. If quantization is + * enabled, the vectors are quantized before being transferred off-heap. Once all vectors are transferred, they are + * flushed and used to build the index. The index is then written to the specified path using JNI calls.

+ * + * @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index. + * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. + * @throws IOException If an I/O error occurs during the process of building and writing the index. + */ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { - iterateVectorValuesOnce(knnVectorValues); // to get bytesPerVector - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + // Needed to make sure we don't get 0 dimensions while initializing index + iterateVectorValuesOnce(knnVectorValues); + IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); + + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { + final List transferredDocIds = new ArrayList<>((int) knnVectorValues.totalLiveDocs()); - final List tranferredDocIds = new ArrayList<>(); while (knnVectorValues.docId() != NO_MORE_DOCS) { + Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); // append is true here so off heap memory buffer isn't overwritten - vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), true); - tranferredDocIds.add(knnVectorValues.docId()); + vectorTransfer.transfer(vector, true); + transferredDocIds.add(knnVectorValues.docId()); knnVectorValues.nextDoc(); } vectorTransfer.flush(true); @@ -60,12 +76,12 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector if (params.containsKey(MODEL_ID)) { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexInfo.getIndexPath(), (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), - indexInfo.getParameters(), + params, indexInfo.getKnnEngine() ); return null; @@ -73,11 +89,11 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector } else { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndex( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexInfo.getIndexPath(), - indexInfo.getParameters(), + params, indexInfo.getKnnEngine() ); return null; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java new file mode 100644 index 000000000..c6c999c07 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/IndexBuildSetup.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +/** + * IndexBuildSetup encapsulates the configuration and parameters required for building an index. + * This includes the size of each vector, the dimensions of the vectors, and any quantization-related + * settings such as the output and state of quantization. + */ +@Getter +@AllArgsConstructor +final class IndexBuildSetup { + /** + * The number of bytes per vector. + */ + private final int bytesPerVector; + + /** + * Dimension of Vector for Indexing + */ + private final int dimensions; + + /** + * The quantization output that will hold the quantized vector. + */ + private final QuantizationOutput quantizationOutput; + + /** + * The state of quantization, which may include parameters and trained models. + */ + private final QuantizationState quantizationState; +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index af80215b6..1115bfe05 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -40,45 +40,60 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { return INSTANCE; } + /** + * Builds and writes a k-NN index using the provided vector values and index parameters. This method handles both + * quantized and non-quantized vectors, transferring them off-heap before building the index using native JNI services. + * + *

The method first iterates over the vector values to calculate the necessary bytes per vector. If quantization is + * enabled, the vectors are quantized before being transferred off-heap. Once all vectors are transferred, they are + * flushed and used to build the index. The index is then written to the specified path using JNI calls.

+ * + * @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index. + * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. + * @throws IOException If an I/O error occurs during the process of building and writing the index. + */ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVectorValues knnVectorValues) throws IOException { - // Needed to make sure we dont get 0 dimensions while initializing index + // Needed to make sure we don't get 0 dimensions while initializing index iterateVectorValuesOnce(knnVectorValues); KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); + IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); // Initialize the index long indexMemoryAddress = AccessController.doPrivileged( (PrivilegedAction) () -> JNIService.initIndex( knnVectorValues.totalLiveDocs(), - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexParameters, engine ) ); - int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / knnVectorValues.bytesPerVector()); + int transferLimit = (int) Math.max(1, KNNSettings.getVectorStreamingMemoryLimit().getBytes() / indexBuildSetup.getBytesPerVector()); try (final OffHeapVectorTransfer vectorTransfer = getVectorTransfer(indexInfo.getVectorDataType(), transferLimit)) { - final List tranferredDocIds = new ArrayList<>(transferLimit); + final List transferredDocIds = new ArrayList<>(transferLimit); + while (knnVectorValues.docId() != NO_MORE_DOCS) { + Object vector = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, indexBuildSetup); // append is false to be able to reuse the memory location - boolean transferred = vectorTransfer.transfer(knnVectorValues.conditionalCloneVector(), false); - tranferredDocIds.add(knnVectorValues.docId()); + boolean transferred = vectorTransfer.transfer(vector, false); + transferredDocIds.add(knnVectorValues.docId()); if (transferred) { // Insert vectors long vectorAddress = vectorTransfer.getVectorAddress(); AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.insertToIndex( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexParameters, indexMemoryAddress, engine ); return null; }); - tranferredDocIds.clear(); + transferredDocIds.clear(); } knnVectorValues.nextDoc(); } @@ -89,16 +104,16 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo, final KNNVector long vectorAddress = vectorTransfer.getVectorAddress(); AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.insertToIndex( - intListToArray(tranferredDocIds), + intListToArray(transferredDocIds), vectorAddress, - knnVectorValues.dimension(), + indexBuildSetup.getDimensions(), indexParameters, indexMemoryAddress, engine ); return null; }); - tranferredDocIds.clear(); + transferredDocIds.clear(); } // Write vector diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 61500371b..ed0e8149a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -12,6 +12,7 @@ import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.store.FilterDirectory; +import org.opensearch.common.Nullable; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.DeprecationHandler; @@ -23,11 +24,13 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; import java.io.OutputStream; @@ -60,6 +63,8 @@ public class NativeIndexWriter { private final SegmentWriteState state; private final FieldInfo fieldInfo; private final NativeIndexBuildStrategy indexBuilder; + @Nullable + private final QuantizationState quantizationState; /** * Gets the correct writer type from fieldInfo @@ -68,13 +73,29 @@ public class NativeIndexWriter { * @return correct NativeIndexWriter to make index specified in fieldInfo */ public static NativeIndexWriter getWriter(final FieldInfo fieldInfo, SegmentWriteState state) { - final KNNEngine knnEngine = extractKNNEngine(fieldInfo); - boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); - boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; - if (iterative) { - return new NativeIndexWriter(state, fieldInfo, MemOptimizedNativeIndexBuildStrategy.getInstance()); - } - return new NativeIndexWriter(state, fieldInfo, DefaultIndexBuildStrategy.getInstance()); + return createWriter(fieldInfo, state, null); + } + + /** + * Gets the correct writer type for the specified field, using a given QuantizationModel. + * + * This method returns a NativeIndexWriter instance that is tailored to the specific characteristics + * of the field described by the provided FieldInfo. It determines whether to use a template-based + * writer or an iterative approach based on the engine type and whether the field is associated with a template. + * + * If quantization is required, the QuantizationModel is passed to the writer to facilitate the quantization process. + * + * @param fieldInfo The FieldInfo object containing metadata about the field for which the writer is needed. + * @param state The SegmentWriteState representing the current segment's writing context. + * @param quantizationState The QuantizationState that contains quantization state required for quantization + * @return A NativeIndexWriter instance appropriate for the specified field, configured with or without quantization. + */ + public static NativeIndexWriter getWriter( + final FieldInfo fieldInfo, + final SegmentWriteState state, + final QuantizationState quantizationState + ) { + return createWriter(fieldInfo, state, quantizationState); } /** @@ -137,7 +158,12 @@ private void buildAndWriteIndex(final KNNVectorValues knnVectorValues) throws // TODO: Refactor this so its scalable. Possibly move it out of this class private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNEngine knnEngine) throws IOException { final Map parameters; - final VectorDataType vectorDataType = extractVectorDataType(fieldInfo); + VectorDataType vectorDataType; + if (quantizationState != null) { + vectorDataType = QuantizationService.getInstance().getVectorDataTypeForTransfer(fieldInfo); + } else { + vectorDataType = extractVectorDataType(fieldInfo); + } if (fieldInfo.attributes().containsKey(MODEL_ID)) { Model model = getModel(fieldInfo); parameters = getTemplateParameters(fieldInfo, model); @@ -151,6 +177,7 @@ private BuildIndexParams indexParams(FieldInfo fieldInfo, String indexPath, KNNE .vectorDataType(vectorDataType) .knnEngine(knnEngine) .indexPath(indexPath) + .quantizationState(quantizationState) .build(); } @@ -295,4 +322,26 @@ private void writeFooter(String indexPath, String engineFileName, SegmentWriteSt os.write(byteBuffer.array()); os.close(); } + + /** + * Helper method to create the appropriate NativeIndexWriter based on the field info and quantization state. + * + * @param fieldInfo The FieldInfo object containing metadata about the field for which the writer is needed. + * @param state The SegmentWriteState representing the current segment's writing context. + * @param quantizationState The QuantizationState that contains quantization state required for quantization, can be null. + * @return A NativeIndexWriter instance appropriate for the specified field, configured with or without quantization. + */ + private static NativeIndexWriter createWriter( + final FieldInfo fieldInfo, + final SegmentWriteState state, + @Nullable final QuantizationState quantizationState + ) { + final KNNEngine knnEngine = extractKNNEngine(fieldInfo); + boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); + boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + NativeIndexBuildStrategy strategy = iterative + ? MemOptimizedNativeIndexBuildStrategy.getInstance() + : DefaultIndexBuildStrategy.getInstance(); + return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java new file mode 100644 index 000000000..8fec1af6d --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import lombok.experimental.UtilityClass; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; + +@UtilityClass +class QuantizationIndexUtils { + + /** + * Processes and returns the vector based on whether quantization is applied or not. + * + * @param knnVectorValues the KNN vector values to be processed. + * @param indexBuildSetup the setup containing quantization state and output, along with other parameters. + * @return the processed vector, either quantized or original. + * @throws IOException if an I/O error occurs during processing. + */ + static Object processAndReturnVector(KNNVectorValues knnVectorValues, IndexBuildSetup indexBuildSetup) throws IOException { + QuantizationService quantizationService = QuantizationService.getInstance(); + if (indexBuildSetup.getQuantizationState() != null && indexBuildSetup.getQuantizationOutput() != null) { + quantizationService.quantize( + indexBuildSetup.getQuantizationState(), + knnVectorValues.getVector(), + indexBuildSetup.getQuantizationOutput() + ); + return indexBuildSetup.getQuantizationOutput().getQuantizedVector(); + } else { + return knnVectorValues.conditionalCloneVector(); + } + } + + /** + * Prepares the quantization setup including bytes per vector and dimensions. + * + * @param knnVectorValues the KNN vector values. + * @param indexInfo the index build parameters. + * @return an instance of QuantizationSetup containing relevant information. + */ + static IndexBuildSetup prepareIndexBuild(KNNVectorValues knnVectorValues, BuildIndexParams indexInfo) { + QuantizationState quantizationState = indexInfo.getQuantizationState(); + QuantizationOutput quantizationOutput = null; + QuantizationService quantizationService = QuantizationService.getInstance(); + + int bytesPerVector; + int dimensions; + + if (quantizationState != null) { + bytesPerVector = quantizationState.getBytesPerVector(); + dimensions = quantizationState.getDimensions(); + quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams()); + } else { + bytesPerVector = knnVectorValues.bytesPerVector(); + dimensions = knnVectorValues.dimension(); + } + + return new IndexBuildSetup(bytesPerVector, dimensions, quantizationOutput, quantizationState); + } +} diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java index af43ff37e..78674c64b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -8,8 +8,10 @@ import lombok.Builder; import lombok.ToString; import lombok.Value; +import org.opensearch.common.Nullable; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.util.Map; @@ -22,5 +24,9 @@ public class BuildIndexParams { String indexPath; VectorDataType vectorDataType; Map parameters; - // TODO: Add quantization state as parameter to build index + /** + * An optional quantization state that contains required information for quantization + */ + @Nullable + QuantizationState quantizationState; } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 8c2dfc126..7ae403445 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -99,9 +99,7 @@ static KNNLibraryIndexingContext adjustIndexDescription( // We need to update the prefix used to create the faiss index if we are using the quantization // framework if (encoderContext != null && Objects.equals(encoderContext.getName(), QFrameBitEncoder.NAME)) { - // TODO: Uncomment to use Quantization framework - // leaving commented now just so it wont fail creating faiss indices. - // prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; + prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; } if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java new file mode 100644 index 000000000..4cf68d16c --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/quantizationService/KNNVectorQuantizationTrainingRequest.java @@ -0,0 +1,55 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.quantizationService; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +import java.io.IOException; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +/** + * KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class. + * It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID. + */ +@Log4j2 +final class KNNVectorQuantizationTrainingRequest extends TrainingRequest { + + private final KNNVectorValues knnVectorValues; + private int lastIndex; + + /** + * Constructs a new QuantizationFloatVectorTrainingRequest. + * + * @param knnVectorValues the KNNVectorValues instance containing the vectors. + */ + KNNVectorQuantizationTrainingRequest(KNNVectorValues knnVectorValues) { + super((int) knnVectorValues.totalLiveDocs()); + this.knnVectorValues = knnVectorValues; + this.lastIndex = 0; + } + + /** + * Retrieves the vector associated with the specified document ID. + * + * @param position the document ID. + * @return the float vector corresponding to the specified document ID, or null if the docId is invalid. + */ + @Override + public T getVectorAtThePosition(int position) throws IOException { + while (lastIndex <= position) { + lastIndex++; + if (knnVectorValues.docId() == NO_MORE_DOCS) { + return null; + } + knnVectorValues.nextDoc(); + } + // Return the vector and the updated index + return knnVectorValues.getVector(); + } +} diff --git a/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java new file mode 100644 index 000000000..a9e3cc715 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/quantizationService/QuantizationService.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.quantizationService; + +import lombok.AccessLevel; +import lombok.NoArgsConstructor; +import org.apache.lucene.index.FieldInfo; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.factory.QuantizerFactory; +import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.quantizer.Quantizer; +import java.io.IOException; + +import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig; + +/** + * A singleton class responsible for handling the quantization process, including training a quantizer + * and applying quantization to vectors. This class is designed to be thread-safe. + * + * @param The type of the input vectors to be quantized. + * @param The type of the quantized output vectors. + */ +@NoArgsConstructor(access = AccessLevel.PRIVATE) +public final class QuantizationService { + + /** + * The singleton instance of the {@link QuantizationService} class. + */ + private static final QuantizationService INSTANCE = new QuantizationService<>(); + + /** + * Returns the singleton instance of the {@link QuantizationService} class. + * + * @param The type of the input vectors to be quantized. + * @param The type of the quantized output vectors. + * @return The singleton instance of {@link QuantizationService}. + */ + public static QuantizationService getInstance() { + return (QuantizationService) INSTANCE; + } + + /** + * Trains a quantizer using the provided {@link KNNVectorValues} and returns the resulting + * {@link QuantizationState}. The quantizer is determined based on the given {@link QuantizationParams}. + * + * @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization. + * @param knnVectorValues The {@link KNNVectorValues} representing the vector data to be used for training. + * @return The {@link QuantizationState} containing the state of the trained quantizer. + * @throws IOException If an I/O error occurs during the training process. + */ + public QuantizationState train(final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues) + throws IOException { + Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); + + // Create the training request from the vector values + KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues); + + // Train the quantizer and return the quantization state + return quantizer.train(trainingRequest); + } + + /** + * Applies quantization to the given vector using the specified {@link QuantizationState} and + * {@link QuantizationOutput}. + * + * @param quantizationState The {@link QuantizationState} containing the state of the trained quantizer. + * @param vector The vector to be quantized. + * @param quantizationOutput The {@link QuantizationOutput} to store the quantized vector. + * @return The quantized vector as an object of type {@code R}. + */ + public R quantize(final QuantizationState quantizationState, final T vector, final QuantizationOutput quantizationOutput) { + Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationState.getQuantizationParams()); + quantizer.quantize(vector, quantizationState, quantizationOutput); + return quantizationOutput.getQuantizedVector(); + } + + /** + * Retrieves quantization parameters from the FieldInfo. + */ + public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { + QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { + return new ScalarQuantizationParams(quantizationConfig.getQuantizationType()); + } + return null; + } + + /** + * Retrieves the appropriate {@link VectorDataType} to be used during the transfer of vectors for indexing or merging. + * This method is intended to determine the correct vector data type based on the provided {@link FieldInfo}. + * + * @param fieldInfo The {@link FieldInfo} object containing metadata about the field for which the vector data type + * is being determined. + * @return The {@link VectorDataType} to be used during the vector transfer process + */ + public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { + QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { + return VectorDataType.BINARY; + } + return null; + } + + /** + * Creates the appropriate {@link QuantizationOutput} based on the given {@link QuantizationParams}. + * + * @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization. + * @return The {@link QuantizationOutput} corresponding to the provided parameters. + * @throws IllegalArgumentException If the quantization parameters are unsupported. + */ + @SuppressWarnings("unchecked") + public QuantizationOutput createQuantizationOutput(final QuantizationParams quantizationParams) { + if (quantizationParams instanceof ScalarQuantizationParams) { + ScalarQuantizationParams scalarParams = (ScalarQuantizationParams) quantizationParams; + return (QuantizationOutput) new BinaryQuantizationOutput(scalarParams.getSqType().getId()); + } + throw new IllegalArgumentException("Unsupported quantization parameters: " + quantizationParams.getClass().getName()); + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java index 33e775cad..531a70851 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/DefaultQuantizationState.java @@ -64,4 +64,19 @@ public byte[] toByteArray() throws IOException { public static DefaultQuantizationState fromByteArray(final byte[] bytes) throws IOException, ClassNotFoundException { return (DefaultQuantizationState) QuantizationStateSerializer.deserialize(bytes, DefaultQuantizationState::new); } + + @Override + public int getBytesPerVector() { + return 0; + } + + @Override + public int getDimensions() { + return 0; + } + + @Override + public long ramBytesUsed() { + return 0; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java index 2778a6cf4..79ce7b955 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/MultiBitScalarQuantizationState.java @@ -8,6 +8,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; +import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -124,4 +125,49 @@ public byte[] toByteArray() throws IOException { public static MultiBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { return (MultiBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, MultiBitScalarQuantizationState::new); } + + /** + * Calculates and returns the number of bytes stored per vector after quantization. + * + * @return the number of bytes stored per vector. + */ + @Override + public int getBytesPerVector() { + // Check if thresholds are null or have invalid structure + if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) { + throw new IllegalStateException("Error in getBytesStoredPerVector: The thresholds array is not initialized."); + } + + // Calculate the number of bytes required for multi-bit quantization + return thresholds.length * thresholds[0].length; + } + + @Override + public int getDimensions() { + // For multi-bit quantization, the dimension for indexing is the number of rows * columns in the thresholds array. + // Where number of column reprensents Dimesion of Original vector and number of rows equals to number of bits + // Check if thresholds are null or have invalid structure + if (thresholds == null || thresholds.length == 0 || thresholds[0] == null) { + throw new IllegalStateException("Error in getting Dimension: The thresholds array is not initialized."); + } + return thresholds.length * thresholds[0].length; + } + + /** + * Calculates the memory usage of the MultiBitScalarQuantizationState object in bytes. + * This method computes the shallow size of the instance itself, the shallow size of the + * quantization parameters, and the memory usage of the 2D thresholds array. + * + * @return The estimated memory usage of the MultiBitScalarQuantizationState object in bytes. + */ + @Override + public long ramBytesUsed() { + long size = RamUsageEstimator.shallowSizeOfInstance(MultiBitScalarQuantizationState.class); + size += RamUsageEstimator.shallowSizeOf(quantizationParams); + size += RamUsageEstimator.shallowSizeOf(thresholds); // shallow size of the 2D array (array of references to rows) + for (float[] row : thresholds) { + size += RamUsageEstimator.sizeOf(row); // size of each row in the 2D array + } + return size; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java index 9998b87e8..9c4ff7460 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/OneBitScalarQuantizationState.java @@ -8,6 +8,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.NoArgsConstructor; +import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.Version; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -107,4 +108,36 @@ public byte[] toByteArray() throws IOException { public static OneBitScalarQuantizationState fromByteArray(final byte[] bytes) throws IOException { return (OneBitScalarQuantizationState) QuantizationStateSerializer.deserialize(bytes, OneBitScalarQuantizationState::new); } + + /** + * Calculates and returns the number of bytes stored per vector after quantization. + * + * @return the number of bytes stored per vector. + */ + @Override + public int getBytesPerVector() { + // Calculate the number of bytes required for one-bit quantization + return meanThresholds.length; + } + + @Override + public int getDimensions() { + // For one-bit quantization, the dimension for indexing is just the length of the thresholds array. + return meanThresholds.length; + } + + /** + * Calculates the memory usage of the OneBitScalarQuantizationState object in bytes. + * This method computes the shallow size of the instance itself, the shallow size of the + * quantization parameters, and the memory usage of the mean thresholds array. + * + * @return The estimated memory usage of the OneBitScalarQuantizationState object in bytes. + */ + @Override + public long ramBytesUsed() { + long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class); + size += RamUsageEstimator.shallowSizeOf(quantizationParams); + size += RamUsageEstimator.sizeOf(meanThresholds); + return size; + } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java index e32df8bc3..18ee813fc 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/QuantizationState.java @@ -29,4 +29,28 @@ public interface QuantizationState extends Writeable { * @throws IOException if an I/O error occurs during serialization. */ byte[] toByteArray() throws IOException; + + /** + * Calculates the number of bytes stored per vector after quantization. + * This method can be overridden by implementing classes to provide the specific calculation. + * + * @return the number of bytes stored per vector. + */ + int getBytesPerVector(); + + /** + * Returns the effective dimension used for indexing after quantization. + * For one-bit quantization, this might correspond to the length of thresholds. + * For multi-bit quantization, this might correspond to rows * columns of the thresholds matrix. + * + * @return the effective dimension for indexing. + */ + int getDimensions(); + + /** + * Estimates the memory usage of the quantization state in bytes. + * + * @return the memory usage in bytes. + */ + long ramBytesUsed(); } diff --git a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java index 54ebe311c..d8b0eab10 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java +++ b/src/main/java/org/opensearch/knn/quantization/models/requests/TrainingRequest.java @@ -8,6 +8,8 @@ import lombok.AllArgsConstructor; import lombok.Getter; +import java.io.IOException; + /** * TrainingRequest represents a request for training a quantizer. * @@ -24,8 +26,8 @@ public abstract class TrainingRequest { /** * Returns the vector corresponding to the specified document ID. * - * @param docId the document ID. + * @param position the document position. * @return the vector corresponding to the specified document ID. */ - public abstract T getVectorByDocId(int docId); + public abstract T getVectorAtThePosition(int position) throws IOException; } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 12a5d1013..a0e6ec402 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -15,6 +15,9 @@ import org.opensearch.knn.quantization.sampler.Sampler; import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; +import oshi.util.tuples.Pair; + +import java.io.IOException; /** * MultiBitScalarQuantizer is responsible for quantizing vectors into multi-bit representations per dimension. @@ -105,14 +108,11 @@ public MultiBitScalarQuantizer(final int bitsPerCoordinate, final int samplingSi * @return a MultiBitScalarQuantizationState containing the computed thresholds. */ @Override - public QuantizationState train(final TrainingRequest trainingRequest) { + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - int dimension = trainingRequest.getVectorByDocId(sampledIndices[0]).length; - float[] meanArray = new float[dimension]; - float[] stdDevArray = new float[dimension]; // Calculate sum, mean, and standard deviation in one pass - QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices, meanArray, stdDevArray); - float[][] thresholds = calculateThresholds(meanArray, stdDevArray, dimension); + Pair meanAndStdDev = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices); + float[][] thresholds = calculateThresholds(meanAndStdDev.getA(), meanAndStdDev.getB()); ScalarQuantizationParams params = (bitsPerCoordinate == 2) ? new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT) : new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); @@ -148,10 +148,10 @@ public void quantize(final float[] vector, final QuantizationState state, final * * @param meanArray the mean for each dimension. * @param stdDevArray the standard deviation for each dimension. - * @param dimension the number of dimensions in the vectors. * @return the thresholds for quantization. */ - private float[][] calculateThresholds(final float[] meanArray, final float[] stdDevArray, final int dimension) { + private float[][] calculateThresholds(final float[] meanArray, final float[] stdDevArray) { + int dimension = meanArray.length; float[][] thresholds = new float[bitsPerCoordinate][dimension]; float coef = bitsPerCoordinate + 1; for (int i = 0; i < bitsPerCoordinate; i++) { diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index a0f6a26b4..ac48a9523 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -15,6 +15,8 @@ import org.opensearch.knn.quantization.sampler.SamplerType; import org.opensearch.knn.quantization.sampler.SamplingFactory; +import java.io.IOException; + /** * OneBitScalarQuantizer is responsible for quantizing vectors using a single bit per dimension. * It computes the mean of each dimension during training and then uses these means as thresholds @@ -56,7 +58,7 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) { * @return a OneBitScalarQuantizationState containing the calculated means. */ @Override - public QuantizationState train(final TrainingRequest trainingRequest) { + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds); return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index c0b297f5d..521863205 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -9,6 +9,8 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import java.io.IOException; + /** * The Quantizer interface defines the methods required for training and quantizing vectors * in the context of K-Nearest Neighbors (KNN) and similar machine learning tasks. @@ -27,7 +29,7 @@ public interface Quantizer { * @param trainingRequest the request containing data and parameters for training. * @return a QuantizationState containing the learned parameters. */ - QuantizationState train(TrainingRequest trainingRequest); + QuantizationState train(TrainingRequest trainingRequest) throws IOException; /** * Quantizes the provided vector using the specified quantization state. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java index 16f969973..bac2067c0 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/QuantizerHelper.java @@ -7,6 +7,9 @@ import org.opensearch.knn.quantization.models.requests.TrainingRequest; import lombok.experimental.UtilityClass; +import oshi.util.tuples.Pair; + +import java.io.IOException; /** * Utility class providing common methods for quantizer operations, such as parameter validation and @@ -19,16 +22,17 @@ class QuantizerHelper { * Calculates the mean vector from a set of sampled vectors. * * @param samplingRequest The {@link TrainingRequest} containing the dataset and methods to access vectors by their indices. - * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. + * @param sampledIndices An array of indices representing the sampled vectors to be used for mean calculation. * @return A float array representing the mean vector of the sampled vectors. * @throws IllegalArgumentException If any of the vectors at the sampled indices are null. - * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. + * @throws IllegalStateException If the mean array is unexpectedly null after processing the vectors. */ - static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) { + static float[] calculateMeanThresholds(TrainingRequest samplingRequest, int[] sampledIndices) throws IOException { int totalSamples = sampledIndices.length; float[] mean = null; + int lastIndex = 0; for (int docId : sampledIndices) { - float[] vector = samplingRequest.getVectorByDocId(docId); + float[] vector = samplingRequest.getVectorAtThePosition(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } @@ -49,36 +53,56 @@ static float[] calculateMeanThresholds(TrainingRequest samplingRequest, } /** - * Calculates the mean and StdDev per dimension for sampled vectors. + * Calculates the mean and standard deviation for each dimension of the vectors in the training request. + *

+ * This method processes the vectors specified by the sampled indices and calculates both the mean and + * standard deviation in one pass. The results are returned as a pair of arrays: one for the means and + * one for the standard deviations. * - * @param trainingRequest the request containing the data and parameters for training. - * @param sampledIndices the indices of the sampled vectors. - * @param meanArray the array to store the sum and then the mean of each dimension. - * @param stdDevArray the array to store the sum of squares and then the standard deviation of each dimension. + * @param trainingRequest The request containing the data and parameters for training. + * @param sampledIndices An array of document IDs representing the sampled indices to be processed. + * @return A Pair containing two float arrays: the first array represents the mean of each dimension, + * and the second array represents the standard deviation of each dimension. + * @throws IllegalArgumentException if any of the vectors at the sampled indices are null. + * @throws IllegalStateException if the mean or standard deviation arrays are not initialized after processing. */ - static void calculateMeanAndStdDev( - TrainingRequest trainingRequest, - int[] sampledIndices, - float[] meanArray, - float[] stdDevArray - ) { + static Pair calculateMeanAndStdDev(TrainingRequest trainingRequest, int[] sampledIndices) + throws IOException { + float[] meanArray = null; + float[] stdDevArray = null; int totalSamples = sampledIndices.length; - int dimension = meanArray.length; + int lastIndex = 0; for (int docId : sampledIndices) { - float[] vector = trainingRequest.getVectorByDocId(docId); + float[] vector = trainingRequest.getVectorAtThePosition(docId); if (vector == null) { throw new IllegalArgumentException("Vector at sampled index " + docId + " is null."); } + int dimension = vector.length; + + // Initialize meanArray and stdDevArray on the first iteration + if (meanArray == null) { + meanArray = new float[dimension]; + } + if (stdDevArray == null) { + stdDevArray = new float[dimension]; + } + for (int j = 0; j < dimension; j++) { meanArray[j] += vector[j]; stdDevArray[j] += vector[j] * vector[j]; } } + if (meanArray == null || stdDevArray == null) { + throw new IllegalStateException("Mean and StdDev should not be null after processing vectors."); + } // Calculate mean and standard deviation in one pass - for (int j = 0; j < dimension; j++) { + for (int j = 0; j < meanArray.length; j++) { meanArray[j] = meanArray[j] / totalSamples; stdDevArray[j] = (float) Math.sqrt((stdDevArray[j] / totalSamples) - (meanArray[j] * meanArray[j])); } + + // Return both arrays as a Pair + return new Pair<>(meanArray, stdDevArray); } } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java index 3810d46fd..85b5d07e6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsFormatTests.java @@ -47,8 +47,11 @@ import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.engine.qframe.QuantizationConfigParser; import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; import java.util.Arrays; @@ -61,6 +64,7 @@ public class NativeEngines990KnnVectorsFormatTests extends KNNTestCase { private static final String FLAT_VECTOR_FILE_EXT = ".vec"; private static final String HNSW_FILE_EXT = ".hnsw"; private static final String FLOAT_VECTOR_FIELD = "float_field"; + private static final String FLOAT_VECTOR_FIELD_BINARY = "float_field_binary"; private static final String BYTE_VECTOR_FIELD = "byte_field"; private Directory dir; private RandomIndexWriter indexWriter; @@ -99,14 +103,32 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc float[] floatVector = { 1.0f, 3.0f, 4.0f }; byte[] byteVector = { 6, 14 }; + FieldType fieldTypeForFloat = createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForFloat.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); + fieldTypeForFloat.freeze(); + addFieldToIndex(new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, fieldTypeForFloat), indexWriter); + FieldType fieldTypeForByte = createVectorField(2, VectorEncoding.BYTE, VectorDataType.BINARY); + fieldTypeForByte.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); + fieldTypeForByte.freeze(); + addFieldToIndex(new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, fieldTypeForByte), indexWriter); + + float[] floatVectorForBinaryQuantization_1 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + float[] floatVectorForBinaryQuantization_2 = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); + fieldTypeForBinaryQuantization.freeze(); + addFieldToIndex( - new KnnFloatVectorField(FLOAT_VECTOR_FIELD, floatVector, createVectorField(3, VectorEncoding.FLOAT32, VectorDataType.FLOAT)), + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization_1, fieldTypeForBinaryQuantization), indexWriter ); addFieldToIndex( - new KnnByteVectorField(BYTE_VECTOR_FIELD, byteVector, createVectorField(2, VectorEncoding.BYTE, VectorDataType.BINARY)), + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization_2, fieldTypeForBinaryQuantization), indexWriter ); + final IndexReader indexReader = indexWriter.getReader(); // ensuring segments are created indexWriter.flush(); @@ -129,7 +151,7 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc if (segmentReader.getSegmentInfo().info.getUseCompoundFile() == false) { final List vecfiles = getFilesFromSegment(dir, FLAT_VECTOR_FILE_EXT); // 2 .vec files will be created as we are using per field vectors format. - assertEquals(2, vecfiles.size()); + assertEquals(3, vecfiles.size()); } final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD); @@ -144,6 +166,12 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc assertEquals(1, byteVectorValues.size()); assertEquals(2, byteVectorValues.dimension()); + final FloatVectorValues floatVectorValuesForBinaryQuantization = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); + floatVectorValuesForBinaryQuantization.nextDoc(); + assertArrayEquals(floatVectorForBinaryQuantization_1, floatVectorValuesForBinaryQuantization.vectorValue(), 0.0f); + assertEquals(2, floatVectorValuesForBinaryQuantization.size()); + assertEquals(8, floatVectorValuesForBinaryQuantization.dimension()); + Assert.assertThrows( UnsupportedOperationException.class, () -> leafReader.searchNearestVectors(FLOAT_VECTOR_FIELD, floatVector, 10, new Bits.MatchAllBits(1), 10) @@ -157,6 +185,42 @@ public void testNativeEngineVectorFormat_whenMultipleVectorFieldIndexed_thenSucc indexReader.close(); } + @SneakyThrows + public void testNativeEngineVectorFormat_whenBinaryQuantizationApplied_thenSuccess() { + setup(); + float[] floatVectorForBinaryQuantization = { 1.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f }; + FieldType fieldTypeForBinaryQuantization = createVectorField(8, VectorEncoding.FLOAT32, VectorDataType.FLOAT); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"BHNSW32\", \"spaceType\": \"l2\"}"); + QuantizationConfig quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.ONE_BIT).build(); + fieldTypeForBinaryQuantization.putAttribute(KNNConstants.QFRAMEWORK_CONFIG, QuantizationConfigParser.toCsv(quantizationConfig)); + + addFieldToIndex( + new KnnFloatVectorField(FLOAT_VECTOR_FIELD_BINARY, floatVectorForBinaryQuantization, fieldTypeForBinaryQuantization), + indexWriter + ); + + final IndexReader indexReader = indexWriter.getReader(); + // ensuring segments are created + indexWriter.flush(); + indexWriter.commit(); + indexWriter.close(); + IndexSearcher searcher = new IndexSearcher(indexReader); + final LeafReader leafReader = searcher.getLeafContexts().get(0).reader(); + SegmentReader segmentReader = Lucene.segmentReader(leafReader); + if (segmentReader.getSegmentInfo().info.getUseCompoundFile() == false) { + final List vecfiles = getFilesFromSegment(dir, FLAT_VECTOR_FILE_EXT); + // 2 .vec files will be created as we are using per field vectors format. + assertEquals(1, vecfiles.size()); + } + + final FloatVectorValues floatVectorValues = leafReader.getFloatVectorValues(FLOAT_VECTOR_FIELD_BINARY); + floatVectorValues.nextDoc(); + assertArrayEquals(floatVectorForBinaryQuantization, floatVectorValues.vectorValue(), 0.0f); + assertEquals(1, floatVectorValues.size()); + assertEquals(8, floatVectorValues.dimension()); + indexReader.close(); + } + private List getFilesFromSegment(Directory dir, String fileFormat) throws IOException { return Arrays.stream(dir.listAll()).filter(x -> x.contains(fileFormat)).collect(Collectors.toList()); } @@ -203,13 +267,11 @@ private FieldType createVectorField(int dimension, VectorEncoding vectorEncoding nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_M, "32"); nativeVectorField.putAttribute(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, "512"); nativeVectorField.putAttribute(KNNConstants.VECTOR_DATA_TYPE_FIELD, vectorDataType.getValue()); - nativeVectorField.putAttribute(KNNConstants.PARAMETERS, "{ \"index_description\":\"HNSW16,Flat\", \"spaceType\": \"l2\"}"); nativeVectorField.setVectorAttributes( dimension, vectorEncoding, SpaceType.L2.getKnnVectorSimilarityFunction().getVectorSimilarityFunction() ); - nativeVectorField.freeze(); return nativeVectorField; } } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index 34a333471..0b5a06dfc 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -17,11 +17,14 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; import java.util.List; @@ -102,6 +105,119 @@ public void testBuildAndWrite() { } } + @SneakyThrows + public void testBuildAndWrite_withQuantization() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(Object.class); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedKNNSettings = mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = mockStatic(OffHeapVectorTransferFactory.class); + MockedStatic mockedQuantizationIntegration = mockStatic(QuantizationService.class) + ) { + + // Limits transfer to 2 vectors + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); + + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + QuantizationService quantizationService = mock(QuantizationService.class); + mockedQuantizationIntegration.when(QuantizationService::getInstance).thenReturn(quantizationService); + + QuantizationState quantizationState = mock(QuantizationState.class); + ArgumentCaptor vectorCaptor = ArgumentCaptor.forClass(float[].class); + // New: Create QuantizationOutput and mock the quantization process + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 }); + when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn( + quantizationOutput + ); + + // Quantize the vector with the quantization output + when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer( + invocation -> { + quantizationOutput.getQuantizedVector(); + return quantizationOutput.getQuantizedVector(); + } + ); + when(quantizationState.getDimensions()).thenReturn(2); + when(quantizationState.getBytesPerVector()).thenReturn(8); + + when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) + .thenReturn(true) + .thenReturn(false); + when(offHeapVectorTransfer.flush(false)).thenReturn(true); + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .quantizationState(quantizationState) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.initIndex( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + Map.of("index", "param"), + KNNEngine.FAISS + ) + ); + + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + assertEquals(200L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + verify(offHeapVectorTransfer, times(0)).reset(); + + for (Object vector : vectorTransferCapture.getAllValues()) { + // Assert that the vector is in byte[] format due to quantization + assertTrue(vector instanceof byte[]); + } + } + } + @SneakyThrows public void testBuildAndWriteWithModel() { // Given diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 2ecfe9259..3bfec4104 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -16,10 +16,13 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.quantizationService.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; import java.util.List; @@ -126,4 +129,119 @@ public void testBuildAndWrite() { } } } + + @SneakyThrows + public void testBuildAndWrite_withQuantization() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(Object.class); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedKNNSettings = Mockito.mockStatic(KNNSettings.class); + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( + OffHeapVectorTransferFactory.class + ); + MockedStatic mockedQuantizationIntegration = Mockito.mockStatic(QuantizationService.class) + ) { + + // Limits transfer to 2 vectors + mockedKNNSettings.when(KNNSettings::getVectorStreamingMemoryLimit).thenReturn(new ByteSizeValue(16)); + mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); + + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 2)) + .thenReturn(offHeapVectorTransfer); + + QuantizationService quantizationService = mock(QuantizationService.class); + mockedQuantizationIntegration.when(QuantizationService::getInstance).thenReturn(quantizationService); + + QuantizationState quantizationState = mock(QuantizationState.class); + ArgumentCaptor vectorCaptor = ArgumentCaptor.forClass(float[].class); + // New: Create QuantizationOutput and mock the quantization process + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 1, 2 }); + when(quantizationService.createQuantizationOutput(eq(quantizationState.getQuantizationParams()))).thenReturn( + quantizationOutput + ); + + // Quantize the vector with the quantization output + when(quantizationService.quantize(eq(quantizationState), vectorCaptor.capture(), eq(quantizationOutput))).thenAnswer( + invocation -> { + quantizationOutput.getQuantizedVector(); + return quantizationOutput.getQuantizedVector(); + } + ); + when(quantizationState.getDimensions()).thenReturn(2); + when(quantizationState.getBytesPerVector()).thenReturn(8); + + when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) + .thenReturn(true) + .thenReturn(false); + when(offHeapVectorTransfer.flush(false)).thenReturn(true); + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexPath("indexPath") + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .quantizationState(quantizationState) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams, knnVectorValues); + + // Then + mockedJNIService.verify( + () -> JNIService.initIndex( + knnVectorValues.totalLiveDocs(), + knnVectorValues.dimension(), + Map.of("index", "param"), + KNNEngine.FAISS + ) + ); + + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + assertEquals(200L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + verify(offHeapVectorTransfer, times(0)).reset(); + + for (Object vector : vectorTransferCapture.getAllValues()) { + // Assert that the vector is in byte[] format due to quantization + assertTrue(vector instanceof byte[]); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java new file mode 100644 index 000000000..30a2098dd --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.codec.nativeindex; + +import org.junit.Before; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.nativeindex.model.BuildIndexParams; +import org.opensearch.knn.index.quantizationService.QuantizationService; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; + +import java.io.IOException; +import java.util.List; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class QuantizationIndexUtilsTests extends KNNTestCase { + + private KNNVectorValues knnVectorValues; + private BuildIndexParams buildIndexParams; + private QuantizationService quantizationService; + + @Before + public void setUp() throws Exception { + super.setUp(); + quantizationService = mock(QuantizationService.class); + + // Predefined float vectors for testing + List floatVectors = List.of( + new float[] { 1.0f, 2.0f, 3.0f }, + new float[] { 4.0f, 5.0f, 6.0f }, + new float[] { 7.0f, 8.0f, 9.0f } + ); + + // Use the predefined vectors to create KNNVectorValues + knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) + ); + + // Mocking BuildIndexParams + buildIndexParams = mock(BuildIndexParams.class); + } + + public void testPrepareIndexBuild_withQuantization_success() { + QuantizationState quantizationState = mock(OneBitScalarQuantizationState.class); + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + when(quantizationOutput.getQuantizedVector()).thenReturn(new byte[] { 0x01 }); + when(quantizationState.getDimensions()).thenReturn(2); + when(quantizationState.getBytesPerVector()).thenReturn(8); + when(quantizationState.getQuantizationParams()).thenReturn(params); + + when(buildIndexParams.getQuantizationState()).thenReturn(quantizationState); + + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + + assertNotNull(setup.getQuantizationState()); + assertEquals(8, setup.getBytesPerVector()); + assertEquals(2, setup.getDimensions()); + } + + public void testPrepareIndexBuild_withoutQuantization_success() throws IOException { + when(buildIndexParams.getQuantizationState()).thenReturn(null); + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + assertNull(setup.getQuantizationState()); + assertEquals(knnVectorValues.bytesPerVector(), setup.getBytesPerVector()); + assertEquals(knnVectorValues.dimension(), setup.getDimensions()); + } + + public void testProcessAndReturnVector_withoutQuantization_success() throws IOException { + // Set up the BuildIndexParams to return no quantization + when(buildIndexParams.getQuantizationState()).thenReturn(null); + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + // Process and return the vector + assertNotNull(QuantizationIndexUtils.processAndReturnVector(knnVectorValues, setup)); + } + + public void testProcessAndReturnVector_withQuantization_success() throws IOException { + // Set up quantization state and output + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + knnVectorValues.nextDoc(); + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + QuantizationOutput quantizationOutput = mock(QuantizationOutput.class); + when(buildIndexParams.getQuantizationState()).thenReturn(state); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + // Process and return the vector + Object result = QuantizationIndexUtils.processAndReturnVector(knnVectorValues, setup); + assertTrue(result instanceof byte[]); + assertArrayEquals(new byte[] { 0x00 }, (byte[]) result); + } +} diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java index c9ce50f22..75da6811e 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java @@ -246,7 +246,7 @@ public void testGetKNNLibraryIndexingContext_whenMethodIsHNSWWithQFrame_thenCrea .vectorDataType(VectorDataType.FLOAT) .build(); int m = 88; - String expectedIndexDescription = "HNSW" + m + ",Flat"; + String expectedIndexDescription = "BHNSW" + m + ",Flat"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .field(NAME, METHOD_HNSW) @@ -285,7 +285,7 @@ public void testGetKNNLibraryIndexingContext_whenMethodIsIVFWithQFrame_thenCreat .vectorDataType(VectorDataType.FLOAT) .build(); int nlist = 88; - String expectedIndexDescription = "IVF" + nlist + ",Flat"; + String expectedIndexDescription = "BIVF" + nlist + ",Flat"; XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() .startObject() .field(NAME, METHOD_IVF) diff --git a/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java new file mode 100644 index 000000000..886dbeabc --- /dev/null +++ b/src/test/java/org/opensearch/knn/index/quantizationService/QuantizationServiceTests.java @@ -0,0 +1,159 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.quantizationService; + +import org.opensearch.knn.KNNTestCase; +import org.junit.Before; + +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; +import org.opensearch.knn.index.vectorvalues.TestVectorValues; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import java.io.IOException; +import java.util.List; + +public class QuantizationServiceTests extends KNNTestCase { + private QuantizationService quantizationService; + private KNNVectorValues knnVectorValues; + + @Before + public void setUp() throws Exception { + super.setUp(); + quantizationService = QuantizationService.getInstance(); + + // Predefined float vectors for testing + List floatVectors = List.of( + new float[] { 1.0f, 2.0f, 3.0f }, + new float[] { 4.0f, 5.0f, 6.0f }, + new float[] { 7.0f, 8.0f, 9.0f } + ); + + // Use the predefined vectors to create KNNVectorValues + knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) + ); + } + + public void testTrain_oneBitQuantizer_success() throws IOException { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + + assertTrue(quantizationState instanceof OneBitScalarQuantizationState); + OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; + + // Validate the mean thresholds obtained from the training + float[] thresholds = oneBitState.getMeanThresholds(); + assertNotNull("Thresholds should not be null", thresholds); + assertEquals("Thresholds array length should match the dimension", 3, thresholds.length); + + // Example expected thresholds based on the provided vectors + assertArrayEquals(new float[] { 4.0f, 5.0f, 6.0f }, thresholds, 0.1f); + } + + public void testTrain_twoBitQuantizer_success() throws IOException { + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); + + assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; + + // Validate the thresholds obtained from the training + float[][] thresholds = multiBitState.getThresholds(); + assertNotNull("Thresholds should not be null", thresholds); + assertEquals("Number of bits should match the number of rows", 2, thresholds.length); + assertEquals("Thresholds array length should match the dimension", 3, thresholds[0].length); + + // // Example expected thresholds for two-bit quantization + float[][] expectedThresholds = { + { 3.1835034f, 4.1835036f, 5.1835036f }, // First bit level + { 4.816497f, 5.816497f, 6.816497f } // Second bit level + }; + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(expectedThresholds[i], thresholds[i], 0.1f); + } + } + + public void testTrain_fourBitQuantizer_success() throws IOException { + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); + + assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); + MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; + + // Validate the thresholds obtained from the training + float[][] thresholds = multiBitState.getThresholds(); + assertNotNull("Thresholds should not be null", thresholds); + assertEquals("Number of bits should match the number of rows", 4, thresholds.length); + assertEquals("Thresholds array length should match the dimension", 3, thresholds[0].length); + + // // Example expected thresholds for four-bit quantization + float[][] expectedThresholds = { + { 2.530306f, 3.530306f, 4.530306f }, // First bit level + { 3.510102f, 4.5101023f, 5.5101023f }, // Second bit level + { 4.489898f, 5.489898f, 6.489898f }, // Third bit level + { 5.469694f, 6.469694f, 7.469694f } // Fourth bit level + }; + for (int i = 0; i < thresholds.length; i++) { + assertArrayEquals(expectedThresholds[i], thresholds[i], 0.1f); + } + } + + public void testQuantize_oneBitQuantizer_success() throws IOException { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); + + byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 1.0f, 2.0f, 3.0f }, quantizationOutput); + + assertNotNull("Quantized vector should not be null", quantizedVector); + + // Expected quantized vector values for one-bit quantization (packed bits) + byte[] expectedQuantizedVector = new byte[] { 0 }; // 00000000 (all bits are 0) + assertArrayEquals(expectedQuantizedVector, quantizedVector); + } + + public void testQuantize_twoBitQuantizer_success() throws IOException { + ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); + byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); + + assertNotNull("Quantized vector should not be null", quantizedVector); + + // Expected quantized vector values for two-bit quantization (packed bits) + byte[] expectedQuantizedVector = new byte[] { (byte) 0b11100000 }; + assertArrayEquals(expectedQuantizedVector, quantizedVector); + } + + public void testQuantize_fourBitQuantizer_success() throws IOException { + ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); + + byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); + + assertNotNull("Quantized vector should not be null", quantizedVector); + + // Expected quantized vector values for four-bit quantization (packed bits) + byte[] expectedQuantizedVector = new byte[] { (byte) 0xFF, (byte) 0xF0 }; + assertArrayEquals(expectedQuantizedVector, quantizedVector); + } + + public void testQuantize_whenInvalidInput_thenThrows() throws IOException { + ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues); + QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); + assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); + } +} diff --git a/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java b/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java index e3f8b607a..38371d8c3 100644 --- a/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java +++ b/src/test/java/org/opensearch/knn/integ/QFrameworkIT.java @@ -5,7 +5,6 @@ package org.opensearch.knn.integ; -import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.knn.KNNRestTestCase; @@ -29,23 +28,24 @@ public class QFrameworkIT extends KNNRestTestCase { public void testBaseCase() throws IOException { createTestIndex(4); - addKnnDoc(INDEX_NAME, "1", FIELD_NAME, TEST_VECTOR); - Response response = searchKNNIndex( - INDEX_NAME, - XContentFactory.jsonBuilder() - .startObject() - .startObject("query") - .startObject("knn") - .startObject(FIELD_NAME) - .field("vector", TEST_VECTOR) - .field("k", K) - .endObject() - .endObject() - .endObject() - .endObject(), - 1 - ); - assertOK(response); + // TODO :- UnComment this once Search is Integrated and KNN_USE_LUCENE_VECTOR_FORMAT_ENABLED_SETTING is enabled + // addKnnDoc(INDEX_NAME, "1", FIELD_NAME, TEST_VECTOR); + // Response response = searchKNNIndex( + // INDEX_NAME, + // XContentFactory.jsonBuilder() + // .startObject() + // .startObject("query") + // .startObject("knn") + // .startObject(FIELD_NAME) + // .field("vector", TEST_VECTOR) + // .field("k", K) + // .endObject() + // .endObject() + // .endObject() + // .endObject(), + // 1 + // ); + // assertOK(response); } private void createTestIndex(int bitCount) throws IOException { diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java index b95123e21..f6974aea2 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerFactoryTests.java @@ -5,7 +5,6 @@ package org.opensearch.knn.quantization.factory; -import org.junit.Before; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -13,31 +12,22 @@ import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; -import java.lang.reflect.Field; -import java.util.concurrent.atomic.AtomicBoolean; - public class QuantizerFactoryTests extends KNNTestCase { - @Before - public void resetIsRegisteredFlag() throws NoSuchFieldException, IllegalAccessException { - Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); - isRegisteredField.setAccessible(true); - AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); - isRegistered.set(false); - } - public void test_Lazy_Registration() { - ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - assertFalse(isRegisteredFieldAccessible()); - Quantizer oneBitQuantizer = QuantizerFactory.getQuantizer(params); - Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); - Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); - assertEquals(quantizerFourBit.getClass(), MultiBitScalarQuantizer.class); - assertEquals(quantizerTwoBit.getClass(), MultiBitScalarQuantizer.class); - assertEquals(oneBitQuantizer.getClass(), OneBitScalarQuantizer.class); - assertTrue(isRegisteredFieldAccessible()); + try { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + ScalarQuantizationParams paramsTwoBit = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + ScalarQuantizationParams paramsFourBit = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); + Quantizer oneBitQuantizer = QuantizerFactory.getQuantizer(params); + Quantizer quantizerTwoBit = QuantizerFactory.getQuantizer(paramsTwoBit); + Quantizer quantizerFourBit = QuantizerFactory.getQuantizer(paramsFourBit); + assertEquals(OneBitScalarQuantizer.class, oneBitQuantizer.getClass()); + assertEquals(MultiBitScalarQuantizer.class, quantizerTwoBit.getClass()); + assertEquals(MultiBitScalarQuantizer.class, quantizerFourBit.getClass()); + } catch (Exception e) { + assertTrue(e.getMessage().contains("already registered")); + } } public void testGetQuantizer_withNullParams() { @@ -48,16 +38,4 @@ public void testGetQuantizer_withNullParams() { assertEquals("Quantization parameters must not be null.", e.getMessage()); } } - - private boolean isRegisteredFieldAccessible() { - try { - Field isRegisteredField = QuantizerFactory.class.getDeclaredField("isRegistered"); - isRegisteredField.setAccessible(true); - AtomicBoolean isRegistered = (AtomicBoolean) isRegisteredField.get(null); - return isRegistered.get(); - } catch (NoSuchFieldException | IllegalAccessException e) { - fail("Failed to access isRegistered field."); - return false; - } - } } diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index 62d31ab61..7c974e517 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -17,18 +17,22 @@ public class QuantizerRegistryTests extends KNNTestCase { @BeforeClass public static void setup() { - QuantizerRegistry.register( - ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), - new OneBitScalarQuantizer() - ); - QuantizerRegistry.register( - ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), - new MultiBitScalarQuantizer(2) - ); - QuantizerRegistry.register( - ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), - new MultiBitScalarQuantizer(4) - ); + try { + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT), + new OneBitScalarQuantizer() + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.TWO_BIT), + new MultiBitScalarQuantizer(2) + ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), + new MultiBitScalarQuantizer(4) + ); + } catch (Exception e) { + assertTrue(e.getMessage().contains("already registered")); + } } public void testRegisterAndGetQuantizer() { diff --git a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java index 35edf49e2..298256127 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizationState/QuantizationStateTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizationState; +import org.apache.lucene.util.RamUsageEstimator; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; @@ -65,4 +66,76 @@ public void testSerializationWithDifferentVersions() throws IOException { assertArrayEquals(mean, deserializedState.getMeanThresholds(), delta); assertEquals(params.getSqType(), deserializedState.getQuantizationParams().getSqType()); } + + public void testOneBitScalarQuantizationStateRamBytesUsed() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); + float[] mean = { 1.0f, 2.0f, 3.0f }; + + OneBitScalarQuantizationState state = new OneBitScalarQuantizationState(params, mean); + + // 1. Manual Calculation of RAM Usage + long manualEstimatedRamBytesUsed = 0L; + + // OneBitScalarQuantizationState object overhead for Object Header + manualEstimatedRamBytesUsed += alignSize(16L); + + // ScalarQuantizationParams object overhead Object Header + manualEstimatedRamBytesUsed += alignSize(16L); + + // Mean array overhead (array header + size of elements) + manualEstimatedRamBytesUsed += alignSize(16L + 4L * mean.length); + + // 3. RAM Usage from RamUsageEstimator + long expectedRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class) + RamUsageEstimator + .shallowSizeOf(params) + RamUsageEstimator.sizeOf(mean); + + long actualRamBytesUsed = state.ramBytesUsed(); + + // Allow a difference between manual estimation, serialization size, and actual RAM usage + assertTrue( + "The difference between manual and actual RAM usage exceeds 8 bytes", + Math.abs(manualEstimatedRamBytesUsed - actualRamBytesUsed) <= 8 + ); + + assertEquals(expectedRamBytesUsed, actualRamBytesUsed); + } + + public void testMultiBitScalarQuantizationStateRamBytesUsedManualCalculation() throws IOException { + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); + float[][] thresholds = { { 0.5f, 1.5f, 2.5f }, { 1.0f, 2.0f, 3.0f } }; + + MultiBitScalarQuantizationState state = new MultiBitScalarQuantizationState(params, thresholds); + + // Manually estimate RAM usage with alignment + long manualEstimatedRamBytesUsed = 0L; + + // Estimate for MultiBitScalarQuantizationState object + manualEstimatedRamBytesUsed += alignSize(16L); // Example overhead for object + + // Estimate for ScalarQuantizationParams object + manualEstimatedRamBytesUsed += alignSize(16L); // Overhead for params object (including fields) + + // Estimate for thresholds array + manualEstimatedRamBytesUsed += alignSize(16L + 4L * thresholds.length); // Overhead for array + references to sub-arrays + + for (float[] row : thresholds) { + manualEstimatedRamBytesUsed += alignSize(16L + 4L * row.length); // Overhead for each sub-array + size of each float + } + + long ramEstimatorRamBytesUsed = RamUsageEstimator.shallowSizeOfInstance(MultiBitScalarQuantizationState.class) + RamUsageEstimator + .shallowSizeOf(params) + RamUsageEstimator.shallowSizeOf(thresholds); + + for (float[] row : thresholds) { + ramEstimatorRamBytesUsed += RamUsageEstimator.sizeOf(row); + } + + long difference = Math.abs(manualEstimatedRamBytesUsed - ramEstimatorRamBytesUsed); + assertTrue("The difference between manual and actual RAM usage exceeds 8 bytes", difference <= 8); + assertEquals(ramEstimatorRamBytesUsed, state.ramBytesUsed()); + } + + private long alignSize(long size) { + return (size + 7) & ~7; // Align to 8 bytes boundary + } + } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java index 45acaf357..de815d8ad 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizerTests.java @@ -17,7 +17,7 @@ public class MultiBitScalarQuantizerTests extends KNNTestCase { - public void testTrain_twoBit() { + public void testTrain_twoBit() throws IOException { float[][] vectors = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, @@ -33,7 +33,7 @@ public void testTrain_twoBit() { assertEquals(2, mbState.getThresholds().length); // 2-bit quantization should have 2 thresholds } - public void testTrain_fourBit() { + public void testTrain_fourBit() throws IOException { MultiBitScalarQuantizer fourBitQuantizer = new MultiBitScalarQuantizer(4); float[][] vectors = { { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, @@ -220,8 +220,8 @@ public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) { } @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } } } diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java index 6f8c2de87..a6b907ccb 100644 --- a/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizerTests.java @@ -22,14 +22,14 @@ public class OneBitScalarQuantizerTests extends KNNTestCase { - public void testTrain_withTrainingRequired() { + public void testTrain_withTrainingRequired() throws IOException { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); TrainingRequest originalRequest = new TrainingRequest(vectors.length) { @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } }; OneBitScalarQuantizer quantizer = new OneBitScalarQuantizer(); @@ -77,6 +77,21 @@ public ScalarQuantizationParams getQuantizationParams() { return new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); } + @Override + public int getBytesPerVector() { + return 0; + } + + @Override + public int getDimensions() { + return 0; + } + + @Override + public long ramBytesUsed() { + return 0; + } + @Override public byte[] toByteArray() { return new byte[0]; @@ -103,14 +118,14 @@ public void testQuantize_withMismatchedDimensions() throws IOException { expectThrows(IllegalArgumentException.class, () -> quantizer.quantize(vector, state, output)); } - public void testCalculateMean() { + public void testCalculateMean() throws IOException { float[][] vectors = { { 1.0f, 2.0f, 3.0f }, { 4.0f, 5.0f, 6.0f }, { 7.0f, 8.0f, 9.0f } }; ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } }; @@ -126,8 +141,8 @@ public void testCalculateMean_withNullVector() { ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); TrainingRequest samplingRequest = new TrainingRequest(vectors.length) { @Override - public float[] getVectorByDocId(int docId) { - return vectors[docId]; + public float[] getVectorAtThePosition(int position) { + return vectors[position]; } };