Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-uniform vector quantization #374

Open
wants to merge 174 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
174 commits
Select commit Hold shift + click to select a range
20642a5
Adding initial files for Non-uniform Vector Quantization
marianotepper Oct 24, 2024
4166400
Adding methods to get the number of rows and columns.
marianotepper Oct 28, 2024
28764d0
Adding method to compute the square of a double number
marianotepper Oct 28, 2024
5e9ecbb
Implementation of the Exponential Natural Evolution Strategies optimi…
marianotepper Oct 28, 2024
c83bd2d
Minor code cleanup
marianotepper Oct 28, 2024
eabfad3
Revert "Adding methods to get the number of rows and columns."
marianotepper Oct 28, 2024
1c0608b
Minor code cleanup
marianotepper Oct 28, 2024
aa6dfdd
Making the constructor of NESOptimizer public. Now optimize returns a…
marianotepper Oct 28, 2024
b16b0d6
Adding tests.
marianotepper Oct 28, 2024
751634a
Remove unused import
marianotepper Oct 28, 2024
3a80c1f
Remove unused variable
marianotepper Oct 28, 2024
8ccf14f
Increase default number of samples in NESOptimizer
marianotepper Oct 28, 2024
1141303
Add a few more tests for NESOptimizer
marianotepper Oct 28, 2024
a421f1d
Add a few more tests for NESOptimizer
marianotepper Oct 28, 2024
98c29c5
Add javadocs
marianotepper Oct 28, 2024
5f986b6
Add one test for the case without box constraints
marianotepper Oct 28, 2024
69c745e
Put access modifier first
marianotepper Oct 28, 2024
a1f94bc
Put access modifier first
marianotepper Oct 28, 2024
81445e4
Improve javadocs
marianotepper Oct 28, 2024
cab64ba
Remove blank line
marianotepper Oct 28, 2024
114d60c
Initial test for the non-uniform quantizer
marianotepper Oct 28, 2024
836f289
Add an overload for the subInPlace vector function, that subtracts a …
marianotepper Oct 29, 2024
b3811f5
Completed TestNonUniformQuantization.testGaussian with a passing crit…
marianotepper Oct 29, 2024
c47d2b3
Add vectorized operations pow, constantMinusExponentiatedVector, and …
marianotepper Oct 29, 2024
93ef1d3
Add vectorized operations to TestNonUniformQuantization and increase …
marianotepper Oct 29, 2024
e4e82c5
Add missing blank space
marianotepper Oct 29, 2024
4250711
Use VectorUtil.squareL2Distance for computing the loss
marianotepper Oct 29, 2024
fa295a6
Remove unused import
marianotepper Oct 29, 2024
8be166b
Revert "Remove unused import"
marianotepper Oct 29, 2024
a096142
Remove unused import
marianotepper Oct 29, 2024
09b6a93
Commit with a stub of the functionality in place. Still non-functional
marianotepper Oct 31, 2024
91abcae
Commit with a more complete stub of the functionality in place. Still…
marianotepper Nov 1, 2024
7a59b2e
Commit vectorized nvqDotProduct.
marianotepper Nov 1, 2024
e3b1141
Implemented getSubVectors for NVQ. Still non-functional
tlwillke Nov 4, 2024
88ee4f9
Fixed getSubVectors
tlwillke Nov 4, 2024
da86dc0
Merge remote-tracking branch 'origin/nuveq' into nuveq
marianotepper Nov 4, 2024
2b49c2b
Commit vectorized nvqCosine.
marianotepper Nov 4, 2024
c022a5d
Delete wrong normalization
marianotepper Nov 4, 2024
0fa7259
Fix access modifier of the fields in QuantizedSubVector
marianotepper Nov 4, 2024
a902916
Remove unused import
marianotepper Nov 4, 2024
46d8837
Implementation of nvqDotProduct and nvqCosine in NativeVectorUtilSupport
marianotepper Nov 4, 2024
04f828a
Remove unused import
marianotepper Nov 4, 2024
a598ed3
Remove unnecessary semicolons
marianotepper Nov 4, 2024
3d52e74
Replace switches with enhanced switches
marianotepper Nov 4, 2024
2d6f27e
Add stub implementations of nvqShuffleQueryInPlace
marianotepper Nov 4, 2024
898ff3f
Updates to some NVQ comments
tlwillke Nov 4, 2024
599b6f9
Merge branch 'nuveq' of https://github.com/jbellis/jvector into nuveq
tlwillke Nov 4, 2024
6f48de0
Add implementations of nvqShuffleQueryInPlace
marianotepper Nov 5, 2024
f6e1943
First implementation of NVQ Euclidean distance calcs. Includes Defau…
tlwillke Nov 5, 2024
62c441e
Merge branch 'nuveq' of https://github.com/jbellis/jvector into nuveq
tlwillke Nov 5, 2024
fa0cc17
Add query shuffling similarity computation in NVQVectors
marianotepper Nov 5, 2024
7b7fe8d
Replace traditional for loops with enhanced loops
marianotepper Nov 5, 2024
684a47a
Completed first native and panama implementations for NVQ square L2 d…
tlwillke Nov 5, 2024
18436aa
Merge branch 'nuveq' of https://github.com/jbellis/jvector into nuveq
tlwillke Nov 5, 2024
30c2d66
Fix a bunch of minor errors
marianotepper Nov 6, 2024
9f3485d
First test for NVQuantization and NVQVectors
marianotepper Nov 6, 2024
7b43382
Merge remote-tracking branch 'origin/nuveq' into nuveq
marianotepper Nov 6, 2024
7ac54fe
Fix bug when dealing with an odd number of dimensions. Expose nvqDequ…
marianotepper Nov 6, 2024
1dbd769
Add TestUtil.normalRandomVector and TestUtil.createNormalRandomVectors
marianotepper Nov 7, 2024
dc0f333
Add VectorUtil.sub to subtract a float from a vector not in-place
marianotepper Nov 7, 2024
3cb1b5d
WIP test for distances in NVQuantization
marianotepper Nov 7, 2024
b0b0a42
Test for distances in NVQuantization completed. Passes for non-vector…
marianotepper Nov 8, 2024
95f793d
Apply shuffling to the centroid as well for cosine similarity
marianotepper Nov 8, 2024
2387cd3
Fix vectorized dequantization for 4-bit NVQ shuffling
marianotepper Nov 8, 2024
03b57e4
Major NVQ refactoring mostly completed.
marianotepper Nov 8, 2024
22a109a
Fix abstraction-breaking functions in VectorUtil. NESOptimizer now us…
marianotepper Nov 13, 2024
6bd8e59
Remove commented code
marianotepper Nov 13, 2024
411e9af
Lower the number of samples in NESOptimizer. A few other code improve…
marianotepper Nov 15, 2024
21076b8
Added NVQ Feature ID. Initial implementation of NVQ for graph indexi…
tlwillke Nov 17, 2024
d5dd527
Merging changes for NVQ graph indexing.
tlwillke Nov 17, 2024
a2ca4ff
Implement vectorized quantization. SIMD vectorized instructions clean…
marianotepper Nov 18, 2024
3e1fcdd
Merge remote-tracking branch 'origin/nuveq' into nuveq
marianotepper Nov 18, 2024
57128ba
Fix order of imports to prevent coupling between graph/disk and pq
marianotepper Nov 18, 2024
fc8fe14
Remove commented code
marianotepper Nov 18, 2024
879b027
Remove redundant interface
marianotepper Nov 18, 2024
7a9643a
Remove blank line
marianotepper Nov 18, 2024
590d03d
Put a maximum number of attempts in NES solver, just in case.
marianotepper Nov 18, 2024
07fb340
Reinstate commented code for further investigation
marianotepper Nov 18, 2024
9fe93c4
Updates to get NVQ working with Bench and for reranking. Changes to …
tlwillke Nov 19, 2024
ba8df54
Wire in NVQ feature-driven reranker in ODGI. Fix serialization sizing…
jkni Nov 19, 2024
c50c982
Reduce allocations for NVQ.
marianotepper Nov 22, 2024
e0098fc
Fix bug in computation of the mean
marianotepper Nov 22, 2024
4266b66
Add missing parenthesis in comment
marianotepper Nov 22, 2024
f56c899
Eliminate unneeded memory allocations and improve performance. Add an…
marianotepper Nov 25, 2024
30adcdf
Fix type in comment
marianotepper Nov 25, 2024
6f8addc
Removed temporary dataset truncation for Bench.
tlwillke Nov 25, 2024
e2a9704
Code cleanup. Implement exp/log exponentiationtrick
marianotepper Nov 26, 2024
09e25af
Merge remote-tracking branch 'origin/nuveq' into nuveq
marianotepper Nov 26, 2024
85be405
Minor code improvements in similarity computations for the default en…
marianotepper Nov 26, 2024
c545926
Eliminate memory allocations and interactions in similarity computati…
marianotepper Nov 26, 2024
b667a5b
Replace manual implementation of the inverse Kumaraswamy with a call …
marianotepper Nov 26, 2024
64ccdf0
Quick experiments with approximations to the Kumaraswamy inverse CDF
marianotepper Nov 26, 2024
52dde49
Accelerated implementation of natural log and exp that lead to accele…
marianotepper Nov 26, 2024
2cea709
Accelerated computation of the Kumaraswamy CDF
marianotepper Nov 26, 2024
5860bdb
Vectorized acceleration of forwardKumaraswamy and inverseKumaraswamy …
marianotepper Nov 26, 2024
ccb9d92
Minor update in the signature of SIMDOps.fastLog
marianotepper Nov 27, 2024
647dc3f
Small bench for assessing speed of distance calculations with NVQ
marianotepper Nov 27, 2024
db765f2
Minor fixes to bench for assessing speed of distance calculations wit…
marianotepper Nov 27, 2024
cb6c360
Create an additional bench for SIMD investigations of Kumaraswamy
marianotepper Nov 27, 2024
6a3b0a7
Moving kumaraswamyApproximationScalar and kumaraswamyApproximationSIM…
marianotepper Nov 27, 2024
d5c2a82
Make import list in distancesNVQ tighter
marianotepper Nov 28, 2024
7ef54d2
Measure number of scalar instead of vector operations
marianotepper Nov 28, 2024
8e6e47e
Refactor kumaraswamy benchmarks into their own package
marianotepper Nov 28, 2024
244e354
Remove unneeded parameter from forwardKumaraswamy
marianotepper Nov 28, 2024
7541448
Remove dummy variable used for debugging
marianotepper Nov 28, 2024
b8d5f03
Temp fixes for NVQ species and inverseKumaraswamy.
tlwillke Dec 1, 2024
55f1f62
Merge changes to species selection and inverseKumaraswamy.
tlwillke Dec 1, 2024
780c7a7
Set species for ByteVector based on FloatVector.SPECIES_PREFERRED.
tlwillke Dec 1, 2024
ba9d50a
Updates in kumaraswamyApproximationScalar to make the experiment more…
marianotepper Dec 1, 2024
c11ded7
Minor updates in kumaraswamyApproximationSIMD
marianotepper Dec 1, 2024
0fcce64
Minor refactor the version of nvqDequantize8bit that returns a vector
marianotepper Dec 1, 2024
cae0da5
Refactoring that allows the code to run on any platform. Implementati…
marianotepper Dec 1, 2024
bbac862
Refactoring of the NVQ SIMD code to always use SPECIES_PREFERRED. Now…
marianotepper Dec 2, 2024
447ab3d
New code with that does the query shuffling for 8-bit NVQ. Optimizati…
marianotepper Dec 12, 2024
cbda16a
Working implementation with logistic optimization
marianotepper Dec 13, 2024
7d9936b
Revert changes in distancesNVQ
marianotepper Dec 13, 2024
a346f66
Reduced the number of calls to logistic_function for each distance co…
marianotepper Dec 13, 2024
dfbd475
Added NVQ logistic and logit methods. Renamed Replaced parameters a,…
tlwillke Dec 13, 2024
be85a72
Merging changes for NVQ logistic methods.
tlwillke Dec 13, 2024
a05c942
Add FMAs where possible in default backend.
marianotepper Dec 13, 2024
6a0d926
Add FMAs where possible in MathUtil::fastExp and improve its comment.
marianotepper Dec 13, 2024
6907e62
Change function parameter names in PanamaVectorUtilSupport to the new…
marianotepper Dec 13, 2024
2e39599
Add better fastExp in SimdOps
marianotepper Dec 13, 2024
721e44a
Added scaling and bias for logistic functions. Changes impact quanti…
tlwillke Dec 15, 2024
3c95d10
Add dummy accumulator in distancesNVQ
marianotepper Dec 16, 2024
84d233f
Merge remote-tracking branch 'origin/nuveq' into nuveq
marianotepper Dec 16, 2024
0d7fe86
Fuse division by 255 and logistic scaling
marianotepper Dec 16, 2024
c53317a
Replace a couple of .mul.add patterns by .fma
marianotepper Dec 16, 2024
58b3030
Fuse division by 255 and logistic scaling
marianotepper Dec 16, 2024
5b92789
Remove extra scaling step not needed anymore with the logistic nonlin…
marianotepper Dec 17, 2024
78fa8ef
Remove extra scaling step not needed anymore with the logistic nonlin…
marianotepper Dec 17, 2024
dd07de3
Remove unused functions from VectorUtil
marianotepper Dec 17, 2024
edbc76d
Remove old test
marianotepper Dec 17, 2024
b34345c
Set learning to true in SiftSmall
marianotepper Dec 17, 2024
184198a
Set number of subsegments to 2 in SiftSmall
marianotepper Dec 17, 2024
b133a69
Fix vectorizedLength in nvqCosine8bit. Fix tail computation in all th…
marianotepper Dec 17, 2024
1f1676a
Set default number fo subvectors to 2 in distancesNVQNVQ
marianotepper Dec 17, 2024
cba3336
Cleanup documentation and comments
marianotepper Dec 17, 2024
8512ce4
Cleanup documentation and comments
marianotepper Dec 17, 2024
d38c591
Remove unnecessary code
marianotepper Dec 17, 2024
16497d4
Remove unnecessary imports
marianotepper Dec 17, 2024
2238c3a
Remove unnecessary variable
marianotepper Dec 17, 2024
966d5f3
Remove unnecessary pow function from VectorUtil
marianotepper Dec 17, 2024
9bda7cd
Fix hyperlinks in documentation
marianotepper Dec 17, 2024
7495341
Remove unused variables and old documentation
marianotepper Dec 17, 2024
099880f
Do the optimization using NonuniformQuantizationLossFunction. Fix com…
marianotepper Dec 17, 2024
c551431
Remove NESSolver calls from NVQuantization
marianotepper Dec 17, 2024
70b88ce
Add versioning code back to NVQuantization
marianotepper Dec 17, 2024
852e770
Remove 4-bit NVQ code
marianotepper Dec 17, 2024
79f8469
Remove Kumaraswamy code from SimdOps
marianotepper Dec 17, 2024
8188d1e
Remove unused static variable from SimdOps
marianotepper Dec 17, 2024
832173a
Add new NVQ code to native backend
marianotepper Dec 18, 2024
afbf274
Remove comments
marianotepper Dec 18, 2024
5b3c16a
Rename directory "pq" as "quantization"
marianotepper Dec 18, 2024
bcc5c0a
Do not expose the number of bits in the API of NVQuantization
marianotepper Dec 18, 2024
9cb9d8e
Remove unused variables and apply fix to hash function in NVQVectors
marianotepper Dec 18, 2024
2287b24
Remove NVQuantization.BitsPerDimension from CompressorParameters
marianotepper Dec 18, 2024
512c899
Adjust expected results in testSaveLoadNVQ to accommodate the new for…
marianotepper Dec 18, 2024
062834b
Remove bench files for experimenting with the Kumaraswamy approximation
marianotepper Dec 18, 2024
5e9694a
Cleaning up distancesNVQ
marianotepper Dec 18, 2024
f814efb
Cleaning up distancesNVQ
marianotepper Dec 18, 2024
3d07f98
Cleaned up a few NVQ comments.
tlwillke Dec 18, 2024
53ecb05
Merge remote-tracking branch 'origin/nuveq' into nuveq
marianotepper Dec 18, 2024
38e54c2
Cleaning up distancesNVQ
marianotepper Dec 18, 2024
db927c4
Remove unused import
marianotepper Dec 18, 2024
64d10e8
Remove unused import
marianotepper Dec 18, 2024
9aa25f1
Restore the FUSED_ADC feature in Bench
marianotepper Dec 18, 2024
ed3d6ec
Merge remote-tracking branch 'origin/main' into nuveq
marianotepper Dec 18, 2024
5599773
Merge changes from main
marianotepper Dec 19, 2024
46bd177
Cosmetic changes to distancesNVQ
marianotepper Dec 19, 2024
d9bc9b0
Conform to the new interface that uses encodeTo
marianotepper Dec 19, 2024
aedda1a
Remove unused code to declutter. The dequantization path wsa useful f…
marianotepper Dec 19, 2024
649f59e
Replaced the logistic/logit apir with their NQT variants
marianotepper Dec 20, 2024
ab128f8
Replaced one occurrence of snake case
marianotepper Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
*/
public enum FeatureId {
INLINE_VECTORS(InlineVectors::load),
FUSED_ADC(FusedADC::load);
FUSED_ADC(FusedADC::load),
NVQ_VECTORS(NVQ::load);

public static final Set<FeatureId> ALL = Collections.unmodifiableSet(EnumSet.allOf(FeatureId.class));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@
import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.GraphIndex;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.pq.FusedADCPQDecoder;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.pq.ProductQuantization;
import io.github.jbellis.jvector.quantization.FusedADCPQDecoder;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.quantization.ProductQuantization;
import io.github.jbellis.jvector.util.ExplicitThreadLocal;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorizationProvider;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.graph.disk;

import io.github.jbellis.jvector.disk.RandomAccessReader;
import io.github.jbellis.jvector.graph.similarity.ScoreFunction;
import io.github.jbellis.jvector.quantization.NVQScorer;
import io.github.jbellis.jvector.quantization.NVQuantization;
import io.github.jbellis.jvector.quantization.NVQuantization.QuantizedVector;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.types.VectorFloat;

import java.io.DataOutput;
import java.io.IOException;
import java.io.UncheckedIOException;

/**
* Implements the storage of NuVeQ vectors in an on-disk graph index. These can be used for reranking.
*/
public class NVQ implements Feature {
private final NVQuantization nvq;
private final NVQScorer scorer;

public NVQ(NVQuantization nvq) {
this.nvq = nvq;
scorer = new NVQScorer(this.nvq);
}

@Override
public FeatureId id() {
return FeatureId.NVQ_VECTORS;
}

@Override
public int headerSize() {
return nvq.compressorSize();
}

@Override
public int inlineSize() { return nvq.compressedVectorSize();}

public int dimension() {
return nvq.globalMean.length();
}

static NVQ load(CommonHeader header, RandomAccessReader reader) {
try {
return new NVQ(NVQuantization.load(reader));
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

@Override
public void writeHeader(DataOutput out) throws IOException {
nvq.write(out, OnDiskGraphIndex.CURRENT_VERSION);
}

@Override
public void writeInline(DataOutput out, Feature.State state_) throws IOException {
var state = (NVQ.State) state_;
state.vector.write(out);
}

public static class State implements Feature.State {
public final QuantizedVector vector;

public State(QuantizedVector vector) {
this.vector = vector;
}
}

ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector,
VectorSimilarityFunction vsf,
FeatureSource source) {
var function = scorer.scoreFunctionFor(queryVector, vsf);

return new ScoreFunction.ExactScoreFunction() {
private final QuantizedVector scratch = NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, nvq.bitsPerDimension);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In most other places like this, we use a reusable thread-local scratch. I'm not sure if it's worth it, so consider this a possible deferred optimization.


@Override
public float similarityTo(int node2) {
try {
var reader = source.inlineReaderForNode(node2, FeatureId.NVQ_VECTORS);
QuantizedVector.loadInto(reader, scratch);
} catch (IOException e) {
throw new RuntimeException(e);
}
return function.similarityTo(scratch);
}
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,16 @@ public void close() throws IOException {

@Override
public ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, VectorSimilarityFunction vsf) {
if (!features.containsKey(FeatureId.INLINE_VECTORS)) {
throw new UnsupportedOperationException("No inline vectors in this graph");

if (features.containsKey(FeatureId.INLINE_VECTORS))
{
return RandomAccessVectorValues.super.rerankerFor(queryVector, vsf);
} else if (features.containsKey(FeatureId.NVQ_VECTORS))
{
return ((NVQ) features.get(FeatureId.NVQ_VECTORS)).rerankerFor(queryVector, vsf, this);
} else {
throw new UnsupportedOperationException("No reranker available for this graph");
}
return RandomAccessVectorValues.super.rerankerFor(queryVector, vsf);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,8 @@ public OnDiskGraphIndexWriter build() throws IOException {
int dimension;
if (features.containsKey(FeatureId.INLINE_VECTORS)) {
dimension = ((InlineVectors) features.get(FeatureId.INLINE_VECTORS)).dimension();
} else if (features.containsKey(FeatureId.NVQ_VECTORS)) {
dimension = ((NVQ) features.get(FeatureId.NVQ_VECTORS)).dimension();
} else {
throw new IllegalArgumentException("Inline vectors must be provided.");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update this error message to indicate that either inline vectors or NVQ vectors much be provided.

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package io.github.jbellis.jvector.graph.similarity;

import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
import io.github.jbellis.jvector.vector.VectorUtil;
import io.github.jbellis.jvector.vector.VectorizationProvider;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package io.github.jbellis.jvector.graph.similarity;

import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
import io.github.jbellis.jvector.pq.PQVectors;
import io.github.jbellis.jvector.quantization.PQVectors;
import io.github.jbellis.jvector.vector.types.VectorFloat;
import org.agrona.collections.Int2ObjectHashMap;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.github.jbellis.jvector.optimization;

import java.util.Arrays;
import java.util.stream.IntStream;

/**
* Class that models a loss function that maps a multidimensional vector onto a real number.
*/
public abstract class LossFunction {
// The number of dimensions
final private int nDims;

// The box constraints that define the feasible set.
private float[] minBounds;
private float[] maxBounds;

/**
* Constructs a LossFunction acting on vectors of the specified number of dimensions.
* @param nDims the number of dimensions
*/
public LossFunction(int nDims) {
if (nDims <= 0) {
throw new IllegalArgumentException("The standard deviation initSigma must be positive");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad error message copy/paste

}
minBounds = new float[nDims];
maxBounds = new float[nDims];
for (int d = 0; d < nDims; d++) {
minBounds[d] = Float.NEGATIVE_INFINITY;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use Arrays.fill here

maxBounds[d] = Float.POSITIVE_INFINITY;
}

this.nDims = nDims;
}

/**
* Computes the loss function. It assumes that input is within the feasible set
* @param x the input vector
* @return the loss
*/
public abstract float compute(float[] x);

/**
* Computes the loss function and projects the input in-place onto the feasible set
* @param x the input vector
* @return the loss
*/
public float projectCompute(float[] x) {
project(x);
return compute(x);
}

/**
* Sets the minimum values of the box constraints.
* @param bounds the specified minimum bound
*/
public void setMinBounds(float[] bounds) {
if (nDims != bounds.length) {
throw new IllegalArgumentException("The length of bounds should match the number of dimensions");
}
minBounds = bounds;
}

/**
* Gets the minimum values of the box constraints.
* @return the minimum bound
*/
public float[] getMinBounds() {
return minBounds;
}

/**
* Sets the maximum values of the box constraints.
* @param bounds the specified maximum bound
*/
public void setMaxBounds(float[] bounds) {
if (nDims != bounds.length) {
throw new IllegalArgumentException("The length of bounds should match the number of dimensions");
}
maxBounds = bounds;
}

/**
* Gets the maximum values of the box constraints.
* @return the maximum bound
*/
public float[] getMaxBounds() {
return maxBounds;
}

/**
* Projects the input onto the feasible set. If in-place, the input array is modified;
* otherwise, a copy is created and then projected.
* @param x the input vector
* @param inPlace If true, the input array is modified; otherwise, a copy is created and then projected.
* @return the projected vector
*/
public float[] project(float[] x, boolean inPlace) {
float[] copy;
if (inPlace) {
copy = x;
}
else {
copy = Arrays.copyOf(x, x.length);
}
for (int d = 0; d < nDims; d++) {
copy[d] = Math.min(Math.max(x[d], minBounds[d]), maxBounds[d]);
}
return copy;
}

/**
* Projects the input in-place onto the feasible set.
* @param x the input vector
*/
public void project(float[] x) {
project(x, true);
}

public boolean minimumGoalAchieved(float lossValue) {
return false;
}
}
Loading
Loading