Skip to content

Commit 0b8f7e8

Browse files
committed
Initial version of SQDataset
1 parent a8221ad commit 0b8f7e8

File tree

5 files changed

+446
-0
lines changed

5 files changed

+446
-0
lines changed

Diff for: examples/cpp/CMakeLists.txt

+6
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ create_simple_example(custom_thread_pool test_custom_thread_pool custom_thread_p
4848
configure_file(../../data/test_dataset/data_f32.fvecs . COPYONLY)
4949
configure_file(../../data/test_dataset/queries_f32.fvecs . COPYONLY)
5050
configure_file(../../data/test_dataset/groundtruth_euclidean.ivecs . COPYONLY)
51+
52+
# tmp executable for sqdataset
53+
add_executable(sqdataset sqdataset.cpp)
54+
target_include_directories(sqdataset PRIVATE ${CMAKE_CURRENT_LIST_DIR})
55+
target_link_libraries(sqdataset ${SVS_LIB} svs_compile_options svs_native_options)
56+
5157
# The vamana test executable.
5258
add_executable(vamana vamana.cpp)
5359
target_include_directories(vamana PRIVATE ${CMAKE_CURRENT_LIST_DIR})

Diff for: include/svs/quantization/scalar/impl/scalar_impl.h

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/*
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "svs/quantization/scalar/scalar.h"
20+
21+
#include <cstddef>
22+
23+
namespace svs {
24+
namespace quantization {
25+
namespace scalar {
26+
27+
template <size_t Extent, typename Alloc>
28+
SQDataset<Extent, Alloc>::SQDataset(size_t size, size_t dims)
29+
: data_{size, dims} {}
30+
31+
template <size_t Extent, typename Alloc>
32+
SQDataset<Extent, Alloc>::SQDataset(data_type data, float scale, float bias)
33+
: scale_(scale)
34+
, bias_(bias)
35+
, data_{std::move(data)} {}
36+
37+
template <size_t Extent, typename Alloc> size_t SQDataset<Extent, Alloc>::size() const {
38+
return data_.size();
39+
}
40+
41+
template <size_t Extent, typename Alloc>
42+
size_t SQDataset<Extent, Alloc>::dimensions() const {
43+
return data_.dimensions();
44+
}
45+
46+
template <size_t Extent, typename Alloc>
47+
typename SQDataset<Extent, Alloc>::const_value_type
48+
SQDataset<Extent, Alloc>::get_datum(size_t i) const {
49+
// return data_.get_datum(i);
50+
// decompress data
51+
auto result = std::vector<float>(dimensions());
52+
compressed_value_type compressed = data_.get_datum(i);
53+
for (size_t j = 0; j < dimensions(); ++j) {
54+
auto val = static_cast<float>(compressed[j]);
55+
result[j] = scale_ * val + bias_;
56+
}
57+
58+
return result;
59+
}
60+
61+
template <size_t Extent, typename Alloc>
62+
template <typename QueryType, size_t N>
63+
void SQDataset<Extent, Alloc>::set_datum(size_t i, std::span<QueryType, N> datum) {
64+
auto dims = dimensions();
65+
assert(datum.size() == dims);
66+
67+
// Compression range extrema
68+
static constexpr std::int8_t MIN = std::numeric_limits<std::int8_t>::min();
69+
static constexpr std::int8_t MAX = std::numeric_limits<std::int8_t>::max();
70+
71+
// Uniform scalar quantization function
72+
auto scalar = [&](float v) -> std::int8_t {
73+
return std::clamp<float>(std::round((v - bias_) / scale_), MIN, MAX);
74+
};
75+
76+
// Prepare compressed elements
77+
std::vector<std::int8_t> buffer(dims);
78+
for (size_t j = 0; j < dims; ++j) {
79+
// Apply scalar quantization to element
80+
buffer[j] = scalar(datum[j]);
81+
}
82+
data_.set_datum(i, buffer);
83+
84+
// TODO: Float16 truncation check? (see codec.h, line 114)
85+
}
86+
87+
template <size_t Extent, typename Alloc>
88+
template <data::ImmutableMemoryDataset Dataset>
89+
SQDataset<Extent, Alloc>
90+
SQDataset<Extent, Alloc>::compress(const Dataset& data, const allocator_type& allocator) {
91+
return compress(data, 1, allocator);
92+
}
93+
94+
template <size_t Extent, typename Alloc>
95+
template <data::ImmutableMemoryDataset Dataset>
96+
SQDataset<Extent, Alloc> SQDataset<Extent, Alloc>::compress(
97+
const Dataset& data, size_t num_threads, const allocator_type& allocator
98+
) {
99+
auto pool = threads::DefaultThreadPool{num_threads};
100+
return compress(data, pool, allocator);
101+
}
102+
103+
template <size_t Extent, typename Alloc>
104+
template <data::ImmutableMemoryDataset Dataset, threads::ThreadPool Pool>
105+
SQDataset<Extent, Alloc> SQDataset<Extent, Alloc>::compress(
106+
const Dataset& data, Pool& threadpool, const allocator_type& allocator
107+
) {
108+
if (Extent != Dynamic && data.dimensions() != Extent) {
109+
throw ANNEXCEPTION("Dimension mismatch!");
110+
}
111+
112+
static constexpr size_t batch_size = 512;
113+
114+
// Helper struct to collect values
115+
struct Accumulator {
116+
double min = 0.0;
117+
double max = 0.0;
118+
119+
void accumulate(double val) {
120+
min = std::min(min, val);
121+
max = std::max(max, val);
122+
}
123+
124+
void merge(const Accumulator& other) {
125+
min = std::min(min, other.min);
126+
max = std::max(max, other.max);
127+
}
128+
};
129+
130+
// Thread-local accumulators
131+
std::vector<Accumulator> tls(threadpool.size());
132+
133+
// Compute mean and squared sum
134+
threads::parallel_for(
135+
threadpool,
136+
threads::DynamicPartition(data.size(), batch_size),
137+
[&](const auto& indices, uint64_t tid) {
138+
threads::UnitRange range{indices};
139+
Accumulator local;
140+
141+
for (size_t i = range.start(); i < range.stop(); ++i) {
142+
const auto& datum = data.get_datum(i);
143+
for (size_t d = 0; d < data.dimensions(); ++d) {
144+
local.accumulate(datum[d]);
145+
}
146+
}
147+
148+
tls.at(tid).merge(local);
149+
}
150+
);
151+
152+
// Reduce
153+
Accumulator global;
154+
for (const auto& partial : tls) {
155+
global.merge(partial);
156+
}
157+
158+
// Compress the scaled and biased values
159+
// TODO: Templated compression bits
160+
// static constexpr size_t bits = 8;
161+
162+
// Compression range extrema
163+
static constexpr std::int8_t MIN = std::numeric_limits<std::int8_t>::min();
164+
static constexpr std::int8_t MAX = std::numeric_limits<std::int8_t>::max();
165+
166+
// Compute scale and bias
167+
float scale = (global.max - global.min) / (MAX - MIN);
168+
float bias = global.min - MIN * scale;
169+
170+
// Uniform scalar quantization function
171+
auto scalar = [&](float v) -> std::int8_t {
172+
return std::clamp<float>(std::round((v - bias) / scale), MIN, MAX);
173+
};
174+
175+
data_type compressed{data.size(), data.dimensions(), allocator};
176+
177+
threads::parallel_for(
178+
threadpool,
179+
threads::DynamicPartition(data.size(), batch_size),
180+
[&](const auto& indices, uint64_t /*tid*/) {
181+
threads::UnitRange range{indices};
182+
for (size_t i = range.start(); i < range.stop(); ++i) {
183+
// Load original row
184+
auto original = data.get_datum(i);
185+
186+
// Allocate temporary buffer for transformed data
187+
std::vector<std::int8_t> transformed(original.size());
188+
189+
for (size_t d = 0; d < original.size(); ++d) {
190+
float val = static_cast<float>(original[d]);
191+
transformed[d] = scalar(val);
192+
}
193+
194+
// Store normalized data back (set_datum will do narrowing if needed)
195+
compressed.set_datum(i, transformed);
196+
}
197+
}
198+
);
199+
200+
return SQDataset<Extent, Alloc>{std::move(compressed), scale, bias};
201+
}
202+
203+
template <size_t Extent, typename Alloc>
204+
lib::SaveTable SQDataset<Extent, Alloc>::save(const lib::SaveContext& ctx) const {
205+
return lib::SaveTable(
206+
serialization_schema,
207+
save_version,
208+
{SVS_LIST_SAVE_(data, ctx),
209+
{"scale", lib::save(scale_, ctx)},
210+
{"bias", lib::save(bias_, ctx)}}
211+
);
212+
}
213+
214+
template <size_t Extent, typename Alloc>
215+
SQDataset<Extent, Alloc> SQDataset<Extent, Alloc>::load(
216+
const lib::LoadTable& table, const allocator_type& allocator
217+
) {
218+
return SQDataset<Extent, Alloc>{
219+
SVS_LOAD_MEMBER_AT_(table, data, allocator),
220+
lib::load_at<float>(table, "scale"),
221+
lib::load_at<float>(table, "bias")};
222+
}
223+
224+
} // namespace scalar
225+
} // namespace quantization
226+
} // namespace svs

Diff for: include/svs/quantization/scalar/scalar.h

+100
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
/*
2+
* Copyright 2025 Intel Corporation
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
// svs
20+
#include "svs/core/data/simple.h"
21+
#include "svs/lib/memory.h"
22+
#include "svs/lib/static.h"
23+
#include "svs/lib/version.h"
24+
25+
// stl
26+
#include <memory>
27+
28+
namespace svs {
29+
namespace quantization {
30+
namespace scalar {
31+
32+
inline constexpr std::string_view scalar_quantization_serialization_schema =
33+
"scalar_quantization_dataset";
34+
inline constexpr lib::Version scalar_quantization_save_version = lib::Version(0, 0, 0);
35+
36+
// Scalar Quantization Dataset
37+
// This class provides a globally quantized (scale & bias) dataset.
38+
template <size_t Extent = svs::Dynamic, typename Alloc = lib::Allocator<std::int8_t>>
39+
class SQDataset {
40+
public:
41+
constexpr static size_t extent = Extent;
42+
43+
using allocator_type = Alloc;
44+
// TODO: replace int8 with template
45+
using data_type = data::SimpleData<std::int8_t, Extent, allocator_type>;
46+
47+
// TODO: get_datum will return this type, other classes would return compressed data
48+
// while we return uncompressed data for simplicity. Maybe this needs to change
49+
// using const_value_type = std::span<const std::int8_t, Extent>;
50+
// using value_type = const_value_type;
51+
// TODO: This is potentially a performance bottleneck. Other datasets simply return a
52+
// view, but because we are manipulating the values before return, they must go into a
53+
// vector
54+
using compressed_value_type = std::span<const std::int8_t, Extent>;
55+
using const_value_type = std::vector<float>;
56+
using value_type = const_value_type;
57+
58+
private:
59+
float scale_;
60+
float bias_;
61+
data_type data_;
62+
63+
public:
64+
SQDataset(size_t size, size_t dims);
65+
SQDataset(data_type data, float scale, float bias);
66+
67+
size_t size() const;
68+
size_t dimensions() const;
69+
70+
float get_scale() const { return scale_; }
71+
float get_bias() const { return bias_; }
72+
73+
const_value_type get_datum(size_t i) const;
74+
75+
template <typename QueryType, size_t N>
76+
void set_datum(size_t i, std::span<QueryType, N> datum);
77+
78+
template <data::ImmutableMemoryDataset Dataset>
79+
static SQDataset compress(const Dataset& data, const allocator_type& allocator = {});
80+
81+
template <data::ImmutableMemoryDataset Dataset>
82+
static SQDataset
83+
compress(const Dataset& data, size_t num_threads, const allocator_type& allocator = {});
84+
85+
template <data::ImmutableMemoryDataset Dataset, threads::ThreadPool Pool>
86+
static SQDataset
87+
compress(const Dataset& data, Pool& threadpool, const allocator_type& allocator = {});
88+
89+
static constexpr lib::Version save_version = scalar_quantization_save_version;
90+
static constexpr std::string_view serialization_schema =
91+
scalar_quantization_serialization_schema;
92+
lib::SaveTable save(const lib::SaveContext& ctx) const;
93+
94+
static SQDataset
95+
load(const lib::LoadTable& table, const allocator_type& allocator = {});
96+
};
97+
98+
} // namespace scalar
99+
} // namespace quantization
100+
} // namespace svs

Diff for: tests/CMakeLists.txt

+3
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,9 @@ set(TEST_SOURCES
139139
# Inverted
140140
${TEST_DIR}/svs/index/inverted/clustering.cpp
141141

142+
# Global scalar quantization
143+
${TEST_DIR}/svs/quantization/scalar/scalar.cpp
144+
142145
# # ${TEST_DIR}/svs/index/vamana/dynamic_index.cpp
143146
)
144147

0 commit comments

Comments
 (0)