-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Non-uniform vector quantization #374
base: main
Are you sure you want to change the base?
Changes from all commits
20642a5
4166400
28764d0
5e9ecbb
c83bd2d
eabfad3
1c0608b
aa6dfdd
b16b0d6
751634a
3a80c1f
8ccf14f
1141303
a421f1d
98c29c5
5f986b6
69c745e
a1f94bc
81445e4
cab64ba
114d60c
836f289
b3811f5
c47d2b3
93ef1d3
e4e82c5
4250711
fa295a6
8be166b
a096142
09b6a93
91abcae
7a59b2e
e3b1141
88ee4f9
da86dc0
2b49c2b
c022a5d
0fa7259
a902916
46d8837
04f828a
a598ed3
3d52e74
2d6f27e
898ff3f
599b6f9
6f48de0
f6e1943
62c441e
fa0cc17
7b7fe8d
684a47a
18436aa
30c2d66
9f3485d
7b43382
7ac54fe
1dbd769
dc0f333
3cb1b5d
b0b0a42
95f793d
2387cd3
03b57e4
22a109a
6bd8e59
411e9af
21076b8
d5dd527
a2ca4ff
3e1fcdd
57128ba
fc8fe14
879b027
7a9643a
590d03d
07fb340
9fe93c4
ba8df54
c50c982
e0098fc
4266b66
f56c899
30adcdf
6f8addc
e2a9704
09e25af
85be405
c545926
b667a5b
64ccdf0
52dde49
2cea709
5860bdb
ccb9d92
647dc3f
db765f2
cb6c360
6a3b0a7
d5c2a82
7ef54d2
8e6e47e
244e354
7541448
b8d5f03
55f1f62
780c7a7
ba9d50a
c11ded7
0fcce64
cae0da5
bbac862
447ab3d
cbda16a
7d9936b
a346f66
dfbd475
be85a72
a05c942
6a0d926
6907e62
2e39599
721e44a
3c95d10
84d233f
0d7fe86
c53317a
58b3030
5b92789
78fa8ef
dd07de3
edbc76d
b34345c
184198a
b133a69
1f1676a
cba3336
8512ce4
d38c591
16497d4
2238c3a
966d5f3
9bda7cd
7495341
099880f
c551431
70b88ce
852e770
79f8469
8188d1e
832173a
afbf274
5b3c16a
bcc5c0a
9cb9d8e
2287b24
512c899
062834b
5e9694a
f814efb
3d07f98
53ecb05
38e54c2
db927c4
64d10e8
9aa25f1
ed3d6ec
5599773
46bd177
d9bc9b0
aedda1a
649f59e
ab128f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
/* | ||
* Copyright DataStax, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package io.github.jbellis.jvector.graph.disk; | ||
|
||
import io.github.jbellis.jvector.disk.RandomAccessReader; | ||
import io.github.jbellis.jvector.graph.similarity.ScoreFunction; | ||
import io.github.jbellis.jvector.quantization.NVQScorer; | ||
import io.github.jbellis.jvector.quantization.NVQuantization; | ||
import io.github.jbellis.jvector.quantization.NVQuantization.QuantizedVector; | ||
import io.github.jbellis.jvector.vector.VectorSimilarityFunction; | ||
import io.github.jbellis.jvector.vector.types.VectorFloat; | ||
|
||
import java.io.DataOutput; | ||
import java.io.IOException; | ||
import java.io.UncheckedIOException; | ||
|
||
/** | ||
* Implements the storage of NuVeQ vectors in an on-disk graph index. These can be used for reranking. | ||
*/ | ||
public class NVQ implements Feature { | ||
private final NVQuantization nvq; | ||
private final NVQScorer scorer; | ||
|
||
public NVQ(NVQuantization nvq) { | ||
this.nvq = nvq; | ||
scorer = new NVQScorer(this.nvq); | ||
} | ||
|
||
@Override | ||
public FeatureId id() { | ||
return FeatureId.NVQ_VECTORS; | ||
} | ||
|
||
@Override | ||
public int headerSize() { | ||
return nvq.compressorSize(); | ||
} | ||
|
||
@Override | ||
public int inlineSize() { return nvq.compressedVectorSize();} | ||
|
||
public int dimension() { | ||
return nvq.globalMean.length(); | ||
} | ||
|
||
static NVQ load(CommonHeader header, RandomAccessReader reader) { | ||
try { | ||
return new NVQ(NVQuantization.load(reader)); | ||
} catch (IOException e) { | ||
throw new UncheckedIOException(e); | ||
} | ||
} | ||
|
||
@Override | ||
public void writeHeader(DataOutput out) throws IOException { | ||
nvq.write(out, OnDiskGraphIndex.CURRENT_VERSION); | ||
} | ||
|
||
@Override | ||
public void writeInline(DataOutput out, Feature.State state_) throws IOException { | ||
var state = (NVQ.State) state_; | ||
state.vector.write(out); | ||
} | ||
|
||
public static class State implements Feature.State { | ||
public final QuantizedVector vector; | ||
|
||
public State(QuantizedVector vector) { | ||
this.vector = vector; | ||
} | ||
} | ||
|
||
ScoreFunction.ExactScoreFunction rerankerFor(VectorFloat<?> queryVector, | ||
VectorSimilarityFunction vsf, | ||
FeatureSource source) { | ||
var function = scorer.scoreFunctionFor(queryVector, vsf); | ||
|
||
return new ScoreFunction.ExactScoreFunction() { | ||
private final QuantizedVector scratch = NVQuantization.QuantizedVector.createEmpty(nvq.subvectorSizesAndOffsets, nvq.bitsPerDimension); | ||
|
||
@Override | ||
public float similarityTo(int node2) { | ||
try { | ||
var reader = source.inlineReaderForNode(node2, FeatureId.NVQ_VECTORS); | ||
QuantizedVector.loadInto(reader, scratch); | ||
} catch (IOException e) { | ||
throw new RuntimeException(e); | ||
} | ||
return function.similarityTo(scratch); | ||
} | ||
}; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -324,6 +324,8 @@ public OnDiskGraphIndexWriter build() throws IOException { | |
int dimension; | ||
if (features.containsKey(FeatureId.INLINE_VECTORS)) { | ||
dimension = ((InlineVectors) features.get(FeatureId.INLINE_VECTORS)).dimension(); | ||
} else if (features.containsKey(FeatureId.NVQ_VECTORS)) { | ||
dimension = ((NVQ) features.get(FeatureId.NVQ_VECTORS)).dimension(); | ||
} else { | ||
throw new IllegalArgumentException("Inline vectors must be provided."); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should update this error message to indicate that either inline vectors or NVQ vectors much be provided. |
||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
/* | ||
* Copyright DataStax, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
package io.github.jbellis.jvector.optimization; | ||
|
||
import java.util.Arrays; | ||
import java.util.stream.IntStream; | ||
|
||
/** | ||
* Class that models a loss function that maps a multidimensional vector onto a real number. | ||
*/ | ||
public abstract class LossFunction { | ||
// The number of dimensions | ||
final private int nDims; | ||
|
||
// The box constraints that define the feasible set. | ||
private float[] minBounds; | ||
private float[] maxBounds; | ||
|
||
/** | ||
* Constructs a LossFunction acting on vectors of the specified number of dimensions. | ||
* @param nDims the number of dimensions | ||
*/ | ||
public LossFunction(int nDims) { | ||
if (nDims <= 0) { | ||
throw new IllegalArgumentException("The standard deviation initSigma must be positive"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bad error message copy/paste |
||
} | ||
minBounds = new float[nDims]; | ||
maxBounds = new float[nDims]; | ||
for (int d = 0; d < nDims; d++) { | ||
minBounds[d] = Float.NEGATIVE_INFINITY; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could use Arrays.fill here |
||
maxBounds[d] = Float.POSITIVE_INFINITY; | ||
} | ||
|
||
this.nDims = nDims; | ||
} | ||
|
||
/** | ||
* Computes the loss function. It assumes that input is within the feasible set | ||
* @param x the input vector | ||
* @return the loss | ||
*/ | ||
public abstract float compute(float[] x); | ||
|
||
/** | ||
* Computes the loss function and projects the input in-place onto the feasible set | ||
* @param x the input vector | ||
* @return the loss | ||
*/ | ||
public float projectCompute(float[] x) { | ||
project(x); | ||
return compute(x); | ||
} | ||
|
||
/** | ||
* Sets the minimum values of the box constraints. | ||
* @param bounds the specified minimum bound | ||
*/ | ||
public void setMinBounds(float[] bounds) { | ||
if (nDims != bounds.length) { | ||
throw new IllegalArgumentException("The length of bounds should match the number of dimensions"); | ||
} | ||
minBounds = bounds; | ||
} | ||
|
||
/** | ||
* Gets the minimum values of the box constraints. | ||
* @return the minimum bound | ||
*/ | ||
public float[] getMinBounds() { | ||
return minBounds; | ||
} | ||
|
||
/** | ||
* Sets the maximum values of the box constraints. | ||
* @param bounds the specified maximum bound | ||
*/ | ||
public void setMaxBounds(float[] bounds) { | ||
if (nDims != bounds.length) { | ||
throw new IllegalArgumentException("The length of bounds should match the number of dimensions"); | ||
} | ||
maxBounds = bounds; | ||
} | ||
|
||
/** | ||
* Gets the maximum values of the box constraints. | ||
* @return the maximum bound | ||
*/ | ||
public float[] getMaxBounds() { | ||
return maxBounds; | ||
} | ||
|
||
/** | ||
* Projects the input onto the feasible set. If in-place, the input array is modified; | ||
* otherwise, a copy is created and then projected. | ||
* @param x the input vector | ||
* @param inPlace If true, the input array is modified; otherwise, a copy is created and then projected. | ||
* @return the projected vector | ||
*/ | ||
public float[] project(float[] x, boolean inPlace) { | ||
float[] copy; | ||
if (inPlace) { | ||
copy = x; | ||
} | ||
else { | ||
copy = Arrays.copyOf(x, x.length); | ||
} | ||
for (int d = 0; d < nDims; d++) { | ||
copy[d] = Math.min(Math.max(x[d], minBounds[d]), maxBounds[d]); | ||
} | ||
return copy; | ||
} | ||
|
||
/** | ||
* Projects the input in-place onto the feasible set. | ||
* @param x the input vector | ||
*/ | ||
public void project(float[] x) { | ||
project(x, true); | ||
} | ||
|
||
public boolean minimumGoalAchieved(float lossValue) { | ||
return false; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In most other places like this, we use a reusable thread-local scratch. I'm not sure if it's worth it, so consider this a possible deferred optimization.