diff --git a/README.md b/README.md index bce27d563..f5ed9e5c5 100644 --- a/README.md +++ b/README.md @@ -27,6 +27,7 @@ The graph is represented by an on-disk adjacency list per node, with additional The second pass can be performed with * Full resolution float32 vectors +* NVQ, which uses a non-uniform technique to quantize vectors with high-accuracy [This two-pass design reduces memory usage and reduces latency while preserving accuracy](https://thenewstack.io/why-vector-size-matters/). diff --git a/UPGRADING.md b/UPGRADING.md index e5f5beb86..0d4982d36 100644 --- a/UPGRADING.md +++ b/UPGRADING.md @@ -16,6 +16,12 @@ `CompressedVectors` directly from `encodeAll()`. - `PQVectors::getProductQuantization` is removed; it duplicated `CompressedVectors::getCompressor` unnecessarily +## New features +- Support for Non-uniform Vector Quantization (NVQ, pronounced as "new vec"). This new technique quantizes the values + in each vector with high accuracy by first applying a nonlinear transformation that is individually fit to each + vector. These nonlinearities are designed to be lightweight and have a negligible impact on distance computation + performance. + # Upgrading from 2.0.x to 3.0.x ## Critical API changes diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java index b6b7a274c..3f35b72b1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/GraphIndexBuilder.java @@ -211,20 +211,20 @@ public static GraphIndexBuilder rescore(GraphIndexBuilder other, BuildScoreProvi var neighbors = other.graph.getNeighbors(i); var sf = newProvider.searchProviderFor(i).scoreFunction(); var newNeighbors = new NodeArray(neighbors.size()); - + // Copy edges, compute new scores for (var it = neighbors.iterator(); it.hasNext(); ) { int neighbor = it.nextInt(); // since we're using a different score provider, use insertSorted instead of addInOrder newNeighbors.insertSorted(neighbor, sf.similarityTo(neighbor)); } - + newBuilder.graph.addNode(i, newNeighbors); } // Set the entry node newBuilder.graph.updateEntryNode(other.graph.entry()); - + return newBuilder; } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java index 79339e82a..1964fe87a 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FeatureId.java @@ -31,7 +31,8 @@ */ public enum FeatureId { INLINE_VECTORS(InlineVectors::load), - FUSED_ADC(FusedADC::load); + FUSED_ADC(FusedADC::load), + NVQ_VECTORS(NVQ::load); public static final Set ALL = Collections.unmodifiableSet(EnumSet.allOf(FeatureId.class)); diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java index 59e86da1c..e76dd2738 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/FusedADC.java @@ -19,9 +19,9 @@ import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.GraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; -import io.github.jbellis.jvector.pq.FusedADCPQDecoder; -import io.github.jbellis.jvector.pq.PQVectors; -import io.github.jbellis.jvector.pq.ProductQuantization; +import io.github.jbellis.jvector.quantization.FusedADCPQDecoder; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.ExplicitThreadLocal; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorizationProvider; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NVQ.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NVQ.java new file mode 100644 index 000000000..ef52fb10d --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/NVQ.java @@ -0,0 +1,104 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.graph.disk; + +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.quantization.NVQScorer; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.quantization.NVQuantization.QuantizedVector; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.io.DataOutput; +import java.io.IOException; +import java.io.UncheckedIOException; + +/** + * Implements the storage of NuVeQ vectors in an on-disk graph index. These can be used for reranking. + */ +public class NVQ implements Feature { + private final NVQuantization nvq; + private final NVQScorer scorer; + private final ThreadLocal reusableQuantizedVector; + + public NVQ(NVQuantization nvq) { + this.nvq = nvq; + scorer = new NVQScorer(this.nvq); + reusableQuantizedVector = ThreadLocal.withInitial(() -> NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, nvq.bitsPerDimension)); + } + + @Override + public FeatureId id() { + return FeatureId.NVQ_VECTORS; + } + + @Override + public int headerSize() { + return nvq.compressorSize(); + } + + @Override + public int inlineSize() { return nvq.compressedVectorSize();} + + public int dimension() { + return nvq.globalMean.length(); + } + + static NVQ load(CommonHeader header, RandomAccessReader reader) { + try { + return new NVQ(NVQuantization.load(reader)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public void writeHeader(DataOutput out) throws IOException { + nvq.write(out, OnDiskGraphIndex.CURRENT_VERSION); + } + + @Override + public void writeInline(DataOutput out, Feature.State state_) throws IOException { + var state = (NVQ.State) state_; + state.vector.write(out); + } + + public static class State implements Feature.State { + public final QuantizedVector vector; + + public State(QuantizedVector vector) { + this.vector = vector; + } + } + + ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, + VectorSimilarityFunction vsf, + FeatureSource source) { + var function = scorer.scoreFunctionFor(queryVector, vsf); + + return node2 -> { + try { + var reader = source.inlineReaderForNode(node2, FeatureId.NVQ_VECTORS); + QuantizedVector.loadInto(reader, reusableQuantizedVector.get()); + } catch (IOException e) { + throw new RuntimeException(e); + } + return function.similarityTo(reusableQuantizedVector.get()); + }; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java index 219813102..13c7fdc58 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndex.java @@ -262,10 +262,13 @@ public void close() throws IOException { @Override public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat queryVector, VectorSimilarityFunction vsf) { - if (!features.containsKey(FeatureId.INLINE_VECTORS)) { - throw new UnsupportedOperationException("No inline vectors in this graph"); + if (features.containsKey(FeatureId.INLINE_VECTORS)) { + return RandomAccessVectorValues.super.rerankerFor(queryVector, vsf); + } else if (features.containsKey(FeatureId.NVQ_VECTORS)) { + return ((NVQ) features.get(FeatureId.NVQ_VECTORS)).rerankerFor(queryVector, vsf, this); + } else { + throw new UnsupportedOperationException("No reranker available for this graph"); } - return RandomAccessVectorValues.super.rerankerFor(queryVector, vsf); } @Override diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java index 35728936d..1d809be8b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/disk/OnDiskGraphIndexWriter.java @@ -324,8 +324,10 @@ public OnDiskGraphIndexWriter build() throws IOException { int dimension; if (features.containsKey(FeatureId.INLINE_VECTORS)) { dimension = ((InlineVectors) features.get(FeatureId.INLINE_VECTORS)).dimension(); + } else if (features.containsKey(FeatureId.NVQ_VECTORS)) { + dimension = ((NVQ) features.get(FeatureId.NVQ_VECTORS)).dimension(); } else { - throw new IllegalArgumentException("Inline vectors must be provided."); + throw new IllegalArgumentException("Inline or NVQ vectors must be provided."); } if (ordinalMapper == null) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java index f94ae7f3b..a796490ec 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/BuildScoreProvider.java @@ -17,8 +17,8 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; -import io.github.jbellis.jvector.pq.BQVectors; -import io.github.jbellis.jvector.pq.PQVectors; +import io.github.jbellis.jvector.quantization.BQVectors; +import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.VectorizationProvider; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/CachingVectorValues.java b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/CachingVectorValues.java index 27aaa633a..af032b0ca 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/CachingVectorValues.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/graph/similarity/CachingVectorValues.java @@ -17,7 +17,7 @@ package io.github.jbellis.jvector.graph.similarity; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; -import io.github.jbellis.jvector.pq.PQVectors; +import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.vector.types.VectorFloat; import org.agrona.collections.Int2ObjectHashMap; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/BQVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java index a79d8b702..46a5c786b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BQVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java index 91bb90dc7..5e956dd86 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/BinaryQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/BinaryQuantization.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.disk.RandomAccessReader; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/CompressedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java similarity index 98% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/CompressedVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java index d9f8eb63a..21726a213 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/CompressedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/CompressedVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/FusedADCPQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/FusedADCPQDecoder.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java index 0d013d361..2b99d3329 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/FusedADCPQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/FusedADCPQDecoder.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.graph.disk.FusedADC; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutableBQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java similarity index 94% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutableBQVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java index 85b2313f8..4acc4744d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutableBQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutableBQVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; public class ImmutableBQVectors extends BQVectors { public ImmutableBQVectors(BinaryQuantization bq, long[][] compressedVectors) { diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java similarity index 96% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java index fef169e12..b1d538b34 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ImmutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ImmutablePQVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.vector.types.ByteSequence; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java index d8a156d5a..573c00b5f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/KMeansPlusPlusClusterer.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/KMeansPlusPlusClusterer.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.vector.Matrix; import io.github.jbellis.jvector.vector.VectorUtil; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutableBQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java similarity index 97% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutableBQVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java index a9cf8a474..10923f7f4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutableBQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableBQVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; public class MutableBQVectors extends BQVectors implements MutableCompressedVectors { private static final int INITIAL_CAPACITY = 1024; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutableCompressedVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java similarity index 96% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutableCompressedVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java index 276923cb6..33d4a77aa 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutableCompressedVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutableCompressedVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; public interface MutableCompressedVectors extends CompressedVectors { /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java similarity index 98% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java index 6820155c6..62e6a522e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/MutablePQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/MutablePQVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.vector.VectorizationProvider; import io.github.jbellis.jvector.vector.types.ByteSequence; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java new file mode 100644 index 000000000..a82520871 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQScorer.java @@ -0,0 +1,145 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.quantization; + +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +public class NVQScorer { + final NVQuantization nvq; + + /** + * Initialize the NVQScorer with an instance of NVQuantization. + */ + public NVQScorer(NVQuantization nvq) { + this.nvq = nvq; + } + + public NVQScoreFunction scoreFunctionFor(VectorFloat query, VectorSimilarityFunction similarityFunction) { + switch (similarityFunction) { + case DOT_PRODUCT: + return dotProductScoreFunctionFor(query); + case EUCLIDEAN: + return euclideanScoreFunctionFor(query); + case COSINE: + return cosineScoreFunctionFor(query); + default: + throw new IllegalArgumentException("Unsupported similarity function " + similarityFunction); + } + } + + private NVQScoreFunction dotProductScoreFunctionFor(VectorFloat query) { + /* Each sub-vector of query vector (full resolution) will be compared to NVQ quantized sub-vectors that were + * first de-meaned by subtracting the global mean. + */ + var queryGlobalBias = VectorUtil.dotProduct(query, this.nvq.globalMean); + var querySubVectors = this.nvq.getSubVectors(query); + + switch (this.nvq.bitsPerDimension) { + case EIGHT: + for (VectorFloat querySubVector : querySubVectors) { + VectorUtil.nvqShuffleQueryInPlace8bit(querySubVector); + } + + return vector2 -> { + float nvqDot = 0; + for (int i = 0; i < querySubVectors.length; i++) { + var svDB = vector2.subVectors[i]; + nvqDot += VectorUtil.nvqDotProduct8bit(querySubVectors[i], + svDB.bytes, svDB.growthRate, svDB.midpoint, + svDB.minValue, svDB.maxValue + ); + } + return (1 + nvqDot + queryGlobalBias) / 2; + }; + default: + throw new IllegalArgumentException("Unsupported bits per dimension " + this.nvq.bitsPerDimension); + } + } + + private NVQScoreFunction euclideanScoreFunctionFor(VectorFloat query) { + /* Each sub-vector of query vector (full resolution) will be compared to NVQ quantized sub-vectors that were + * first de-meaned by subtracting the global mean. + */ + var shiftedQuery = VectorUtil.sub(query, this.nvq.globalMean); + var querySubVectors = this.nvq.getSubVectors(shiftedQuery); + + switch (this.nvq.bitsPerDimension) { + case EIGHT: + for (VectorFloat querySubVector : querySubVectors) { + VectorUtil.nvqShuffleQueryInPlace8bit(querySubVector); + } + + return vector2 -> { + float dist = 0; + for (int i = 0; i < querySubVectors.length; i++) { + var svDB = vector2.subVectors[i]; + dist += VectorUtil.nvqSquareL2Distance8bit( + querySubVectors[i], + svDB.bytes, svDB.growthRate, svDB.midpoint, + svDB.minValue, svDB.maxValue + ); + } + + return 1 / (1 + dist); + }; + default: + throw new IllegalArgumentException("Unsupported bits per dimension " + this.nvq.bitsPerDimension); + } + } + + private NVQScoreFunction cosineScoreFunctionFor(VectorFloat query) { + float queryNorm = (float) Math.sqrt(VectorUtil.dotProduct(query, query)); + var querySubVectors = this.nvq.getSubVectors(query); + var meanSubVectors = this.nvq.getSubVectors(this.nvq.globalMean); + + switch (this.nvq.bitsPerDimension) { + case EIGHT: + for (var i = 0; i < querySubVectors.length; i++) { + VectorUtil.nvqShuffleQueryInPlace8bit(querySubVectors[i]); + VectorUtil.nvqShuffleQueryInPlace8bit(meanSubVectors[i]); + } + + return vector2 -> { + float cos = 0; + float squaredNormalization = 0; + for (int i = 0; i < querySubVectors.length; i++) { + var svDB = vector2.subVectors[i]; + var partialCosSim = VectorUtil.nvqCosine8bit(querySubVectors[i], + svDB.bytes, svDB.growthRate, svDB.midpoint, + svDB.minValue, svDB.maxValue, + meanSubVectors[i]); + cos += partialCosSim[0]; + squaredNormalization += partialCosSim[1]; + } + float cosine = (cos / queryNorm) / (float) Math.sqrt(squaredNormalization); + + return (1 + cosine) / 2; + }; + default: + throw new IllegalArgumentException("Unsupported bits per dimension " + this.nvq.bitsPerDimension); + } + } + + public interface NVQScoreFunction { + /** + * @return the similarity to another vector + */ + float similarityTo(NVQuantization.QuantizedVector vector2); + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java new file mode 100644 index 000000000..6b4eb1ca1 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQVectors.java @@ -0,0 +1,153 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.quantization; + +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.util.RamUsageEstimator; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; + +public class NVQVectors implements CompressedVectors { + final NVQuantization nvq; + final NVQScorer scorer; + final NVQuantization.QuantizedVector[] compressedVectors; + + /** + * Initialize the NVQVectors with an initial array of vectors. This array may be + * mutated, but caller is responsible for thread safety issues when doing so. + */ + public NVQVectors(NVQuantization nvq, NVQuantization.QuantizedVector[] compressedVectors) { + this.nvq = nvq; + this.scorer = new NVQScorer(nvq); + this.compressedVectors = compressedVectors; + } + + @Override + public int count() { + return compressedVectors.length; + } + + @Override + public void write(DataOutput out, int version) throws IOException + { + // serializing NVQ at the given version + nvq.write(out, version); + + // compressed vectors + out.writeInt(compressedVectors.length); + for (var v : compressedVectors) { + v.write(out); + } + } + + public static NVQVectors load(RandomAccessReader in) throws IOException { + var nvq = NVQuantization.load(in); + + // read the vectors + int size = in.readInt(); + if (size < 0) { + throw new IOException("Invalid compressed vector count " + size); + } + NVQuantization.QuantizedVector[] compressedVectors = new NVQuantization.QuantizedVector[size]; + + for (int i = 0; i < size; i++) { + compressedVectors[i] = NVQuantization.QuantizedVector.load(in); + } + + return new NVQVectors(nvq, compressedVectors); + } + + public static NVQVectors load(RandomAccessReader in, long offset) throws IOException { + in.seek(offset); + return load(in); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NVQVectors that = (NVQVectors) o; + if (!Objects.equals(nvq, that.nvq)) return false; + return Arrays.deepEquals(compressedVectors, that.compressedVectors); + } + + @Override + public int hashCode() { + return Objects.hash(nvq, Arrays.hashCode(compressedVectors)); + } + + @Override + public ScoreFunction.ApproximateScoreFunction precomputedScoreFunctionFor(VectorFloat query, VectorSimilarityFunction similarityFunction) { + return scoreFunctionFor(query, similarityFunction); + } + + @Override + public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat query, VectorSimilarityFunction similarityFunction) { + var function = scorer.scoreFunctionFor(query, similarityFunction); + return node2 -> function.similarityTo(compressedVectors[node2]); + } + + public NVQuantization.QuantizedVector get(int ordinal) { + return compressedVectors[ordinal]; + } + + public NVQuantization getNVQuantization() { + return nvq; + } + + @Override + public int getOriginalSize() { + return nvq.originalDimension * Float.BYTES; + } + + @Override + public int getCompressedSize() { + return nvq.compressedVectorSize(); + } + + @Override + public NVQuantization getCompressor() { + return nvq; + } + + @Override + public long ramBytesUsed() { + int REF_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_REF; + int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; + int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; + + long codebooksSize = nvq.ramBytesUsed(); + long listSize = (long) REF_BYTES * (1 + compressedVectors.length); + long dataSize = (long) (OH_BYTES + AH_BYTES + nvq.compressedVectorSize()) * compressedVectors.length; + return codebooksSize + listSize + dataSize; + } + + @Override + public String toString() { + return "NVQVectors{" + + "NVQ=" + nvq + + ", count=" + compressedVectors.length + + '}'; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java new file mode 100644 index 000000000..0354e82be --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/NVQuantization.java @@ -0,0 +1,658 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.quantization; + +import io.github.jbellis.jvector.annotations.VisibleForTesting; +import io.github.jbellis.jvector.disk.RandomAccessReader; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; +import io.github.jbellis.jvector.util.Accountable; +import io.github.jbellis.jvector.vector.VectorUtil; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.ByteSequence; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; + +import java.io.DataOutput; +import java.io.IOException; +import java.util.Arrays; +import java.util.Objects; +import java.util.concurrent.ForkJoinPool; +import java.util.stream.IntStream; + + +/** + * Non-uniform Vector Quantization for float vectors. + * It divides each vector in subvectors and then quantizes each one individually using a non-uniform quantizer. + */ +public class NVQuantization implements VectorCompressor, Accountable { + public enum BitsPerDimension { + EIGHT { + @Override + public int getInt() { + return 8; + } + + @Override + public ByteSequence createByteSequence(int nDimensions) { + return vectorTypeSupport.createByteSequence(nDimensions); + } + }, + FOUR { + @Override + public int getInt() { + return 4; + } + + @Override + public ByteSequence createByteSequence(int nDimensions) { + return vectorTypeSupport.createByteSequence((int) Math.ceil(nDimensions / 2.)); + } + }; + + /** + * Writes the BitsPerDimension to DataOutput. + * @param out the DataOutput into which to write the object + * @throws IOException if there is a problem writing to out. + */ + public void write(DataOutput out) throws IOException { + out.writeInt(getInt()); + } + + /** + * Returns the integer 4 for FOUR and 8 for EIGHT + */ + public abstract int getInt(); + + /** + * Creates a ByteSequence of the proper length to store the quantized vector. + * @param nDimensions the number of dimensions of the original vector + * @return the byte sequence + */ + public abstract ByteSequence createByteSequence(int nDimensions); + + /** + * Loads the BitsPerDimension from a RandomAccessReader. + * @param in the RandomAccessReader to read from. + * @throws IOException if there is a problem reading from in. + */ + public static BitsPerDimension load(RandomAccessReader in) throws IOException { + int nBitsPerDimension = in.readInt(); + switch (nBitsPerDimension) { + case 8: + return BitsPerDimension.EIGHT; + default: + throw new IllegalArgumentException("Unsupported BitsPerDimension " + nBitsPerDimension); + } + } + } + + private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + + // How many bits to use for each dimension when quantizing the vector: + public final BitsPerDimension bitsPerDimension; + + // We subtract the global mean vector to make it robust against center datasets with a large mean: + public final VectorFloat globalMean; + + // The number of dimensions of the original (uncompressed) vectors: + public final int originalDimension; + + // A matrix that stores the size and starting point of each subvector: + public final int[][] subvectorSizesAndOffsets; + + // Whether we want to skip the optimization of the NVQ parameters. Here for debug purposes only. + @VisibleForTesting + public boolean learn = true; + + /** + * Class constructor. + * @param subvectorSizesAndOffsets a matrix that stores the size and starting point of each subvector + * @param globalMean the mean of the database (its average vector) + */ + private NVQuantization(int[][] subvectorSizesAndOffsets, VectorFloat globalMean) { + this.bitsPerDimension = BitsPerDimension.EIGHT; + this.globalMean = globalMean; + this.subvectorSizesAndOffsets = subvectorSizesAndOffsets; + this.originalDimension = Arrays.stream(subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum(); + + if (globalMean.length() != originalDimension) { + var msg = String.format("Global mean length %d does not match vector dimensionality %d", globalMean.length(), originalDimension); + throw new IllegalArgumentException(msg); + } + } + + /** + * Computes the global mean vector and the data structures used to divide each vector into subvectors. + * + * @param ravv the vectors to quantize + * @param nSubVectors number of subvectors + */ + public static NVQuantization compute(RandomAccessVectorValues ravv, int nSubVectors) { + var subvectorSizesAndOffsets = getSubvectorSizesAndOffsets(ravv.dimension(), nSubVectors); + + var ravvCopy = ravv.threadLocalSupplier().get(); + var dim = ravvCopy.getVector(0).length(); + var globalMean = vectorTypeSupport.createFloatVector(dim); + for (int i = 0; i < ravvCopy.size(); i++) { + VectorUtil.addInPlace(globalMean, ravvCopy.getVector(i)); + } + VectorUtil.scale(globalMean, 1.0f / ravvCopy.size()); + return new NVQuantization(subvectorSizesAndOffsets, globalMean); + } + + + @Override + public CompressedVectors createCompressedVectors(Object[] compressedVectors) { + return new NVQVectors(this, (QuantizedVector[]) compressedVectors); + } + + /** + * Encodes the given vectors in parallel using NVQ. + */ + @Override + public NVQVectors encodeAll(RandomAccessVectorValues ravv, ForkJoinPool parallelExecutor) { + var ravvCopy = ravv.threadLocalSupplier(); + return new NVQVectors(this, + parallelExecutor.submit(() -> IntStream.range(0, ravv.size()) + .parallel() + .mapToObj(i -> { + var localRavv = ravvCopy.get(); + VectorFloat v = localRavv.getVector(i); + return encode(v); + }) + .toArray(QuantizedVector[]::new)) + .join()); + } + + /** + * Encodes the input vector using NVQ. + * @return one subvector per subspace + */ + @Override + public QuantizedVector encode(VectorFloat vector) { + var qv = QuantizedVector.createEmpty(subvectorSizesAndOffsets, bitsPerDimension); + encodeTo(vector, qv); + return qv; + } + + /** + * Encodes the input vector using NVQ into dest + */ + @Override + public void encodeTo(VectorFloat v, NVQuantization.QuantizedVector dest) { + var tempVector = VectorUtil.sub(v, globalMean); + QuantizedVector.quantizeTo(getSubVectors(tempVector), bitsPerDimension, learn, dest); + } + + /** + * Creates an array of subvectors from a given vector + */ + public VectorFloat[] getSubVectors(VectorFloat vector) { + VectorFloat[] subvectors = new VectorFloat[subvectorSizesAndOffsets.length]; + + // Iterate through the subvectorSizesAndOffsets to create each subvector and copy slices into them + for (int i = 0; i < subvectorSizesAndOffsets.length; i++) { + int size = subvectorSizesAndOffsets[i][0]; // Size of the subvector + int offset = subvectorSizesAndOffsets[i][1]; // Offset from the start of the input vector + VectorFloat subvector = vectorTypeSupport.createFloatVector(size); + subvector.copyFrom(vector, offset, 0, size); + subvectors[i] = subvector; + } + return subvectors; + } + + /** + * Splits the vector dimension into M subvectors of roughly equal size. + */ + static int[][] getSubvectorSizesAndOffsets(int dimensions, int M) { + if (M > dimensions) { + throw new IllegalArgumentException("Number of subspaces must be less than or equal to the vector dimension"); + } + int[][] sizes = new int[M][2]; + int baseSize = dimensions / M; + int remainder = dimensions % M; + // distribute the remainder among the subvectors + int offset = 0; + for (int i = 0; i < M; i++) { + int size = baseSize + (i < remainder ? 1 : 0); + sizes[i] = new int[]{size, offset}; + offset += size; + } + return sizes; + } + + /** + * Writes the instance to a DataOutput. + * @param out DataOutput to write to + * @param version serialization version. + * @throws IOException fails if we cannot write to the DataOutput + */ + public void write(DataOutput out, int version) throws IOException + { + if (version > OnDiskGraphIndex.CURRENT_VERSION) { + throw new IllegalArgumentException("Unsupported serialization version " + version); + } + + out.writeInt(version); + + out.writeInt(globalMean.length()); + vectorTypeSupport.writeFloatVector(out, globalMean); + + bitsPerDimension.write(out); + + out.writeInt(subvectorSizesAndOffsets.length); + assert Arrays.stream(subvectorSizesAndOffsets).mapToInt(m -> m[0]).sum() == originalDimension; + for (var a : subvectorSizesAndOffsets) { + out.writeInt(a[0]); + } + } + + /** + * Returns the size in bytes of this class when writing it using the write method. + * @return the size in bytes + */ + @Override + public int compressorSize() { + int size = 0; + size += Integer.BYTES; // STORAGE_VERSION + size += Integer.BYTES; // globalCentroidLength + size += Float.BYTES * globalMean.length(); + size += Integer.BYTES; // bitsPerDimension + size += Integer.BYTES; // nSubVectors + size += subvectorSizesAndOffsets.length * Integer.BYTES; + return size; + } + + /** + * Loads an instance from a RandomAccessReader + * @param in the RandomAccessReader + * @return the instance + * @throws IOException fails if we cannot read from the RandomAccessReader + */ + public static NVQuantization load(RandomAccessReader in) throws IOException { + int version = in.readInt(); + int globalMeanLength = in.readInt(); + + VectorFloat globalMean = null; + if (globalMeanLength > 0) { + globalMean = vectorTypeSupport.readFloatVector(in, globalMeanLength); + } + + BitsPerDimension bitsPerDimension = BitsPerDimension.load(in); + + int nSubVectors = in.readInt(); + int[][] subvectorSizes = new int[nSubVectors][]; + int offset = 0; + for (int i = 0; i < nSubVectors; i++) { + subvectorSizes[i] = new int[2]; + int size = in.readInt(); + subvectorSizes[i][0] = size; + subvectorSizes[i][1] = offset; + offset += size; + } + + return new NVQuantization(subvectorSizes, globalMean); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NVQuantization that = (NVQuantization) o; + return originalDimension == that.originalDimension + && Objects.equals(globalMean, that.globalMean) + && Arrays.deepEquals(subvectorSizesAndOffsets, that.subvectorSizesAndOffsets); + } + + @Override + public int hashCode() { + int result = Objects.hash(originalDimension); + result = 31 * result + Objects.hashCode(globalMean); + result = 31 * result + Arrays.deepHashCode(subvectorSizesAndOffsets); + return result; + } + + @Override + public int compressedVectorSize() { + int size = Integer.BYTES; // number of subvectors + for (int[] subvectorSizesAndOffset : subvectorSizesAndOffsets) { + size += QuantizedSubVector.compressedVectorSize(subvectorSizesAndOffset[0], bitsPerDimension); + } + return size; + } + + @Override + public long ramBytesUsed() { + return globalMean.ramBytesUsed(); + } + + @Override + public String toString() { + return String.format("NVQuantization(sub-vectors=%d)", subvectorSizesAndOffsets.length); + } + + /** + * A NuVeQ vector. + */ + public static class QuantizedVector { + public final QuantizedSubVector[] subVectors; + + /** + * Class constructor. + * @param subVectors receives the subvectors to quantize + * @param bitsPerDimension the number of bits per dimension + * @param learn whether to use optimization to find the parameters of the nonlinearity + */ + public static void quantizeTo(VectorFloat[] subVectors, BitsPerDimension bitsPerDimension, boolean learn, QuantizedVector dest) { + for (int i = 0; i < subVectors.length; i++) { + QuantizedSubVector.quantizeTo(subVectors[i], bitsPerDimension, learn, dest.subVectors[i]); + } + } + + /** + * Constructs an instance from existing subvectors. Used when loading from a RandomAccessReader. + * @param subVectors the subvectors + */ + private QuantizedVector(QuantizedSubVector[] subVectors) { + this.subVectors = subVectors; + } + + /** + * Create an empty instance. Meant to be used as scratch space in conjunction with loadInto + * @param subvectorSizesAndOffsets the array containing the sizes for the subvectors + * @param bitsPerDimension the number of bits per dimension + */ + public static QuantizedVector createEmpty(int[][] subvectorSizesAndOffsets, BitsPerDimension bitsPerDimension) { + var subVectors = new QuantizedSubVector[subvectorSizesAndOffsets.length]; + for (int i = 0; i < subvectorSizesAndOffsets.length; i++) { + subVectors[i] = QuantizedSubVector.createEmpty(bitsPerDimension, subvectorSizesAndOffsets[i][0]); + } + return new QuantizedVector(subVectors); + } + + + /** + * Write the instance to a DataOutput + * @param out the DataOutput + * @throws IOException fails if we cannot write to the DataOutput + */ + public void write(DataOutput out) throws IOException { + out.writeInt(subVectors.length); + + for (var sv : subVectors) { + sv.write(out); + } + } + + /** + * Read the instance from a RandomAccessReader + * @param in the RandomAccessReader + * @throws IOException fails if we cannot read from the RandomAccessReader + */ + public static QuantizedVector load(RandomAccessReader in) throws IOException { + int length = in.readInt(); + var subVectors = new QuantizedSubVector[length]; + for (int i = 0; i < length; i++) { + subVectors[i] = QuantizedSubVector.load(in); + } + + return new QuantizedVector(subVectors); + } + + /** + * Read the instance from a RandomAccessReader + * @param in the RandomAccessReader + * @throws IOException fails if we cannot read from the RandomAccessReader + */ + public static void loadInto(RandomAccessReader in, QuantizedVector qvector) throws IOException { + in.readInt(); + for (int i = 0; i < qvector.subVectors.length; i++) { + QuantizedSubVector.loadInto(in, qvector.subVectors[i]); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + QuantizedVector that = (QuantizedVector) o; + return Arrays.deepEquals(subVectors, that.subVectors); + } + } + + /** + * A NuVeQ sub-vector. + */ + public static class QuantizedSubVector { + // The byte sequence that stores the quantized subvector + public ByteSequence bytes; + + // The number of bits for each dimension of the input uncompressed subvector + public BitsPerDimension bitsPerDimension; + + // The NVQ parameters + public float growthRate; + public float midpoint; + public float maxValue; + public float minValue; + + // The number of dimensions of the input uncompressed subvector + public int originalDimensions; + + /** + * Return the number of bytes occupied by the serialization of a QuantizedSubVector + * @param nDims the number fof dimensions of the subvector + * @param bitsPerDimension the number of bits per dimensions + * @return the size in bytes of the quantized subvector + */ + public static int compressedVectorSize(int nDims, BitsPerDimension bitsPerDimension) { + // Here we assume that an enum takes 4 bytes + switch (bitsPerDimension) { + case EIGHT: return nDims + 4 * Float.BYTES + 3 * Integer.BYTES; + default: throw new IllegalArgumentException("Unsupported bits per dimension: " + bitsPerDimension); + } + } + + /** + * Quantize the vector using NVQ into dest + * @param vector the subvector to quantize + * @param bitsPerDimension the number of bits per dimension + * @param learn whether to use optimization to find the parameters of the nonlinearity + * @param dest the destination subvector + */ + public static void quantizeTo(VectorFloat vector, BitsPerDimension bitsPerDimension, boolean learn, QuantizedSubVector dest) { + var minValue = VectorUtil.min(vector); + var maxValue = VectorUtil.max(vector); + + //----------------------------------------------------------------- + // Optimization to find the hyperparameters of the quantization + float growthRate = 1e-2f; + float midpoint = 0; + + if (learn) { + NonuniformQuantizationLossFunction lossFunction = new NonuniformQuantizationLossFunction(bitsPerDimension); + lossFunction.setVector(vector, minValue, maxValue); + + float growthRateCoarse = 1e-2f; + float bestLossValue = Float.MIN_VALUE; + float[] tempSolution = {growthRateCoarse, 0.f}; + for (float gr = 1e-6f; gr < 20.f; gr += 1f) { + tempSolution[0] = gr; + float lossValue = lossFunction.compute(tempSolution); + if (lossValue > bestLossValue) { + bestLossValue = lossValue; + growthRateCoarse = gr; + } + } + float growthRateFineTuned = growthRateCoarse; + for (float gr = growthRateCoarse - 1; gr < growthRateCoarse + 1; gr += 0.1f) { + tempSolution[0] = gr; + float lossValue = lossFunction.compute(tempSolution); + if (lossValue > bestLossValue) { + bestLossValue = lossValue; + growthRateFineTuned = gr; + } + } + + growthRate = growthRateFineTuned; + } + //--------------------------------------------------------------------------- + + var quantized = bitsPerDimension.createByteSequence(vector.length()); + switch (bitsPerDimension) { + case EIGHT: + VectorUtil.nvqQuantize8bit(vector, growthRate, midpoint, minValue, maxValue, quantized); + break; + default: + throw new IllegalArgumentException("Unsupported bits per dimension: " + bitsPerDimension); + } + + dest.bitsPerDimension = bitsPerDimension; + dest.minValue = minValue; + dest.maxValue = maxValue; + dest.growthRate = growthRate; + dest.midpoint = midpoint; + dest.bytes = quantized; + dest.originalDimensions = vector.length(); + } + + /** + * Constructor used when loading from a RandomAccessReader. It takes its member fields. + */ + private QuantizedSubVector(ByteSequence bytes, int originalDimensions, BitsPerDimension bitsPerDimension, + float minValue, float maxValue, + float growthRate, float midpoint) { + this.bitsPerDimension = bitsPerDimension; + this.bytes = bytes; + this.minValue = minValue; + this.maxValue = maxValue; + this.growthRate = growthRate; + this.midpoint = midpoint; + this.originalDimensions = originalDimensions; + } + + /** + * Write the instance to a DataOutput + * @param out the DataOutput + * @throws IOException fails if we cannot write to the DataOutput + */ + public void write(DataOutput out) throws IOException { + bitsPerDimension.write(out); + out.writeFloat(minValue); + out.writeFloat(maxValue); + out.writeFloat(growthRate); + out.writeFloat(midpoint); + out.writeInt(originalDimensions); + out.writeInt(bytes.length()); + + vectorTypeSupport.writeByteSequence(out, bytes); + } + + /** + * Create an empty instance. Meant to be used as scratch space in conjunction with loadInto + * @param bitsPerDimension the number of bits per dimension + * @param length the number of dimensions + */ + public static QuantizedSubVector createEmpty(BitsPerDimension bitsPerDimension, int length) { + ByteSequence bytes = bitsPerDimension.createByteSequence(length); + return new QuantizedSubVector(bytes, length, bitsPerDimension, 0.f, 0.f, 0.f, 0.f); + } + + /** + * Read the instance from a RandomAccessReader + * @param in the RandomAccessReader + * @throws IOException fails if we cannot read from the RandomAccessReader + */ + public static QuantizedSubVector load(RandomAccessReader in) throws IOException { + BitsPerDimension bitsPerDimension = BitsPerDimension.load(in); + float minValue = in.readFloat(); + float maxValue = in.readFloat(); + float logisticAlpha = in.readFloat(); + float logisticX0 = in.readFloat(); + int originalDimensions = in.readInt(); + int compressedDimension = in.readInt(); + + ByteSequence bytes = vectorTypeSupport.readByteSequence(in, compressedDimension); + + return new QuantizedSubVector(bytes, originalDimensions, bitsPerDimension, minValue, maxValue, logisticAlpha, logisticX0); + } + + /** + * Read the instance from a RandomAccessReader + * @param in the RandomAccessReader + * @throws IOException fails if we cannot read from the RandomAccessReader + */ + public static void loadInto(RandomAccessReader in, QuantizedSubVector quantizedSubVector) throws IOException { + quantizedSubVector.bitsPerDimension = BitsPerDimension.load(in); + quantizedSubVector.minValue = in.readFloat(); + quantizedSubVector.maxValue = in.readFloat(); + quantizedSubVector.growthRate = in.readFloat(); + quantizedSubVector.midpoint = in.readFloat(); + quantizedSubVector.originalDimensions = in.readInt(); + in.readInt(); + + vectorTypeSupport.readByteSequence(in, quantizedSubVector.bytes); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + QuantizedSubVector that = (QuantizedSubVector) o; + return (maxValue == that.maxValue) + && (minValue == that.minValue) + && (growthRate == that.growthRate) + && (midpoint == that.midpoint) + && (bitsPerDimension == that.bitsPerDimension) + && bytes.equals(that.bytes); + } + } + + /** + * The loss used to optimize for the NVQ hyperparameters + * We use the ratio between the loss given by the uniform quantization and the NVQ loss. + */ + private static class NonuniformQuantizationLossFunction { + final private BitsPerDimension bitsPerDimension; + private VectorFloat vector; + private float minValue; + private float maxValue; + private float baseline; + + public NonuniformQuantizationLossFunction(BitsPerDimension bitsPerDimension) { + this.bitsPerDimension = bitsPerDimension; + } + + public void setVector(VectorFloat vector, float minValue, float maxValue) { + this.vector = vector; + this.minValue = minValue; + this.maxValue = maxValue; + baseline = VectorUtil.nvqUniformLoss(vector, minValue, maxValue, bitsPerDimension.getInt()); + } + + public float computeRaw(float[] x) { + return VectorUtil.nvqLoss(vector, x[0], x[1], minValue, maxValue, bitsPerDimension.getInt()); + } + + public float compute(float[] x) { + return baseline / computeRaw(x); + } + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java index bd6783953..129eebee7 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQDecoder.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQDecoder.java @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 280253978..1dfa554d4 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; @@ -38,7 +38,7 @@ public abstract class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom - + final ProductQuantization pq; protected ByteSequence[] compressedDataChunks; protected int vectorCount; @@ -53,18 +53,18 @@ public static ImmutablePQVectors load(RandomAccessReader in) throws IOException var pq = ProductQuantization.load(in); // read the vectors - int vectorCount = in.readInt(); + int vectorCount = in.readInt(); int compressedDimension = in.readInt(); - + int[] params = calculateChunkParameters(vectorCount, compressedDimension); int vectorsPerChunk = params[0]; int totalChunks = params[1]; int fullSizeChunks = params[2]; int remainingVectors = params[3]; - + ByteSequence[] chunks = new ByteSequence[totalChunks]; int chunkBytes = vectorsPerChunk * compressedDimension; - + for (int i = 0; i < fullSizeChunks; i++) { chunks[i] = vectorTypeSupport.readByteSequence(in, chunkBytes); } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java similarity index 99% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java index c56d91f2e..367d9381d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/ProductQuantization.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.annotations.VisibleForTesting; import io.github.jbellis.jvector.disk.RandomAccessReader; @@ -39,7 +39,7 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import static io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer.UNWEIGHTED; +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; import static io.github.jbellis.jvector.util.MathUtil.square; import static io.github.jbellis.jvector.vector.VectorUtil.dotProduct; import static io.github.jbellis.jvector.vector.VectorUtil.sub; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java similarity index 98% rename from jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java rename to jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java index 3492a97a2..09eb1e035 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/VectorCompressor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/VectorCompressor.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java b/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java index 6205c8faf..0f757f036 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/util/PhysicalCoreExecutor.java @@ -16,7 +16,7 @@ package io.github.jbellis.jvector.util; import io.github.jbellis.jvector.graph.GraphIndexBuilder; -import io.github.jbellis.jvector.pq.ProductQuantization; +import io.github.jbellis.jvector.quantization.ProductQuantization; import java.io.Closeable; import java.util.concurrent.ForkJoinPool; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index a5ec14bb9..db169766e 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -243,6 +243,13 @@ public void addInPlace(VectorFloat v1, VectorFloat v2) { } } + /** Adds value to each element of v1, in place (v1 will be modified) */ + public void addInPlace(VectorFloat v1, float value) { + for (int i = 0; i < v1.length(); i++) { + v1.set(i, v1.get(i) + value); + } + } + @Override public void subInPlace(VectorFloat v1, VectorFloat v2) { for (int i = 0; i < v1.length(); i++) { @@ -250,11 +257,27 @@ public void subInPlace(VectorFloat v1, VectorFloat v2) { } } + @Override + public void subInPlace(VectorFloat vector, float value) { + for (int i = 0; i < vector.length(); i++) { + vector.set(i, vector.get(i) - value); + } + } + @Override public VectorFloat sub(VectorFloat a, VectorFloat b) { return sub(a, 0, b, 0, a.length()); } + @Override + public VectorFloat sub(VectorFloat a, float value) { + VectorFloat result = new ArrayVectorFloat(a.length()); + for (int i = 0; i < a.length(); i++) { + result.set(i, a.get(i) - value); + } + return result; + } + @Override public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length) { VectorFloat result = new ArrayVectorFloat(length); @@ -328,29 +351,194 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat v) { - float max = -Float.MAX_VALUE; - for (int i = 0; i < v.length(); i++) { - max = Math.max(max, v.get(i)); - } - return max; + float max = -Float.MAX_VALUE; + for (int i = 0; i < v.length(); i++) { + max = Math.max(max, v.get(i)); + } + return max; } @Override public float min(VectorFloat v) { - float min = Float.MAX_VALUE; - for (int i = 0; i < v.length(); i++) { - min = Math.min(min, v.get(i)); - } - return min; + float min = Float.MAX_VALUE; + for (int i = 0; i < v.length(); i++) { + min = Math.min(min, v.get(i)); + } + return min; + } + + @Override + public float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { + var delta = maxValue - minValue; + var scaledGrowthRate = growthRate / delta; + var scaledMidpoint = midpoint * delta; + var inverseScaledGrowthRate = 1 / scaledGrowthRate; + var logisticBias = logisticFunctionNQT(minValue, scaledGrowthRate, scaledMidpoint); + var logisticScale = (logisticFunctionNQT(maxValue, scaledGrowthRate, scaledMidpoint) - logisticBias) / 255; + + float dotProd = 0; + float value; + for (int d = 0; d < bytes.length(); d++) { + value = Byte.toUnsignedInt(bytes.get(d)); + value = scaledLogitFunctionNQT(value, inverseScaledGrowthRate, scaledMidpoint, logisticScale, logisticBias); + + dotProd = Math.fma(vector.get(d), value, dotProd); + } + return dotProd; + } + + @Override + public float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { + var delta = maxValue - minValue; + var scaledGrowthRate = growthRate / delta; + var scaledMidpoint = midpoint * delta; + var inverseScaledGrowthRate = 1 / scaledGrowthRate; + var logisticBias = logisticFunctionNQT(minValue, scaledGrowthRate, scaledMidpoint); + var logisticScale = (logisticFunctionNQT(maxValue, scaledGrowthRate, scaledMidpoint) - logisticBias) / 255; + + float squareSum = 0; + + float value; + + for (int d = 0; d < bytes.length(); d++) { + value = Byte.toUnsignedInt(bytes.get(d)); + value = scaledLogitFunctionNQT(value, inverseScaledGrowthRate, scaledMidpoint, logisticScale, logisticBias); + + var temp = value - vector.get(d); + squareSum = Math.fma(temp, temp, squareSum); + } + return squareSum; + } + + @Override + public float[] nvqCosine8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue, VectorFloat centroid) { + var delta = maxValue - minValue; + var scaledGrowthRate = growthRate / delta; + var scaledMidpoint = midpoint * delta; + var inverseScaledGrowthRate = 1 / scaledGrowthRate; + var logisticBias = logisticFunctionNQT(minValue, scaledGrowthRate, scaledMidpoint); + var logisticScale = (logisticFunctionNQT(maxValue, scaledGrowthRate, scaledMidpoint) - logisticBias) / 255; + + float sum = 0; + float normDQ = 0; + + float elem2; + + for (int d = 0; d < bytes.length(); d++) { + elem2 = Byte.toUnsignedInt(bytes.get(d)); + elem2 = scaledLogitFunctionNQT(elem2, inverseScaledGrowthRate, scaledMidpoint, logisticScale, logisticBias); + elem2 += centroid.get(d); + + sum = Math.fma(vector.get(d), elem2, sum); + normDQ = Math.fma(elem2, elem2, normDQ); + } + return new float[]{sum, normDQ}; + } + + @Override + public void nvqShuffleQueryInPlace8bit(VectorFloat vector) {} + + static float logisticFunctionNQT(float value, float alpha, float x0) { + float temp = Math.fma(value, alpha, -alpha * x0); + int p = Math.round(temp + 0.5f); + int m = Float.floatToIntBits(Math.fma(temp - p, 0.5f, 1)); + + temp = Float.intBitsToFloat(m + (p << 23)); // temp = m * 2^p + return temp / (temp + 1); + } + + static float logitNQT(float value, float inverseAlpha, float x0) { + float z = value / (1 - value); + + int temp = Float.floatToIntBits(z); + int e = temp & 0x7f800000; + float p = (float) ((e >> 23) - 128); + float m = Float.intBitsToFloat((temp & 0x007fffff) + 0x3f800000); + + return Math.fma(m + p, inverseAlpha, x0); + } + + static float scaledLogisticFunction(float value, float growthRate, float midpoint, float logisticScale, float logisticBias) { + var y = logisticFunctionNQT(value, growthRate, midpoint); + return (y - logisticBias) * (1 / logisticScale); + } + + static float scaledLogitFunctionNQT(float value, float inverseGrowthRate, float midpoint, float logisticScale, float logisticBias) { + var scaledValue = Math.fma(value, logisticScale, logisticBias); + return logitNQT(scaledValue, inverseGrowthRate, midpoint); + } + + @Override + public void nvqQuantize8bit(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, ByteSequence destination) { + var delta = maxValue - minValue; + var scaledGrowthRate = growthRate / delta; + var scaledMidpoint = midpoint * delta; + var logisticBias = logisticFunctionNQT(minValue, scaledGrowthRate, scaledMidpoint); + var logisticScale = (logisticFunctionNQT(maxValue, scaledGrowthRate, scaledMidpoint) - logisticBias) / 255; + + + for (int d = 0; d < vector.length(); d++) { + // Ensure the quantized value is within the 0 to constant range + float value = vector.get(d); + value = scaledLogisticFunction(value, scaledGrowthRate, scaledMidpoint, logisticScale, logisticBias); + int quantizedValue = Math.round(value); + destination.set(d, (byte) quantizedValue); + } + } + + public float nvqLoss(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, int nBits) { + float constant = (1 << nBits) - 1; + + var delta = maxValue - minValue; + var scaledGrowthRate = growthRate / delta; + var scaledMidpoint = midpoint * delta; + + var logisticBias = logisticFunctionNQT(minValue, scaledGrowthRate, scaledMidpoint); + var logisticScale = (logisticFunctionNQT(maxValue, scaledGrowthRate, scaledMidpoint) - logisticBias) / constant; + var inverseScaledGrowthRate = 1 / scaledGrowthRate; + + float squaredSum = 0.f; + float originalValue, reconstructedValue; + for (int d = 0; d < vector.length(); d++) { + originalValue = vector.get(d); + + reconstructedValue = scaledLogisticFunction(originalValue, scaledGrowthRate, scaledMidpoint, logisticScale, logisticBias); + reconstructedValue = Math.round(reconstructedValue); + reconstructedValue = scaledLogitFunctionNQT(reconstructedValue, inverseScaledGrowthRate, scaledMidpoint, logisticScale, logisticBias); + + var diff = originalValue - reconstructedValue; + squaredSum = Math.fma(diff, diff, squaredSum); + } + + return squaredSum; + } + + public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { + float constant = (1 << nBits) - 1; + + float squaredSum = 0.f; + float originalValue, reconstructedValue; + for (int d = 0; d < vector.length(); d++) { + originalValue = vector.get(d); + + reconstructedValue = (originalValue - minValue) / (maxValue - minValue); + reconstructedValue = Math.round(constant * reconstructedValue) / constant; + reconstructedValue = reconstructedValue * (maxValue - minValue) + minValue; + + var diff = originalValue - reconstructedValue; + squaredSum = Math.fma(diff, diff, squaredSum); + } + + return squaredSum; } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 595cb9153..0b847a344 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -134,14 +134,26 @@ public static void addInPlace(VectorFloat v1, VectorFloat v2) { impl.addInPlace(v1, v2); } + public static void addInPlace(VectorFloat v1, float value) { + impl.addInPlace(v1, value); + } + public static void subInPlace(VectorFloat v1, VectorFloat v2) { impl.subInPlace(v1, v2); } + public static void subInPlace(VectorFloat vector, float value) { + impl.subInPlace(vector, value); + } + public static VectorFloat sub(VectorFloat lhs, VectorFloat rhs) { return impl.sub(lhs, rhs); } + public static VectorFloat sub(VectorFloat lhs, float value) { + return impl.sub(lhs, value); + } + public static VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length) { return impl.sub(a, aOffset, b, bOffset, length); } @@ -198,4 +210,32 @@ public static float min(VectorFloat v) { public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } + + public static float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { + return impl.nvqDotProduct8bit(vector, bytes, growthRate, midpoint, minValue, maxValue); + } + + public static float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { + return impl.nvqSquareL2Distance8bit(vector, bytes, growthRate, midpoint, minValue, maxValue); + } + + public static float[] nvqCosine8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue, VectorFloat centroid) { + return impl.nvqCosine8bit(vector, bytes, growthRate, midpoint, minValue, maxValue, centroid); + } + + public static void nvqShuffleQueryInPlace8bit(VectorFloat vector) { + impl.nvqShuffleQueryInPlace8bit(vector); + } + + public static void nvqQuantize8bit(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, ByteSequence destination) { + impl.nvqQuantize8bit(vector, growthRate, midpoint, minValue, maxValue, destination); + } + + public static float nvqLoss(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, int nBits) { + return impl.nvqLoss(vector, growthRate, midpoint, minValue, maxValue, nBits); + } + + public static float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { + return impl.nvqUniformLoss(vector, minValue, maxValue, nBits); + } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index 320f71a12..2aa3be5d1 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -64,12 +64,21 @@ public interface VectorUtilSupport { /** Adds v2 into v1, in place (v1 will be modified) */ void addInPlace(VectorFloat v1, VectorFloat v2); + /** Adds value to each element of v1, in place (v1 will be modified) */ + void addInPlace(VectorFloat v1, float value); + /** Subtracts v2 from v1, in place (v1 will be modified) */ void subInPlace(VectorFloat v1, VectorFloat v2); + /** Subtracts value from each element of v1, in place (v1 will be modified) */ + void subInPlace(VectorFloat vector, float value); + /** @return a - b, element-wise */ VectorFloat sub(VectorFloat a, VectorFloat b); + /** Subtracts value from each element of a */ + VectorFloat sub(VectorFloat a, float value); + /** @return a - b, element-wise, starting at aOffset and bOffset respectively */ VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length); @@ -213,4 +222,82 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCoun return (float) (sum / Math.sqrt(aMag * bMagnitude)); } + /** + * Computes the dot product between a vector and a 8-bit quantized vector (described by its parameters). + * We assume that the number of dimensions of the vector and the quantized vector match. + * @param vector The query vector + * @param bytes The byte sequence where the quantized vector is stored. + * @param growthRate The growth rate of the logistic function + * @param midpoint the midpoint of the logistic function + * @param minValue The minimum value of the subvector + * @param maxValue The maximum value of the subvector + * @return the dot product + */ + float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue); + + /** + * Computes the squared Euclidean distance between a vector and a 8-bit quantized vector (described by its parameters). + * We assume that the number of dimensions of the vector and the quantized vector match. + * @param vector The query vector + * @param bytes The byte sequence where the quantized vector is stored. + * @param growthRate The growth rate of the logistic function + * @param midpoint the midpoint of the logistic function + * @param minValue The minimum value of the subvector + * @param maxValue The maximum value of the subvector + * @return the squared Euclidean distance + */ + float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue); + + /** + * Computes the cosine similarity between a vector and a 8-bit quantized vector (described by its parameters). + * We assume that the number of dimensions of the vector and the quantized vector match. + * @param vector The query vector + * @param bytes The byte sequence where the quantized vector is stored. + * @param growthRate The growth rate of the logistic function + * @param midpoint the midpoint of the logistic function + * @param minValue The minimum value of the subvector + * @param maxValue The maximum value of the subvector + * @param centroid the global mean vector used to re-center the quantized subvectors. + * @return the cosine similarity + */ + float[] nvqCosine8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue, VectorFloat centroid); + + /** + * When computing distances, the unpacking of am NVQ quantized vector is faster if we do not do it in sequential order. + * Hence, we shuffle the query so that it matches this order + * See this reference + * @param vector the vector to be shuffled + */ + void nvqShuffleQueryInPlace8bit(VectorFloat vector); + + /** + * Quantize a subvector as an 8-bit quantized subvector. + * @param vector The vector to quantized + * @param growthRate The growth rate of the logistic function + * @param midpoint the midpoint of the logistic function + * @param minValue The minimum value of the subvector + * @param maxValue The maximum value of the subvector + * @param destination The vector where the reconstructed values are stored + */ + void nvqQuantize8bit(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, ByteSequence destination); + + /** + * Compute the squared error of quantizing the vector with NVQ. + * @param vector The vector to quantized + * @param growthRate The growth rate of the logistic function + * @param midpoint the midpoint of the logistic function + * @param minValue The minimum value of the subvector + * @param maxValue The maximum value of the subvector + * @param nBits the number of bits per dimension + */ + float nvqLoss(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, int nBits); + + /** + * Compute the squared error of quantizing the vector with a uniform quantizer. + * @param vector The vector to quantized + * @param minValue The minimum value of the subvector + * @param maxValue The maximum value of the subvector + * @param nBits the number of bits per dimension + */ + float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java index d0aedae88..5cd6573a9 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Bench.java @@ -33,7 +33,7 @@ import java.util.regex.Pattern; import java.util.stream.Collectors; -import static io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer.UNWEIGHTED; +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; /** * Tests GraphIndexes against vectors from various datasets @@ -55,7 +55,8 @@ public static void main(String[] args) throws IOException { ds -> new PQParameters(ds.getDimension() / 8, 256, ds.similarityFunction == VectorSimilarityFunction.EUCLIDEAN, UNWEIGHTED) ); List> featureSets = Arrays.asList( - EnumSet.of(FeatureId.INLINE_VECTORS) + EnumSet.of(FeatureId.INLINE_VECTORS), + EnumSet.of(FeatureId.NVQ_VECTORS) ); // args is list of regexes, possibly needing to be split by whitespace. diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java new file mode 100644 index 000000000..f3869069c --- /dev/null +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/DistancesNVQ.java @@ -0,0 +1,152 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.example; + +import io.github.jbellis.jvector.example.util.SiftLoader; +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; +import io.github.jbellis.jvector.vector.types.VectorFloat; + +import java.io.IOException; +import java.util.*; + +import static java.lang.Math.abs; + +// this class uses explicit typing instead of `var` for easier reading when excerpted for instructional use +public class DistancesNVQ { + public static void testNVQEncodings(String filenameBase, String filenameQueries, VectorSimilarityFunction vsf) throws IOException { + List> vectors = SiftLoader.readFvecs(filenameBase); + List> queries = SiftLoader.readFvecs(filenameQueries); + + int dimension = vectors.get(0).length(); + int nQueries = 100; + int nVectors = 10_000; + + vectors = vectors.subList(0, nVectors); + + System.out.format("\t%d base and %d query vectors loaded, dimensions %d%n", + vectors.size(), queries.size(), vectors.get(0).length()); + + // Generate a NVQ for random vectors + var ravv = new ListRandomAccessVectorValues(vectors, dimension); + var nvq = NVQuantization.compute(ravv, 2); + + // Compress the vectors + long startTime = System.nanoTime(); + var nvqVecs = nvq.encodeAll(ravv); + long endTime = System.nanoTime(); + double duration = (double) (endTime - startTime) / 1_000_000_000; + System.out.println("\tEncoding took " + duration + " seconds"); + + // compare the encoded similarities to the raw + double distanceError = 0; + for (int i = 0; i < nQueries; i++) { + var q = queries.get(i); + if (VectorUtil.dotProduct(q, q) == 0) { + continue; + } + var f = nvqVecs.scoreFunctionFor(q, vsf); + + for (int j = 0; j < nVectors; j++) { + var v = vectors.get(j); + distanceError += abs(f.similarityTo(j) - vsf.compare(q, v)); + } + } + distanceError /= nQueries * nVectors; + + System.out.println("\t" + vsf + " error " + distanceError); + + + float dummyAccumulator = 0; + + startTime = System.nanoTime(); + for (int i = 0; i < nQueries; i++) { + var q = queries.get(i); + if (VectorUtil.dotProduct(q, q) == 0) { + continue; + } + var f = nvqVecs.scoreFunctionFor(q, vsf); + + for (int j = 0; j < nVectors; j++) { + dummyAccumulator += f.similarityTo(j); + } + } + endTime = System.nanoTime(); + duration = (double) (endTime - startTime) / 1_000_000_000; + System.out.println("\tNVQ Distance computations took " + duration + " seconds"); + + startTime = System.nanoTime(); + for (int i = 0; i < nQueries; i++) { + var q = queries.get(i); + if (VectorUtil.dotProduct(q, q) == 0) { + continue; + } + + for (int j = 0; j < nVectors; j++) { + var v = vectors.get(j); + dummyAccumulator += vsf.compare(q, v); + } + } + endTime = System.nanoTime(); + duration = (double) (endTime - startTime) / 1_000_000_000; + System.out.println("\tFloat Distance computations took " + duration + " seconds"); + + System.out.println("\tdummyAccumulator: " + dummyAccumulator); + System.out.println("--"); + } + + public static void runSIFT() throws IOException { + System.out.println("Running siftsmall"); + + var baseVectors = "siftsmall/siftsmall_base.fvecs"; + var queryVectors = "siftsmall/siftsmall_query.fvecs"; + testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); + } + + public static void runADA() throws IOException { + System.out.println("Running ada_002"); + + var baseVectors = "./fvec/wikipedia_squad/100k/ada_002_100000_base_vectors.fvec"; + var queryVectors = "./fvec/wikipedia_squad/100k/ada_002_100000_query_vectors_10000.fvec"; + testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); + } + + public static void runColbert() throws IOException { + System.out.println("Running colbertv2"); + + var baseVectors = "./fvec/wikipedia_squad/1M/colbertv2.0_128_base_vectors_1000000.fvec"; + var queryVectors = "./fvec/wikipedia_squad/1M/colbertv2.0_128_query_vectors_100000.fvec"; + testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); + } + + public static void runOpenai3072() throws IOException { + System.out.println("Running text-embedding-3-large_3072"); + + var baseVectors = "./fvec/wikipedia_squad/100k/text-embedding-3-large_3072_100000_base_vectors.fvec"; + var queryVectors = "./fvec/wikipedia_squad/100k/text-embedding-3-large_3072_100000_base_vectors.fvec"; + testNVQEncodings(baseVectors, queryVectors, VectorSimilarityFunction.COSINE); + } + + public static void main(String[] args) throws IOException { + runSIFT(); + runADA(); + runColbert(); + runOpenai3072(); + } +} \ No newline at end of file diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java index 0a76ca784..4ebbc2280 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/Grid.java @@ -30,16 +30,14 @@ import io.github.jbellis.jvector.graph.disk.FeatureId; import io.github.jbellis.jvector.graph.disk.FusedADC; import io.github.jbellis.jvector.graph.disk.InlineVectors; +import io.github.jbellis.jvector.graph.disk.NVQ; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.disk.OrdinalMapper; import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.pq.CompressedVectors; -import io.github.jbellis.jvector.pq.PQVectors; -import io.github.jbellis.jvector.pq.ProductQuantization; -import io.github.jbellis.jvector.pq.VectorCompressor; +import io.github.jbellis.jvector.quantization.*; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.ExplicitThreadLocal; import io.github.jbellis.jvector.util.PhysicalCoreExecutor; @@ -173,10 +171,13 @@ private static Map, GraphIndex> buildOnDisk(List features // no supplier as these will be used for writeInline, when we don't have enough information to fuse neighbors builder.with(new FusedADC(onHeapGraph.maxDegree(), pq)); break; + case NVQ_VECTORS: + var nvq = NVQuantization.compute(floatVectors, 2); + builder.with(new NVQ(nvq)); + suppliers.put(FeatureId.NVQ_VECTORS, ordinal -> new NVQ.State(nvq.encode(floatVectors.getVector(ordinal)))); + break; + } } return new BuilderWithSuppliers(builder, suppliers); diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java index d217cef0b..9e89db81d 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/IPCService.java @@ -29,9 +29,9 @@ import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.pq.CompressedVectors; -import io.github.jbellis.jvector.pq.PQVectors; -import io.github.jbellis.jvector.pq.ProductQuantization; +import io.github.jbellis.jvector.quantization.CompressedVectors; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.PhysicalCoreExecutor; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index a216a4803..334ef1f1d 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -31,6 +31,7 @@ import io.github.jbellis.jvector.graph.disk.Feature; import io.github.jbellis.jvector.graph.disk.FeatureId; import io.github.jbellis.jvector.graph.disk.InlineVectors; +import io.github.jbellis.jvector.graph.disk.NVQ; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; import io.github.jbellis.jvector.graph.disk.OrdinalMapper; @@ -38,9 +39,10 @@ import io.github.jbellis.jvector.graph.similarity.ScoreFunction.ApproximateScoreFunction; import io.github.jbellis.jvector.graph.similarity.ScoreFunction.ExactScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.pq.MutablePQVectors; -import io.github.jbellis.jvector.pq.PQVectors; -import io.github.jbellis.jvector.pq.ProductQuantization; +import io.github.jbellis.jvector.quantization.MutablePQVectors; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.ExceptionUtils; import io.github.jbellis.jvector.util.ExplicitThreadLocal; @@ -264,6 +266,67 @@ public static void siftDiskAnnLTM(List> baseVectors, List> baseVectors, List> queryVectors, List> groundTruth) throws IOException { + int originalDimension = baseVectors.get(0).length(); + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(baseVectors, originalDimension); + + // compute the codebook, but don't encode any vectors yet + ProductQuantization pq = ProductQuantization.compute(ravv, 16, 256, true); + + var nvq = NVQuantization.compute(ravv, 2); + + // as we build the index we'll compress the new vectors and add them to this List backing a PQVectors; + // this is used to score the construction searches + var pqv = new MutablePQVectors(pq); + BuildScoreProvider bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqv); + + Path indexPath = Files.createTempFile("siftsmall", ".inline"); + Path pqPath = Files.createTempFile("siftsmall", ".pq"); + // Builder creation looks mostly the same + try (GraphIndexBuilder builder = new GraphIndexBuilder(bsp, ravv.dimension(), 16, 100, 1.2f, 1.2f); + // explicit Writer for the first time, this is what's behind OnDiskGraphIndex.write + OnDiskGraphIndexWriter writer = new OnDiskGraphIndexWriter.Builder(builder.getGraph(), indexPath) + .with(new NVQ(nvq)) + .withMapper(new OrdinalMapper.IdentityMapper(baseVectors.size() - 1)) + .build(); + // output for the compressed vectors + DataOutputStream pqOut = new DataOutputStream(new BufferedOutputStream(Files.newOutputStream(pqPath)))) + { + // build the index vector-at-a-time (on disk) + for (int ordinal = 0; ordinal < baseVectors.size(); ordinal++) { + VectorFloat v = baseVectors.get(ordinal); + // compress the new vector and add it to the PQVectors + pqv.encodeAndSet(ordinal, v); + // write the full vector to disk + writer.writeInline(ordinal, Feature.singleState(FeatureId.NVQ_VECTORS, new NVQ.State(nvq.encode(v)))); + // now add it to the graph -- the previous steps must be completed first since the PQVectors + // and InlineVectorValues are both used during the search that runs as part of addGraphNode construction + builder.addGraphNode(ordinal, v); + } + + // cleanup does a final enforcement of maxDegree and handles other scenarios like deleted nodes + // that we don't need to worry about here + builder.cleanup(); + + // finish writing the index (by filling in the edge lists) and write our completed PQVectors + writer.write(Map.of()); + pqv.write(pqOut); + } + + // searching the index does not change + ReaderSupplier rs = new MMapReader.Supplier(indexPath); + OnDiskGraphIndex index = OnDiskGraphIndex.load(rs); + try (RandomAccessReader in = new SimpleMappedReader(pqPath)) { + var pqvSearch = PQVectors.load(in); + Function, SearchScoreProvider> sspFactory = q -> { + ApproximateScoreFunction asf = pqvSearch.precomputedScoreFunctionFor(q, VectorSimilarityFunction.EUCLIDEAN); + ExactScoreFunction reranker = index.getView().rerankerFor(q, VectorSimilarityFunction.EUCLIDEAN); + return new SearchScoreProvider(asf, reranker); + }; + testRecall(index, queryVectors, groundTruth, sspFactory); + } + } + // // Utilities and main() harness // @@ -326,5 +389,6 @@ public static void main(String[] args) throws IOException { siftPersisted(baseVectors, queryVectors, groundTruth); siftDiskAnn(baseVectors, queryVectors, groundTruth); siftDiskAnnLTM(baseVectors, queryVectors, groundTruth); + siftDiskAnnLTMWithNVQ(baseVectors, queryVectors, groundTruth); } } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java index fc748a929..e1ffebb9b 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/CompressorParameters.java @@ -16,9 +16,10 @@ package io.github.jbellis.jvector.example.util; -import io.github.jbellis.jvector.pq.BinaryQuantization; -import io.github.jbellis.jvector.pq.ProductQuantization; -import io.github.jbellis.jvector.pq.VectorCompressor; +import io.github.jbellis.jvector.quantization.BinaryQuantization; +import io.github.jbellis.jvector.quantization.NVQuantization; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.quantization.VectorCompressor; public abstract class CompressorParameters { public static final CompressorParameters NONE = new NoCompressionParameters(); @@ -70,6 +71,29 @@ public VectorCompressor computeCompressor(DataSet ds) { } } + public static class NVQParameters extends CompressorParameters { + private final int nSubVectors; + + public NVQParameters(int nSubVectors) { + this.nSubVectors = nSubVectors; + } + + @Override + public VectorCompressor computeCompressor(DataSet ds) { + return NVQuantization.compute(ds.getBaseRavv(), nSubVectors); + } + + @Override + public String idStringFor(DataSet ds) { + return String.format("NVQ_%s_%d_%s", ds.name, nSubVectors); + } + + @Override + public boolean supportsCaching() { + return true; + } + } + private static class NoCompressionParameters extends CompressorParameters { @Override public VectorCompressor computeCompressor(DataSet ds) { diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 335c1b1c9..a0c070cfc 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -77,11 +77,21 @@ public void addInPlace(VectorFloat v1, VectorFloat v2) { VectorSimdOps.addInPlace((MemorySegmentVectorFloat)v1, (MemorySegmentVectorFloat)v2); } + @Override + public void addInPlace(VectorFloat vector, float value) { + VectorSimdOps.addInPlace((MemorySegmentVectorFloat) vector, value); + } + @Override public void subInPlace(VectorFloat v1, VectorFloat v2) { VectorSimdOps.subInPlace((MemorySegmentVectorFloat)v1, (MemorySegmentVectorFloat)v2); } + @Override + public void subInPlace(VectorFloat vector, float value) { + VectorSimdOps.subInPlace((MemorySegmentVectorFloat)vector, value); + } + @Override public VectorFloat sub(VectorFloat a, VectorFloat b) { if (a.length() != b.length()) { @@ -90,6 +100,11 @@ public VectorFloat sub(VectorFloat a, VectorFloat b) { return sub(a, 0, b, 0, a.length()); } + @Override + public VectorFloat sub(VectorFloat a, float value) { + return VectorSimdOps.sub((MemorySegmentVectorFloat) a, 0, value, a.length()); + } + @Override public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length) { return VectorSimdOps.sub((MemorySegmentVectorFloat) a, aOffset, (MemorySegmentVectorFloat) b, bOffset, length); @@ -165,4 +180,49 @@ public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount assert encoded.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + encoded.offset(); return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); } + + @Override + public float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { + return VectorSimdOps.nvqDotProduct8bit( + (MemorySegmentVectorFloat) vector, (MemorySegmentByteSequence) bytes, + growthRate, midpoint, minValue, maxValue + ); + } + + @Override + public float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue) { + return VectorSimdOps.nvqSquareDistance8bit( + (MemorySegmentVectorFloat) vector, (MemorySegmentByteSequence) bytes, + growthRate, midpoint, minValue, maxValue + ); + } + + @Override + public float[] nvqCosine8bit(VectorFloat vector, ByteSequence bytes, float growthRate, float midpoint, float minValue, float maxValue, VectorFloat centroid) { + return VectorSimdOps.nvqCosine8bit( + (MemorySegmentVectorFloat) vector, (MemorySegmentByteSequence) bytes, + growthRate, midpoint, minValue, maxValue, + (MemorySegmentVectorFloat) centroid + ); + } + + @Override + public void nvqShuffleQueryInPlace8bit(VectorFloat vector) { + VectorSimdOps.nvqShuffleQueryInPlace8bit((MemorySegmentVectorFloat) vector); + } + + @Override + public void nvqQuantize8bit(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, ByteSequence destination) { + VectorSimdOps.nvqQuantize8bit((MemorySegmentVectorFloat) vector, growthRate, midpoint, minValue, maxValue, (MemorySegmentByteSequence) destination); + } + + @Override + public float nvqLoss(VectorFloat vector, float growthRate, float midpoint, float minValue, float maxValue, int nBits) { + return VectorSimdOps.nvqLoss((MemorySegmentVectorFloat) vector, growthRate, midpoint, minValue, maxValue, nBits); + } + + @Override + public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { + return VectorSimdOps.nvqUniformLoss((MemorySegmentVectorFloat) vector, minValue, maxValue, nBits); + } } diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java index 999105c1a..65693a002 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/VectorSimdOps.java @@ -16,11 +16,14 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.util.MathUtil; import io.github.jbellis.jvector.vector.types.VectorFloat; +import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorOperators; import java.nio.ByteOrder; @@ -447,6 +450,11 @@ static void addInPlace64(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v a.add(b).intoMemorySegment(v1.get(), v1.offset(0), ByteOrder.LITTLE_ENDIAN); } + static void addInPlace64(MemorySegmentVectorFloat v1, float value) { + var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_64, v1.get(), 0, ByteOrder.LITTLE_ENDIAN); + a.add(value).intoMemorySegment(v1.get(), v1.offset(0), ByteOrder.LITTLE_ENDIAN); + } + static void addInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) { if (v1.length() != v2.length()) { throw new IllegalArgumentException("Vectors must have the same length"); @@ -472,6 +480,26 @@ static void addInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) } } + static void addInPlace(MemorySegmentVectorFloat v1, float value) { + if (v1.length() == 2) { + addInPlace64(v1, value); + return; + } + + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length()); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN); + a.add(value).intoMemorySegment(v1.get(), v1.offset(i), ByteOrder.LITTLE_ENDIAN); + } + + // Process the tail + for (int i = vectorizedLength; i < v1.length(); i++) { + v1.set(i, v1.get(i) + value); + } + } + static VectorFloat sub(MemorySegmentVectorFloat a, int aOffset, MemorySegmentVectorFloat b, int bOffset, int length) { MemorySegmentVectorFloat result = new MemorySegmentVectorFloat(length); int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(length); @@ -492,6 +520,25 @@ static VectorFloat sub(MemorySegmentVectorFloat a, int aOffset, MemorySegment return result; } + static VectorFloat sub(MemorySegmentVectorFloat a, int aOffset, float value, int length) { + MemorySegmentVectorFloat result = new MemorySegmentVectorFloat(length); + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(length); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var lhs = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, a.get(), a.offset(aOffset + i), ByteOrder.LITTLE_ENDIAN); + var subResult = lhs.sub(value); + subResult.intoMemorySegment(result.get(), result.offset(i), ByteOrder.LITTLE_ENDIAN); + } + + // Process the tail + for (int i = vectorizedLength; i < length; i++) { + result.set(i, a.get(aOffset + i) - value); + } + + return result; + } + static void subInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) { if (v1.length() != v2.length()) { throw new IllegalArgumentException("Vectors must have the same length"); @@ -512,6 +559,22 @@ static void subInPlace(MemorySegmentVectorFloat v1, MemorySegmentVectorFloat v2) } } + static void subInPlace(MemorySegmentVectorFloat vector, float value) { + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var a = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), vector.offset(i), ByteOrder.LITTLE_ENDIAN); + a.sub(value).intoMemorySegment(vector.get(), vector.offset(i), ByteOrder.LITTLE_ENDIAN); + } + + // Process the tail + for (int i = vectorizedLength; i < vector.length(); i++) { + vector.set(i, vector.get(i) - value); + } + + } + public static int hammingDistance(long[] a, long[] b) { var sum = LongVector.zero(LongVector.SPECIES_PREFERRED); int vectorizedLength = LongVector.SPECIES_PREFERRED.loopBound(a.length); @@ -586,4 +649,369 @@ public static void quantizePartials(float delta, MemorySegmentVectorFloat partia } } } + + //--------------------------------------------- + // NVQ quantization instructions start here + //--------------------------------------------- + static final FloatVector const1f = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, 1.f); + static final FloatVector const05f = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, 0.5f); + + static FloatVector logisticNQT(FloatVector vector, float alpha, float x0) { + FloatVector temp = vector.fma(alpha, -alpha * x0); + VectorMask isPositive = temp.test(VectorOperators.IS_NEGATIVE).not(); + IntVector p = temp.add(1, isPositive) + .convert(VectorOperators.F2I, 0) + .reinterpretAsInts(); + FloatVector e = p.convert(VectorOperators.I2F, 0).reinterpretAsFloats(); + IntVector m = temp.sub(e).fma(0.5f, 1).reinterpretAsInts(); + + temp = m.add(p.lanewise(VectorOperators.LSHL, 23)).reinterpretAsFloats(); // temp = m * 2^p + return temp.div(temp.add(1)); + } + + static float logisticNQT(float value, float alpha, float x0) { + float temp = Math.fma(value, alpha, -alpha * x0); + int p = (int) Math.floor(temp + 1); + int m = Float.floatToIntBits(Math.fma(temp - p, 0.5f, 1)); + + temp = Float.intBitsToFloat(m + (p << 23)); // temp = m * 2^p + return temp / (temp + 1); + } + + static FloatVector logitNQT(FloatVector vector, float inverseAlpha, float x0) { + FloatVector z = vector.div(const1f.sub(vector)); + + IntVector temp = z.reinterpretAsInts(); + FloatVector p = temp.and(0x7f800000) + .lanewise(VectorOperators.LSHR, 23).sub(128) + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + FloatVector m = temp.lanewise(VectorOperators.AND, 0x007fffff).add(0x3f800000).reinterpretAsFloats(); + + return m.add(p).fma(inverseAlpha, x0); + } + + static float logitNQT(float value, float inverseAlpha, float x0) { + float z = value / (1 - value); + + int temp = Float.floatToIntBits(z); + int e = temp & 0x7f800000; + float p = (float) ((e >> 23) - 128); + float m = Float.intBitsToFloat((temp & 0x007fffff) + 0x3f800000); + + return Math.fma(m + p, inverseAlpha, x0); + } + + static FloatVector nvqDequantize8bit(ByteVector bytes, float inverseAlpha, float x0, float logisticScale, float logisticBias, int part) { + /* + * We unpack the vector using the FastLanes strategy: + * https://www.vldb.org/pvldb/vol16/p2132-afroozeh.pdf?ref=blog.lancedb.com + * + * We treat the ByteVector bytes as a vector of integers. + * | Int0 | Int1 | ... + * | Byte3 Byte2 Byte1 Byte0 | Byte3 Byte2 Byte1 Byte0 | ... + * + * The argument part indicates which byte we want to extract from each integer. + * With part=0, we extract + * Int0\Byte0, Int1\Byte0, etc. + * With part=1, we shift by 8 bits and then extract + * Int0\Byte1, Int1\Byte1, etc. + * With part=2, we shift by 16 bits and then extract + * Int0\Byte2, Int1\Byte2, etc. + * With part=3, we shift by 24 bits and then extract + * Int0\Byte3, Int1\Byte3, etc. + */ + var arr = bytes.reinterpretAsInts() + .lanewise(VectorOperators.LSHR, 8 * part) + .lanewise(VectorOperators.AND, 0xff) + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + + arr = arr.fma(logisticScale, logisticBias); + return logitNQT(arr, inverseAlpha, x0); + } + + static void nvqQuantize8bit(MemorySegmentVectorFloat vector, float alpha, float x0, float minValue, float maxValue, MemorySegmentByteSequence destination) { + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + final var mask = ByteVector.SPECIES_PREFERRED.indexInRange(0, FloatVector.SPECIES_PREFERRED.length()); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var invLogisticScale = 255 / (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias); + + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i, ByteOrder.LITTLE_ENDIAN); + arr = logisticNQT(arr, scaledAlpha, scaledX0); + arr = arr.sub(logisticBias).mul(invLogisticScale); + var bytes = arr.add(const05f) + .convertShape(VectorOperators.F2B, ByteVector.SPECIES_PREFERRED, 0) + .reinterpretAsBytes(); + bytes.intoMemorySegment(destination.get(), i, ByteOrder.LITTLE_ENDIAN, mask); + } + + // Process the tail + for (int d = vectorizedLength; d < vector.length(); d++) { + // Ensure the quantized value is within the 0 to constant range + float value = vector.get(d); + value = logisticNQT(value, scaledAlpha, scaledX0); + value = (value - logisticBias) * invLogisticScale; + int quantizedValue = Math.round(value); + destination.set(d, (byte) quantizedValue); + } + } + + static float nvqLoss(MemorySegmentVectorFloat vector, float alpha, float x0, float minValue, float maxValue, int nBits) { + int constant = (1 << nBits) - 1; + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + + FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / constant; + var invLogisticScale = 1 / logisticScale; + + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i, ByteOrder.LITTLE_ENDIAN); + var recArr = logisticNQT(arr, scaledAlpha, scaledX0); + recArr = recArr.sub(logisticBias).mul(invLogisticScale); + recArr = recArr.add(const05f) + .convert(VectorOperators.F2I, 0) + .reinterpretAsInts() + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + recArr = recArr.fma(logisticScale, logisticBias); + recArr = logitNQT(recArr, invScaledAlpha, scaledX0); + + var diff = arr.sub(recArr); + squaredSumVec = diff.fma(diff, squaredSumVec); + } + + float squaredSum = squaredSumVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value, recValue; + for (int i = vectorizedLength; i < vector.length(); i++) { + value = vector.get(i); + + recValue = logisticNQT(value, scaledAlpha, scaledX0); + recValue = (recValue - logisticBias) * invLogisticScale; + recValue = Math.round(recValue); + recValue = Math.fma(logisticScale, recValue, logisticBias); + recValue = logitNQT(recValue, scaledAlpha, scaledX0); + + squaredSum += MathUtil.square(value - recValue); + } + + return squaredSum; + } + + static float nvqUniformLoss(MemorySegmentVectorFloat vector, float minValue, float maxValue, int nBits) { + float constant = (1 << nBits) - 1; + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + + FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var arr = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i, ByteOrder.LITTLE_ENDIAN); + var recArr = arr.sub(minValue).mul(constant / (maxValue - minValue)); + recArr = recArr.add(const05f) + .convert(VectorOperators.F2I, 0) + .reinterpretAsInts() + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + recArr = recArr.fma((maxValue - minValue) / constant, minValue); + + var diff = arr.sub(recArr); + squaredSumVec = diff.fma(diff, squaredSumVec); + } + + float squaredSum = squaredSumVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value, recValue; + for (int i = vectorizedLength; i < vector.length(); i++) { + value = vector.get(i); + + recValue = (value - minValue) / (maxValue - minValue); + recValue = Math.round(constant * recValue) / constant; + recValue = recValue / (maxValue - minValue) + minValue; + + squaredSum += MathUtil.square(value - recValue); + } + + return squaredSum; + } + + static float nvqSquareDistance8bit(MemorySegmentVectorFloat vector, MemorySegmentByteSequence quantizedVector, + float alpha, float x0, float minValue, float maxValue) { + FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(quantizedVector.length()); + int floatStep = FloatVector.SPECIES_PREFERRED.length(); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255; + + for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) { + var byteArr = ByteVector.fromMemorySegment(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i, ByteOrder.LITTLE_ENDIAN); + + for (int j = 0; j < 4; j++) { + var v1 = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN); + var v2 = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j); + + var diff = v1.sub(v2); + squaredSumVec = diff.fma(diff, squaredSumVec); + } + } + + float squaredSum = squaredSumVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value2, diff; + for (int i = vectorizedLength; i < quantizedVector.length(); i++) { + value2 = Byte.toUnsignedInt(quantizedVector.get(i)); + value2 = Math.fma(logisticScale, value2, logisticBias); + value2 = logitNQT(value2, scaledAlpha, scaledX0); + diff = vector.get(i) - value2; + squaredSum += MathUtil.square(diff); + } + + return squaredSum; + } + + + static float nvqDotProduct8bit(MemorySegmentVectorFloat vector, MemorySegmentByteSequence quantizedVector, + float alpha, float x0, float minValue, float maxValue) { + FloatVector dotProdVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(quantizedVector.length()); + int floatStep = FloatVector.SPECIES_PREFERRED.length(); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255; + + + for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) { + var byteArr = ByteVector.fromMemorySegment(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i, ByteOrder.LITTLE_ENDIAN); + + for (int j = 0; j < 4; j++) { + var v1 = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN); + var v2 = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j); + dotProdVec = v1.fma(v2, dotProdVec); + } + } + + float dotProd = dotProdVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value2; + for (int i = vectorizedLength; i < quantizedVector.length(); i++) { + value2 = Byte.toUnsignedInt(quantizedVector.get(i)); + value2 = Math.fma(logisticScale, value2, logisticBias); + value2 = logitNQT(value2, scaledAlpha, scaledX0); + dotProd = Math.fma(vector.get(i), value2, dotProd); + } + + return dotProd; + } + + static float[] nvqCosine8bit(MemorySegmentVectorFloat vector, MemorySegmentByteSequence quantizedVector, + float alpha, float x0, float minValue, float maxValue, + MemorySegmentVectorFloat centroid) { + if (vector.length() != centroid.length()) { + throw new IllegalArgumentException("Vectors must have the same length"); + } + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255; + + var vsum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + var vbMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(vector.length()); + int floatStep = FloatVector.SPECIES_PREFERRED.length(); + + for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) { + var byteArr = ByteVector.fromMemorySegment(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i, ByteOrder.LITTLE_ENDIAN); + + for (int j = 0; j < 4; j++) { + var va = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN); + var vb = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j); + + var vCentroid = FloatVector.fromMemorySegment(FloatVector.SPECIES_PREFERRED, centroid.get(), i + floatStep * j, ByteOrder.LITTLE_ENDIAN); + vb = vb.add(vCentroid); + + vsum = va.fma(vb, vsum); + vbMagnitude = vb.fma(vb, vbMagnitude); + } + } + + float sum = vsum.reduceLanes(VectorOperators.ADD); + float bMagnitude = vbMagnitude.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value2; + for (int i = vectorizedLength; i < vector.length(); i++) { + value2 = Byte.toUnsignedInt(quantizedVector.get(i)); + value2 = Math.fma(logisticScale, value2, logisticBias); + value2 = logitNQT(value2, scaledAlpha, scaledX0) + centroid.get(i); + sum = Math.fma(vector.get(i), value2, sum); + bMagnitude = Math.fma(value2, value2, bMagnitude); + } + + // TODO can we avoid returning a new array? + return new float[]{sum, bMagnitude}; + } + + static void transpose(MemorySegmentVectorFloat arr, int first, int last, int nRows) { + final int mn1 = (last - first - 1); + final int n = (last - first) / nRows; + boolean[] visited = new boolean[last - first]; + float temp; + int cycle = first; + while (++cycle != last) { + if (visited[cycle - first]) + continue; + int a = cycle - first; + do { + a = a == mn1 ? mn1 : (n * a) % mn1; + temp = arr.get(first + a); + arr.set(first + a, arr.get(cycle)); + arr.set(cycle, temp); + visited[a] = true; + } while ((first + a) != cycle); + } + } + + static void nvqShuffleQueryInPlace8bit(MemorySegmentVectorFloat vector) { + // To understand this shuffle, see nvqDequantize8bit + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + final int step = FloatVector.SPECIES_PREFERRED.length() * 4; + + for (int i = 0; i + step <= vectorizedLength; i += step) { + transpose(vector, i, i + step, 4); + } + } + + //--------------------------------------------- + // NVQ quantization instructions end here + //--------------------------------------------- } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java index 8b1d1c070..1068c86e1 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/TestUtil.java @@ -28,7 +28,7 @@ import io.github.jbellis.jvector.graph.disk.InlineVectors; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndex; import io.github.jbellis.jvector.graph.disk.OnDiskGraphIndexWriter; -import io.github.jbellis.jvector.pq.PQVectors; +import io.github.jbellis.jvector.quantization.PQVectors; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.vector.VectorUtil; import io.github.jbellis.jvector.vector.VectorizationProvider; @@ -134,6 +134,18 @@ public static List> createRandomVectors(int count, int dimension) return IntStream.range(0, count).mapToObj(i -> TestUtil.randomVector(getRandom(), dimension)).collect(Collectors.toList()); } + public static VectorFloat normalRandomVector(Random random, int dim) { + var vec = vectorTypeSupport.createFloatVector(dim); + for (int i = 0; i < dim; i++) { + vec.set(i, (float) random.nextGaussian()); + } + return vec; + } + + public static List> createNormalRandomVectors(int count, int dimension) { + return IntStream.range(0, count).mapToObj(i -> TestUtil.normalRandomVector(getRandom(), dimension)).collect(Collectors.toList()); + } + public static void writeGraph(GraphIndex graph, RandomAccessVectorValues ravv, Path outputPath) throws IOException { OnDiskGraphIndex.write(graph, ravv, outputPath); } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java index 7822440a6..bceb5247f 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/TestVectorGraph.java @@ -30,8 +30,8 @@ import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; -import io.github.jbellis.jvector.pq.PQVectors; -import io.github.jbellis.jvector.pq.ProductQuantization; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.util.Bits; import io.github.jbellis.jvector.util.BoundedLongHeap; import io.github.jbellis.jvector.util.FixedBitSet; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java index bc02f5bfb..ac553335a 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/graph/disk/TestOnDiskGraphIndex.java @@ -24,8 +24,8 @@ import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; import io.github.jbellis.jvector.graph.RandomAccessVectorValues; import io.github.jbellis.jvector.graph.TestVectorGraph; -import io.github.jbellis.jvector.pq.PQVectors; -import io.github.jbellis.jvector.pq.ProductQuantization; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; import org.junit.After; import org.junit.Before; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestADCGraphIndex.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java similarity index 99% rename from jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestADCGraphIndex.java rename to jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java index aea0c2761..f301937f2 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestADCGraphIndex.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestADCGraphIndex.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java similarity index 59% rename from jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java rename to jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java index 2bb434df3..52586fa49 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestCompressedVectors.java @@ -14,15 +14,16 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import io.github.jbellis.jvector.TestUtil; import io.github.jbellis.jvector.disk.SimpleMappedReader; import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; -import io.github.jbellis.jvector.graph.similarity.ScoreFunction; import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorUtil; +import io.github.jbellis.jvector.vector.types.VectorFloat; import org.junit.Test; import java.io.DataOutputStream; @@ -31,6 +32,7 @@ import java.util.List; import static io.github.jbellis.jvector.TestUtil.createRandomVectors; +import static io.github.jbellis.jvector.TestUtil.createNormalRandomVectors; import static io.github.jbellis.jvector.TestUtil.nextInt; import static java.lang.Math.abs; import static java.lang.Math.log; @@ -87,7 +89,47 @@ public void testSaveLoadBQ() throws Exception { } } - private void testEncodings(int dimension, int codebooks) { + @Test + public void testSaveLoadNVQ() throws Exception { + + int[][] testsConfigAndResults = { + //Tuples of: nDimensions, nSubvectors, and the expected number of bytes + {64, 1, 96}, + {64, 2, 124}, + {65, 1, 97}, + }; + + for (int[] testConfigAndResult : testsConfigAndResults) { + var nDimensions = testConfigAndResult[0]; + var nSubvectors = testConfigAndResult[1]; + var expectedSize = testConfigAndResult[2]; + + // Generate an NVQ for random vectors + var vectors = createRandomVectors(512, nDimensions); + var ravv = new ListRandomAccessVectorValues(vectors, nDimensions); + + var nvq = NVQuantization.compute(ravv, nSubvectors); + + // Compress the vectors + var cv = nvq.encodeAll(ravv); + assertEquals(nDimensions * Float.BYTES, cv.getOriginalSize()); + assertEquals(expectedSize, cv.getCompressedSize()); + + // Write compressed vectors + File cvFile = File.createTempFile("bqtest", ".cv"); + try (var out = new DataOutputStream(new FileOutputStream(cvFile))) { + cv.write(out); + } + // Read compressed vectors + try (var in = new SimpleMappedReader(cvFile.getAbsolutePath())) { + var cv2 = NVQVectors.load(in, 0); + assertEquals(cv, cv2); + } + } + } + + + private void testPQEncodings(int dimension, int codebooks) { // Generate a PQ for random vectors var vectors = createRandomVectors(512, dimension); var ravv = new ListRandomAccessVectorValues(vectors, dimension); @@ -118,11 +160,70 @@ private void testEncodings(int dimension, int codebooks) { } @Test - public void testEncodings() { + public void testPQEncodings() { // start with i=2 (dimension 4) b/c dimension 2 is an outlier for our error prediction for (int i = 2; i <= 8; i++) { for (int M = 1; M <= i; M++) { - testEncodings(2 * i, M); + testPQEncodings(2 * i, M); + } + } + } + + private void testNVQEncodings(List> vectors, List> queries, int nSubvectors, + boolean learn) { + int dimension = vectors.get(0).length(); + int nQueries = queries.size(); + + // Generate a NVQ for random vectors + var ravv = new ListRandomAccessVectorValues(vectors, dimension); + var nvq = NVQuantization.compute(ravv, nSubvectors); + nvq.learn = learn; + + // Compress the vectors + var cv = nvq.encodeAll(ravv); + + // compare the encoded similarities to the raw + for (var vsf : List.of(VectorSimilarityFunction.EUCLIDEAN, VectorSimilarityFunction.DOT_PRODUCT, VectorSimilarityFunction.COSINE)) { + double error = 0; + for (int i = 0; i < nQueries; i++) { + var q = queries.get(i); + VectorUtil.l2normalize(q); + var f = cv.scoreFunctionFor(q, vsf); + for (int j = 0; j < vectors.size(); j++) { + var v = vectors.get(j); + vsf.compare(q, v); + if (vsf == VectorSimilarityFunction.DOT_PRODUCT) { + error += abs(f.similarityTo(j) - vsf.compare(q, v)) / abs(vsf.compare(v, v)); + } else { + error += abs(f.similarityTo(j) - vsf.compare(q, v)); + } + } + } + error /= nQueries * vectors.size(); + + float tolerance = 0.0005f * (dimension / 256.f); + if (vsf == VectorSimilarityFunction.COSINE) { + tolerance *= 10; + } else if (vsf == VectorSimilarityFunction.DOT_PRODUCT) { + tolerance *= 4; + } + System.out.println(vsf + " error " + error + " tolerance " + tolerance); + assert error <= tolerance : String.format("%s > %s for %s with %d dimensions and %d subvectors", error, tolerance, vsf, dimension, nSubvectors); + } + System.out.println("--"); + } + + @Test + public void testNVQEncodings() { + for (int d = 256; d <= 2048; d += 256) { + var vectors = createNormalRandomVectors(512, d); + var queries = createNormalRandomVectors(10, d); + + for (var nSubvectors : List.of(1, 2, 4, 8)) { + for (var learn : List.of(false, true)) { + System.out.println("dimensions: " + d + " subvectors: " + nSubvectors + " learn: " + learn); + testNVQEncodings(vectors, queries, nSubvectors, learn); + } } } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java similarity index 97% rename from jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java rename to jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java index 0c9a0dde3..8f7bbb2c8 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestProductQuantization.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/quantization/TestProductQuantization.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package io.github.jbellis.jvector.pq; +package io.github.jbellis.jvector.quantization; import com.carrotsearch.randomizedtesting.RandomizedTest; import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; @@ -39,9 +39,9 @@ import static io.github.jbellis.jvector.TestUtil.createRandomVectors; import static io.github.jbellis.jvector.TestUtil.randomVector; -import static io.github.jbellis.jvector.pq.KMeansPlusPlusClusterer.UNWEIGHTED; -import static io.github.jbellis.jvector.pq.ProductQuantization.DEFAULT_CLUSTERS; -import static io.github.jbellis.jvector.pq.ProductQuantization.getSubvectorSizesAndOffsets; +import static io.github.jbellis.jvector.quantization.KMeansPlusPlusClusterer.UNWEIGHTED; +import static io.github.jbellis.jvector.quantization.ProductQuantization.DEFAULT_CLUSTERS; +import static io.github.jbellis.jvector.quantization.ProductQuantization.getSubvectorSizesAndOffsets; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; @@ -286,7 +286,7 @@ public void testPQVectorsChunkCalculation() { assertEquals(1, params[1]); // numChunks assertEquals(1, params[2]); // fullSizeChunks assertEquals(0, params[3]); // remainingVectors - + // Test case requiring multiple chunks int bigVectorCount = Integer.MAX_VALUE - 1; int smallDim = 8; @@ -294,14 +294,14 @@ public void testPQVectorsChunkCalculation() { validateChunkMath(params, bigVectorCount, smallDim); assertTrue(params[0] > 0); assertTrue(params[1] > 1); - + // Test edge case with large dimension int smallVectorCount = 1000; int bigDim = Integer.MAX_VALUE / 2; params = PQVectors.calculateChunkParameters(smallVectorCount, bigDim); validateChunkMath(params, smallVectorCount, bigDim); assertTrue(params[0] > 0); - + // Test invalid inputs assertThrows(IllegalArgumentException.class, () -> PQVectors.calculateChunkParameters(-1, 8)); assertThrows(IllegalArgumentException.class, () -> PQVectors.calculateChunkParameters(100, -1)); diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 2d0d53893..f18266e5f 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -72,11 +72,21 @@ public void addInPlace(VectorFloat v1, VectorFloat v2) { SimdOps.addInPlace((ArrayVectorFloat)v1, (ArrayVectorFloat)v2); } + @Override + public void addInPlace(VectorFloat v1, float value) { + SimdOps.addInPlace((ArrayVectorFloat)v1, value); + } + @Override public void subInPlace(VectorFloat v1, VectorFloat v2) { SimdOps.subInPlace((ArrayVectorFloat) v1, (ArrayVectorFloat) v2); } + @Override + public void subInPlace(VectorFloat vector, float value) { + SimdOps.subInPlace((ArrayVectorFloat) vector, value); + } + @Override public VectorFloat sub(VectorFloat a, VectorFloat b) { if (a.length() != b.length()) { @@ -85,6 +95,11 @@ public VectorFloat sub(VectorFloat a, VectorFloat b) { return SimdOps.sub((ArrayVectorFloat)a, 0, (ArrayVectorFloat)b, 0, a.length()); } + @Override + public VectorFloat sub(VectorFloat a, float value) { + return SimdOps.sub((ArrayVectorFloat)a, 0, value, a.length()); + } + @Override public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int bOffset, int length) { return SimdOps.sub((ArrayVectorFloat) a, aOffset, (ArrayVectorFloat) b, bOffset, length); @@ -161,5 +176,48 @@ public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount { return SimdOps.pqDecodedCosineSimilarity((ByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); } + + @Override + public float nvqDotProduct8bit(VectorFloat vector, ByteSequence bytes, float alpha, float x0, float minValue, float maxValue) { + return SimdOps.nvqDotProduct8bit( + (ArrayVectorFloat) vector, (ArrayByteSequence) bytes, + alpha, x0, minValue, maxValue); + } + + @Override + public float nvqSquareL2Distance8bit(VectorFloat vector, ByteSequence bytes, float alpha, float x0, float minValue, float maxValue) { + return SimdOps.nvqSquareDistance8bit( + (ArrayVectorFloat) vector, (ArrayByteSequence) bytes, + alpha, x0, minValue, maxValue); + } + + @Override + public float[] nvqCosine8bit(VectorFloat vector, ByteSequence bytes, float alpha, float x0, float minValue, float maxValue, VectorFloat centroid) { + return SimdOps.nvqCosine8bit( + (ArrayVectorFloat) vector, (ArrayByteSequence) bytes, + alpha, x0, minValue, maxValue, + (ArrayVectorFloat) centroid + ); + } + + @Override + public void nvqShuffleQueryInPlace8bit(VectorFloat vector) { + SimdOps.nvqShuffleQueryInPlace8bit((ArrayVectorFloat) vector); + } + + @Override + public void nvqQuantize8bit(VectorFloat vector, float alpha, float x0, float minValue, float maxValue, ByteSequence destination) { + SimdOps.nvqQuantize8bit((ArrayVectorFloat) vector, alpha, x0, minValue, maxValue,(ArrayByteSequence) destination); + } + + @Override + public float nvqLoss(VectorFloat vector, float alpha, float x0, float minValue, float maxValue, int nBits) { + return SimdOps.nvqLoss((ArrayVectorFloat) vector, alpha, x0, minValue, maxValue, nBits); + } + + @Override + public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { + return SimdOps.nvqUniformLoss((ArrayVectorFloat) vector, minValue, maxValue, nBits); + } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 183b8f763..3c9c1ca7f 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.util.MathUtil; import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import jdk.incubator.vector.ByteVector; @@ -23,6 +24,7 @@ import jdk.incubator.vector.IntVector; import jdk.incubator.vector.LongVector; import jdk.incubator.vector.ShortVector; +import jdk.incubator.vector.VectorMask; import jdk.incubator.vector.VectorOperators; import java.util.List; @@ -32,11 +34,9 @@ final class SimdOps { static final IntVector BYTE_TO_INT_MASK_512 = IntVector.broadcast(IntVector.SPECIES_512, 0xff); static final IntVector BYTE_TO_INT_MASK_256 = IntVector.broadcast(IntVector.SPECIES_256, 0xff); - static final ThreadLocal scratchInt512 = ThreadLocal.withInitial(() -> new int[IntVector.SPECIES_512.length()]); static final ThreadLocal scratchInt256 = ThreadLocal.withInitial(() -> new int[IntVector.SPECIES_256.length()]); - static float sum(ArrayVectorFloat vector) { var sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); @@ -488,6 +488,11 @@ static void addInPlace64(ArrayVectorFloat v1, ArrayVectorFloat v2) { a.add(b).intoArray(v1.get(), 0); } + static void addInPlace64(ArrayVectorFloat v1, float value) { + var a = FloatVector.fromArray(FloatVector.SPECIES_64, v1.get(), 0); + a.add(value).intoArray(v1.get(), 0); + } + static void addInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { if (v1.length() != v2.length()) { throw new IllegalArgumentException("Vectors must have the same length"); @@ -513,6 +518,26 @@ static void addInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { } } + static void addInPlace(ArrayVectorFloat v1, float value) { + if (v1.length() == 2) { + addInPlace64(v1, value); + return; + } + + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(v1.length()); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, v1.get(), i); + a.add(value).intoArray(v1.get(), i); + } + + // Process the tail + for (int i = vectorizedLength; i < v1.length(); i++) { + v1.set(i, v1.get(i) + value); + } + } + static void subInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { if (v1.length() != v2.length()) { throw new IllegalArgumentException("Vectors must have the same length"); @@ -533,6 +558,21 @@ static void subInPlace(ArrayVectorFloat v1, ArrayVectorFloat v2) { } } + static void subInPlace(ArrayVectorFloat vector, float value) { + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var a = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i); + a.sub(value).intoArray(vector.get(), i); + } + + // Process the tail + for (int i = vectorizedLength; i < vector.length(); i++) { + vector.set(i, vector.get(i) - value); + } + } + static VectorFloat sub(ArrayVectorFloat a, int aOffset, ArrayVectorFloat b, int bOffset, int length) { int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(length); float[] res = new float[length]; @@ -553,6 +593,25 @@ static VectorFloat sub(ArrayVectorFloat a, int aOffset, ArrayVectorFloat b, i return new ArrayVectorFloat(res); } + static VectorFloat sub(ArrayVectorFloat a, int aOffset, float value, int length) { + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(length); + float[] res = new float[length]; + + // Process the vectorized part + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var lhs = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, a.get(), aOffset + i); + var subResult = lhs.sub(value); + subResult.intoArray(res, i); + } + + // Process the tail + for (int i = vectorizedLength; i < length; i++) { + res[i] = a.get(aOffset + i) - value; + } + + return new ArrayVectorFloat(res); + } + static float assembleAndSum(float[] data, int dataBase, ByteSequence baseOffsets) { return switch (PREFERRED_BIT_SIZE) { @@ -750,7 +809,7 @@ public static float pqDecodedCosineSimilarity512(ByteSequence baseOffset float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length(); i++) { + for (; i < baseOffsets.length(); i++) { int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets.get(i)); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset]; @@ -811,4 +870,373 @@ public static float pqDecodedCosineSimilarity128(ByteSequence baseOffset return (float) (sum / Math.sqrt(aMag * bMagnitude)); } + + //--------------------------------------------- + // NVQ quantization instructions start here + //--------------------------------------------- + + static final FloatVector const1f = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, 1.f); + static final FloatVector const05f = FloatVector.broadcast(FloatVector.SPECIES_PREFERRED, 0.5f); + + static FloatVector logisticNQT(FloatVector vector, float alpha, float x0) { + FloatVector temp = vector.fma(alpha, -alpha * x0); + VectorMask isPositive = temp.test(VectorOperators.IS_NEGATIVE).not(); + IntVector p = temp.add(1, isPositive) + .convert(VectorOperators.F2I, 0) + .reinterpretAsInts(); + FloatVector e = p.convert(VectorOperators.I2F, 0).reinterpretAsFloats(); + IntVector m = temp.sub(e).fma(0.5f, 1).reinterpretAsInts(); + + temp = m.add(p.lanewise(VectorOperators.LSHL, 23)).reinterpretAsFloats(); // temp = m * 2^p + return temp.div(temp.add(1)); + } + + static float logisticNQT(float value, float alpha, float x0) { + float temp = Math.fma(value, alpha, -alpha * x0); + int p = (int) Math.floor(temp + 1); + int m = Float.floatToIntBits(Math.fma(temp - p, 0.5f, 1)); + + temp = Float.intBitsToFloat(m + (p << 23)); // temp = m * 2^p + return temp / (temp + 1); + } + + static FloatVector logitNQT(FloatVector vector, float inverseAlpha, float x0) { + FloatVector z = vector.div(const1f.sub(vector)); + + IntVector temp = z.reinterpretAsInts(); + FloatVector p = temp.and(0x7f800000) + .lanewise(VectorOperators.LSHR, 23).sub(128) + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + FloatVector m = temp.lanewise(VectorOperators.AND, 0x007fffff).add(0x3f800000).reinterpretAsFloats(); + + return m.add(p).fma(inverseAlpha, x0); + } + + static float logitNQT(float value, float inverseAlpha, float x0) { + float z = value / (1 - value); + + int temp = Float.floatToIntBits(z); + int e = temp & 0x7f800000; + float p = (float) ((e >> 23) - 128); + float m = Float.intBitsToFloat((temp & 0x007fffff) + 0x3f800000); + + return Math.fma(m + p, inverseAlpha, x0); + } + + static FloatVector nvqDequantize8bit(ByteVector bytes, float inverseAlpha, float x0, float logisticScale, float logisticBias, int part) { + /* + * We unpack the vector using the FastLanes strategy: + * https://www.vldb.org/pvldb/vol16/p2132-afroozeh.pdf?ref=blog.lancedb.com + * + * We treat the ByteVector bytes as a vector of integers. + * | Int0 | Int1 | ... + * | Byte3 Byte2 Byte1 Byte0 | Byte3 Byte2 Byte1 Byte0 | ... + * + * The argument part indicates which byte we want to extract from each integer. + * With part=0, we extract + * Int0\Byte0, Int1\Byte0, etc. + * With part=1, we shift by 8 bits and then extract + * Int0\Byte1, Int1\Byte1, etc. + * With part=2, we shift by 16 bits and then extract + * Int0\Byte2, Int1\Byte2, etc. + * With part=3, we shift by 24 bits and then extract + * Int0\Byte3, Int1\Byte3, etc. + */ + var arr = bytes.reinterpretAsInts() + .lanewise(VectorOperators.LSHR, 8 * part) + .lanewise(VectorOperators.AND, 0xff) + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + + arr = arr.fma(logisticScale, logisticBias); + return logitNQT(arr, inverseAlpha, x0); + } + + static void nvqQuantize8bit(ArrayVectorFloat vector, float alpha, float x0, float minValue, float maxValue, ArrayByteSequence destination) { + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + final var mask = ByteVector.SPECIES_PREFERRED.indexInRange(0, FloatVector.SPECIES_PREFERRED.length()); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var invLogisticScale = 255 / (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias); + + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var arr = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i); + arr = logisticNQT(arr, scaledAlpha, scaledX0); + arr = arr.sub(logisticBias).mul(invLogisticScale); + var bytes = arr.add(const05f) + .convertShape(VectorOperators.F2B, ByteVector.SPECIES_PREFERRED, 0) + .reinterpretAsBytes(); + bytes.intoArray(destination.get(), i, mask); + } + + // Process the tail + for (int d = vectorizedLength; d < vector.length(); d++) { + // Ensure the quantized value is within the 0 to constant range + float value = vector.get(d); + value = logisticNQT(value, scaledAlpha, scaledX0); + value = (value - logisticBias) * invLogisticScale; + int quantizedValue = Math.round(value); + destination.set(d, (byte) quantizedValue); + } + } + + static float nvqLoss(ArrayVectorFloat vector, float alpha, float x0, float minValue, float maxValue, int nBits) { + int constant = (1 << nBits) - 1; + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + + FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / constant; + var invLogisticScale = 1 / logisticScale; + + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var arr = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i); + var recArr = logisticNQT(arr, scaledAlpha, scaledX0); + recArr = recArr.sub(logisticBias).mul(invLogisticScale); + recArr = recArr.add(const05f) + .convert(VectorOperators.F2I, 0) + .reinterpretAsInts() + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + recArr = recArr.fma(logisticScale, logisticBias); + recArr = logitNQT(recArr, invScaledAlpha, scaledX0); + + var diff = arr.sub(recArr); + squaredSumVec = diff.fma(diff, squaredSumVec); + } + + float squaredSum = squaredSumVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value, recValue; + for (int i = vectorizedLength; i < vector.length(); i++) { + value = vector.get(i); + + recValue = logisticNQT(value, scaledAlpha, scaledX0); + recValue = (recValue - logisticBias) * invLogisticScale; + recValue = Math.round(recValue); + recValue = Math.fma(logisticScale, recValue, logisticBias); + recValue = logitNQT(recValue, invScaledAlpha, scaledX0); + + squaredSum += MathUtil.square(value - recValue); + } + + return squaredSum; + } + + static float nvqUniformLoss(ArrayVectorFloat vector, float minValue, float maxValue, int nBits) { + float constant = (1 << nBits) - 1; + float delta = maxValue - minValue; + + int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + + FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int i = 0; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + var arr = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i); + var recArr = arr.sub(minValue).mul(constant / delta); + recArr = recArr.add(const05f) + .convert(VectorOperators.F2I, 0) + .reinterpretAsInts() + .convert(VectorOperators.I2F, 0) + .reinterpretAsFloats(); + recArr = recArr.fma(delta / constant, minValue); + + var diff = arr.sub(recArr); + squaredSumVec = diff.fma(diff, squaredSumVec); + } + + float squaredSum = squaredSumVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value, recValue; + for (int i = vectorizedLength; i < vector.length(); i++) { + value = vector.get(i); + + recValue = (value - minValue) / delta; + recValue = Math.round(constant * recValue) / constant; + recValue = recValue * delta + minValue; + + squaredSum += MathUtil.square(value - recValue); + } + + return squaredSum; + } + + static float nvqSquareDistance8bit(ArrayVectorFloat vector, ArrayByteSequence quantizedVector, + float alpha, float x0, float minValue, float maxValue) { + FloatVector squaredSumVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(quantizedVector.length()); + int floatStep = FloatVector.SPECIES_PREFERRED.length(); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255; + + for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) { + var byteArr = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i); + + for (int j = 0; j < 4; j++) { + var v1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j); + var v2 = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j); + + var diff = v1.sub(v2); + squaredSumVec = diff.fma(diff, squaredSumVec); + } + } + + float squaredSum = squaredSumVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value2, diff; + for (int i = vectorizedLength; i < quantizedVector.length(); i++) { + value2 = Byte.toUnsignedInt(quantizedVector.get(i)); + value2 = logisticScale * value2 + logisticBias; + value2 = logitNQT(value2, invScaledAlpha, scaledX0); + diff = vector.get(i) - value2; + squaredSum += MathUtil.square(diff); + } + + return squaredSum; + } + + static float nvqDotProduct8bit(ArrayVectorFloat vector, ArrayByteSequence quantizedVector, + float alpha, float x0, float minValue, float maxValue) { + FloatVector dotProdVec = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(quantizedVector.length()); + int floatStep = FloatVector.SPECIES_PREFERRED.length(); + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255; + + + for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) { + var byteArr = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i); + + for (int j = 0; j < 4; j++) { + var v1 = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j); + var v2 = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j); + dotProdVec = v1.fma(v2, dotProdVec); + } + } + + float dotProd = dotProdVec.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value2; + for (int i = vectorizedLength; i < quantizedVector.length(); i++) { + value2 = Byte.toUnsignedInt(quantizedVector.get(i)); + value2 = Math.fma(logisticScale, value2, logisticBias); + value2 = logitNQT(value2, invScaledAlpha, scaledX0); + dotProd = Math.fma(vector.get(i), value2, dotProd); + } + + return dotProd; + } + + static float[] nvqCosine8bit(ArrayVectorFloat vector, + ArrayByteSequence quantizedVector, float alpha, float x0, float minValue, float maxValue, + ArrayVectorFloat centroid) { + if (vector.length() != centroid.length()) { + throw new IllegalArgumentException("Vectors must have the same length"); + } + + var delta = maxValue - minValue; + var scaledAlpha = alpha / delta; + var invScaledAlpha = 1 / scaledAlpha; + var scaledX0 = x0 * delta; + var logisticBias = logisticNQT(minValue, scaledAlpha, scaledX0); + var logisticScale = (logisticNQT(maxValue, scaledAlpha, scaledX0) - logisticBias) / 255; + + var vsum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + var vbMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + int vectorizedLength = ByteVector.SPECIES_PREFERRED.loopBound(vector.length()); + int floatStep = FloatVector.SPECIES_PREFERRED.length(); + + for (int i = 0; i < vectorizedLength; i += ByteVector.SPECIES_PREFERRED.length()) { + var byteArr = ByteVector.fromArray(ByteVector.SPECIES_PREFERRED, quantizedVector.get(), i); + + for (int j = 0; j < 4; j++) { + var va = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, vector.get(), i + floatStep * j); + var vb = nvqDequantize8bit(byteArr, invScaledAlpha, scaledX0, logisticScale, logisticBias, j); + + var vCentroid = FloatVector.fromArray(FloatVector.SPECIES_PREFERRED, centroid.get(), i + floatStep * j); + vb = vb.add(vCentroid); + + vsum = va.fma(vb, vsum); + vbMagnitude = vb.fma(vb, vbMagnitude); + } + } + + float sum = vsum.reduceLanes(VectorOperators.ADD); + float bMagnitude = vbMagnitude.reduceLanes(VectorOperators.ADD); + + // Process the tail + float value2; + for (int i = vectorizedLength; i < vector.length(); i++) { + value2 = Byte.toUnsignedInt(quantizedVector.get(i)); + value2 = Math.fma(logisticScale, value2, logisticBias); + value2 = logitNQT(value2, invScaledAlpha, scaledX0) + centroid.get(i); + sum = Math.fma(vector.get(i), value2, sum); + bMagnitude = Math.fma(value2, value2, bMagnitude); + } + + // TODO can we avoid returning a new array? + return new float[]{sum, bMagnitude}; + } + + static void transpose(float[] arr, int first, int last, int nRows) { + final int mn1 = (last - first - 1); + final int n = (last - first) / nRows; + boolean[] visited = new boolean[last - first]; + float temp; + int cycle = first; + while (++cycle != last) { + if (visited[cycle - first]) + continue; + int a = cycle - first; + do { + a = a == mn1 ? mn1 : (n * a) % mn1; + temp = arr[first + a]; + arr[first + a] = arr[cycle]; + arr[cycle] = temp; + visited[a] = true; + } while ((first + a) != cycle); + } + } + + static void nvqShuffleQueryInPlace8bit(ArrayVectorFloat vector) { + // To understand this shuffle, see nvqDequantize8bit + var arr = vector.get(); + + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length()); + final int step = FloatVector.SPECIES_PREFERRED.length() * 4; + + for (int i = 0; i + step <= vectorizedLength; i += step) { + transpose(arr, i, i + step, 4); + } + } + + //--------------------------------------------- + // NVQ instructions end here + //--------------------------------------------- }