Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regression in assembleAndSum PQ decoder performance #379

Merged
merged 2 commits into from
Dec 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@
* Support class for vector operations using a mix of native and Panama SIMD.
*/
final class VectorSimdOps {
static final boolean HAS_AVX512 = IntVector.SPECIES_PREFERRED == IntVector.SPECIES_512;

static float sum(MemorySegmentVectorFloat vector) {
var sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@
import java.util.List;

final class SimdOps {

static final boolean HAS_AVX512 = IntVector.SPECIES_PREFERRED == IntVector.SPECIES_512;
static final int PREFERRED_BIT_SIZE = FloatVector.SPECIES_PREFERRED.vectorBitSize();
static final IntVector BYTE_TO_INT_MASK_512 = IntVector.broadcast(IntVector.SPECIES_512, 0xff);
static final IntVector BYTE_TO_INT_MASK_256 = IntVector.broadcast(IntVector.SPECIES_256, 0xff);


static final ThreadLocal<int[]> scratchInt512 = ThreadLocal.withInitial(() -> new int[IntVector.SPECIES_512.length()]);
static final ThreadLocal<int[]> scratchInt256 = ThreadLocal.withInitial(() -> new int[IntVector.SPECIES_256.length()]);


static float sum(ArrayVectorFloat vector) {
var sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED);
int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(vector.length());
Expand Down Expand Up @@ -517,8 +518,13 @@ static VectorFloat<?> sub(ArrayVectorFloat a, int aOffset, ArrayVectorFloat b, i
}

static float assembleAndSum(float[] data, int dataBase, ByteSequence<byte[]> baseOffsets) {
return HAS_AVX512 ? assembleAndSum512(data, dataBase, baseOffsets)
: assembleAndSum256(data, dataBase, baseOffsets);
return switch (PREFERRED_BIT_SIZE)
{
case 512 -> assembleAndSum512(data, dataBase, baseOffsets);
case 256 -> assembleAndSum256(data, dataBase, baseOffsets);
case 128 -> assembleAndSum128(data, dataBase, baseOffsets);
default -> throw new IllegalStateException("Unsupported vector width: " + PREFERRED_BIT_SIZE);
};
}

static float assembleAndSum512(float[] data, int dataBase, ByteSequence<byte[]> baseOffsets) {
Expand Down Expand Up @@ -578,6 +584,15 @@ static float assembleAndSum256(float[] data, int dataBase, ByteSequence<byte[]>
return res;
}

static float assembleAndSum128(float[] data, int dataBase, ByteSequence<byte[]> baseOffsets) {
// benchmarking a 128-bit SIMD implementation showed it performed worse than scalar
float sum = 0f;
for (int i = 0; i < baseOffsets.length(); i++) {
sum += data[dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i))];
}
return sum;
}

/**
* Vectorized calculation of Hamming distance for two arrays of long integers.
* Both arrays should have the same length.
Expand Down Expand Up @@ -662,9 +677,12 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
}

public static float pqDecodedCosineSimilarity(ByteSequence<byte[]> encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
return HAS_AVX512
? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
: pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
return switch (PREFERRED_BIT_SIZE) {
case 512 -> pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
case 256 -> pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
case 128 -> pqDecodedCosineSimilarity128(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
default -> throw new IllegalStateException("Unsupported vector width: " + PREFERRED_BIT_SIZE);
};
}

public static float pqDecodedCosineSimilarity512(ByteSequence<byte[]> baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
Expand Down Expand Up @@ -742,4 +760,19 @@ public static float pqDecodedCosineSimilarity256(ByteSequence<byte[]> baseOffset

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}

public static float pqDecodedCosineSimilarity128(ByteSequence<byte[]> baseOffsets, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
// benchmarking showed that a 128-bit SIMD implementation performed worse than scalar
float sum = 0.0f;
float aMag = 0.0f;

for (int m = 0; m < baseOffsets.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(baseOffsets.get(m));
var index = m * clusterCount + centroidIndex;
sum += partialSums.get(index);
aMag += aMagnitude.get(index);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
}
}
Loading