-
Notifications
You must be signed in to change notification settings - Fork 85
/
Copy pathkgraph.h
306 lines (283 loc) · 11 KB
/
kgraph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
// Copyright (C) 2013-2015 Wei Dong <[email protected]>. All Rights Reserved.
//
// \mainpage KGraph: A Library for Efficient K-NN Search
// \author Wei Dong \f$ [email protected] \f$
// \author 2013-2015
//
#ifndef WDONG_KGRAPH
#define WDONG_KGRAPH
#include <stdexcept>
namespace kgraph {
static unsigned const default_iterations = 30;
static unsigned const default_L = 100;
static unsigned const default_K = 25;
static unsigned const default_P = 100;
static unsigned const default_M = 0;
static unsigned const default_T = 1;
static unsigned const default_S = 10;
static unsigned const default_R = 100;
static unsigned const default_controls = 100;
static unsigned const default_seed = 1998;
static float const default_delta = 0.002;
static float const default_recall = 0.99;
static float const default_epsilon = 1e30;
static unsigned const default_verbosity = 1;
enum {
PRUNE_LEVEL_1 = 1,
PRUNE_LEVEL_2 = 2
};
enum {
REVERSE_AUTO = -1,
REVERSE_NONE = 0,
};
static unsigned const default_prune = 0;
static int const default_reverse = REVERSE_NONE;
/// Verbosity control
/** Set verbosity = 0 to disable information output to stderr.
*/
extern unsigned verbosity;
/// Index oracle
/** The index oracle is the user-supplied plugin that computes
* the distance between two arbitrary objects in the dataset.
* It is used for offline k-NN graph construction.
*/
class IndexOracle {
public:
/// Returns the size of the dataset.
virtual unsigned size () const = 0;
/// Computes similarity
/**
* 0 <= i, j < size() are the index of two objects in the dataset.
* This method return the distance between objects i and j.
*/
virtual float operator () (unsigned i, unsigned j) const = 0;
};
/// Search oracle
/** The search oracle is the user-supplied plugin that computes
* the distance between the query and a arbitrary object in the dataset.
* It is used for online k-NN search.
*/
class SearchOracle {
public:
/// Returns the size of the dataset.
virtual unsigned size () const = 0;
/// Computes similarity
/**
* 0 <= i < size() are the index of an objects in the dataset.
* This method return the distance between the query and object i.
*/
virtual float operator () (unsigned i) const = 0;
/// Search with brutal force.
/**
* Search results are guaranteed to be ranked in ascending order of distance.
*
* @param K Return at most K nearest neighbors.
* @param epsilon Only returns nearest neighbors within distance epsilon.
* @param ids Pointer to the memory where neighbor IDs are returned.
* @param dists Pointer to the memory where distance values are returned, can be nullptr.
*/
unsigned search (unsigned K, float epsilon, unsigned *ids, float *dists = nullptr) const;
};
/// The KGraph index.
/** This is an abstract base class. Use KGraph::create to create an instance.
*/
class KGraph {
public:
/// Indexing parameters.
struct IndexParams {
unsigned iterations;
unsigned L;
unsigned K;
unsigned S;
unsigned R;
unsigned controls;
unsigned seed;
float delta;
float recall;
unsigned prune;
int reverse;
/// Construct with default values.
IndexParams (): iterations(default_iterations), L(default_L), K(default_K), S(default_S), R(default_R), controls(default_controls), seed(default_seed), delta(default_delta), recall(default_recall), prune(default_prune), reverse(default_reverse) {
}
};
/// Search parameters.
struct SearchParams {
unsigned K;
unsigned M;
unsigned P;
unsigned S;
unsigned T;
float epsilon;
unsigned seed;
unsigned init;
/// Construct with default values.
SearchParams (): K(default_K), M(default_M), P(default_P), S(default_S), T(default_T), epsilon(default_epsilon), seed(1998), init(0) {
}
};
enum {
FORMAT_DEFAULT = 0,
FORMAT_NO_DIST = 1,
FORMAT_TEXT = 128
};
/// Information and statistics of the indexing algorithm.
struct IndexInfo {
enum StopCondition {
ITERATION = 0,
DELTA,
RECALL
} stop_condition;
unsigned iterations;
float cost;
float recall;
float accuracy;
float delta;
float M;
};
/// Information and statistics of the search algorithm.
struct SearchInfo {
float cost;
unsigned updates;
};
virtual ~KGraph () {
}
/// Load index from file.
/**
* @param path Path to the index file.
*/
virtual void load (char const *path) = 0;
/// Save index to file.
/**
* @param path Path to the index file.
*/
virtual void save (char const *path, int format = FORMAT_DEFAULT) const = 0; // save to file
/// Build the index
virtual void build (IndexOracle const &oracle, IndexParams const ¶ms, IndexInfo *info = 0) = 0;
/// Prune the index
/**
* Pruning makes the index smaller to save memory, and makes online search on the pruned index faster.
* (The cost parameters of online search must be enlarged so accuracy is not reduced.)
*
* Currently only two pruning levels are supported:
* - PRUNE_LEVEL_1 = 1: Only reduces index size, fast.
* - PRUNE_LEVEL_2 = 2: For improve online search speed, slow.
*
* No pruning is done if level = 0.
*/
virtual void prune (IndexOracle const &oracle, unsigned level) = 0;
/// Online k-NN search.
/**
* Search results are guaranteed to be ranked in ascending order of distance.
*
* @param ids Pointer to the memory where neighbor IDs are stored, must have space to save params.K ids.
*/
unsigned search (SearchOracle const &oracle, SearchParams const ¶ms, unsigned *ids, SearchInfo *info = 0) const {
return search(oracle, params, ids, nullptr, info);
}
/// Online k-NN search.
/**
* Search results are guaranteed to be ranked in ascending order of distance.
*
* @param ids Pointer to the memory where neighbor IDs are stored, must have space to save params.K values.
* @param dists Pointer to the memory where distances are stored, must have space to save params.K values.
*/
virtual unsigned search (SearchOracle const &oracle, SearchParams const ¶ms, unsigned *ids, float *dists, SearchInfo *info) const = 0;
/// Constructor.
static KGraph *create ();
/// Returns version string.
static char const* version ();
/// Get offline computed k-NNs of a given object.
/**
* See the full version of get_nn.
*/
virtual void get_nn (unsigned id, unsigned *nns, unsigned *M, unsigned *L) const {
get_nn(id, nns, nullptr, M, L);
}
/// Get offline computed k-NNs of a given object.
/**
* The user must provide space to save IndexParams::L values.
* The actually returned L could be smaller than IndexParams::L, and
* M <= L is the number of neighbors KGraph thinks
* could be most useful for online search, and is usually < L.
* If the index has been pruned, the returned L could be smaller than
* IndexParams::L used to construct the index.
*
* @params id Object ID whose neighbor information are returned.
* @params nns Neighbor IDs, must have space to save IndexParams::L values.
* @params dists Distance values, must have space to save IndexParams::L values.
* @params M Useful number of neighbors, output only.
* @params L Actually returned number of neighbors, output only.
*/
virtual void get_nn (unsigned id, unsigned *nns, float *dists, unsigned *M, unsigned *L) const = 0;
virtual void reverse (int) = 0;
};
}
#if __cplusplus > 199711L
#include <functional>
namespace kgraph {
/// Oracle adapter for datasets stored in a vector-like container.
/**
* If the dataset is stored in a container of CONTAINER_TYPE that supports
* - a size() method that returns the number of objects.
* - a [] operator that returns the const reference to an object.
* This class can be used to provide a wrapper to facilitate creating
* the index and search oracles.
*
* The user must provide a callback function that takes in two
* const references to objects and returns a distance value.
*/
template <typename CONTAINER_TYPE, typename OBJECT_TYPE>
class VectorOracle: public IndexOracle {
public:
typedef std::function<float(OBJECT_TYPE const &, OBJECT_TYPE const &)> METRIC_TYPE;
private:
CONTAINER_TYPE const &data;
METRIC_TYPE dist;
public:
class VectorSearchOracle: public SearchOracle {
CONTAINER_TYPE const &data;
OBJECT_TYPE const query;
METRIC_TYPE dist;
public:
VectorSearchOracle (CONTAINER_TYPE const &p, OBJECT_TYPE const &q, METRIC_TYPE m): data(p), query(q), dist(m) {
}
virtual unsigned size () const {
return data.size();
}
virtual float operator () (unsigned i) const {
return dist(data[i], query);
}
};
/// Constructor.
/**
* @param d: the container that holds the dataset.
* @param m: a callback function for distance computation. m(d[i], d[j]) must be
* a valid expression to compute distance.
*/
VectorOracle (CONTAINER_TYPE const &d, METRIC_TYPE m): data(d), dist(m) {
}
virtual unsigned size () const {
return data.size();
}
virtual float operator () (unsigned i, unsigned j) const {
return dist(data[i], data[j]);
}
/// Constructs a search oracle for query object q.
VectorSearchOracle query (OBJECT_TYPE const &q) const {
return VectorSearchOracle(data, q, dist);
}
};
class invalid_argument: public std::invalid_argument {
public:
using std::invalid_argument::invalid_argument;
};
class runtime_error: public std::runtime_error {
public:
using std::runtime_error::runtime_error;
};
class io_error: public runtime_error {
public:
using runtime_error::runtime_error;
};
}
#endif
#endif