Skip to content

Commit b2f88f0

Browse files
authored
feat: support to sepcify ef search param (lancedb#1844)
Signed-off-by: BubbleCal <[email protected]>
1 parent f2e3989 commit b2f88f0

File tree

10 files changed

+165
-0
lines changed

10 files changed

+165
-0
lines changed

nodejs/__test__/table.test.ts

+48
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,54 @@ describe("When creating an index", () => {
477477
expect(rst.numRows).toBe(1);
478478
});
479479

480+
it("should create and search IVF_HNSW indices", async () => {
481+
await tbl.createIndex("vec", {
482+
config: Index.hnswSq(),
483+
});
484+
485+
// check index directory
486+
const indexDir = path.join(tmpDir.name, "test.lance", "_indices");
487+
expect(fs.readdirSync(indexDir)).toHaveLength(1);
488+
const indices = await tbl.listIndices();
489+
expect(indices.length).toBe(1);
490+
expect(indices[0]).toEqual({
491+
name: "vec_idx",
492+
indexType: "IvfHnswSq",
493+
columns: ["vec"],
494+
});
495+
496+
// Search without specifying the column
497+
let rst = await tbl
498+
.query()
499+
.limit(2)
500+
.nearestTo(queryVec)
501+
.distanceType("dot")
502+
.toArrow();
503+
expect(rst.numRows).toBe(2);
504+
505+
// Search using `vectorSearch`
506+
rst = await tbl.vectorSearch(queryVec).limit(2).toArrow();
507+
expect(rst.numRows).toBe(2);
508+
509+
// Search with specifying the column
510+
const rst2 = await tbl
511+
.query()
512+
.limit(2)
513+
.nearestTo(queryVec)
514+
.column("vec")
515+
.toArrow();
516+
expect(rst2.numRows).toBe(2);
517+
expect(rst.toString()).toEqual(rst2.toString());
518+
519+
// test offset
520+
rst = await tbl.query().limit(2).offset(1).nearestTo(queryVec).toArrow();
521+
expect(rst.numRows).toBe(1);
522+
523+
// test ef
524+
rst = await tbl.query().limit(2).nearestTo(queryVec).ef(100).toArrow();
525+
expect(rst.numRows).toBe(2);
526+
});
527+
480528
it("should be able to query unindexed data", async () => {
481529
await tbl.createIndex("vec");
482530
await tbl.add([

nodejs/lancedb/query.ts

+14
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,20 @@ export class VectorQuery extends QueryBase<NativeVectorQuery> {
385385
return this;
386386
}
387387

388+
/**
389+
* Set the number of candidates to consider during the search
390+
*
391+
* This argument is only used when the vector column has an HNSW index.
392+
* If there is no index then this value is ignored.
393+
*
394+
* Increasing this value will increase the recall of your query but will
395+
* also increase the latency of your query. The default value is 1.5*limit.
396+
*/
397+
ef(ef: number): VectorQuery {
398+
super.doCall((inner) => inner.ef(ef));
399+
return this;
400+
}
401+
388402
/**
389403
* Set the vector column to query
390404
*

nodejs/src/query.rs

+5
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ impl VectorQuery {
167167
self.inner = self.inner.clone().nprobes(nprobe as usize);
168168
}
169169

170+
#[napi]
171+
pub fn ef(&mut self, ef: u32) {
172+
self.inner = self.inner.clone().ef(ef as usize);
173+
}
174+
170175
#[napi]
171176
pub fn bypass_vector_index(&mut self) {
172177
self.inner = self.inner.clone().bypass_vector_index()

python/python/lancedb/query.py

+66
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ class Query(pydantic.BaseModel):
131131

132132
fast_search: bool = False
133133

134+
ef: Optional[int] = None
135+
134136

135137
class LanceQueryBuilder(ABC):
136138
"""An abstract query builder. Subclasses are defined for vector search,
@@ -257,6 +259,7 @@ def __init__(self, table: "Table"):
257259
self._with_row_id = False
258260
self._vector = None
259261
self._text = None
262+
self._ef = None
260263

261264
@deprecation.deprecated(
262265
deprecated_in="0.3.1",
@@ -638,6 +641,28 @@ def nprobes(self, nprobes: int) -> LanceVectorQueryBuilder:
638641
self._nprobes = nprobes
639642
return self
640643

644+
def ef(self, ef: int) -> LanceVectorQueryBuilder:
645+
"""Set the number of candidates to consider during search.
646+
647+
Higher values will yield better recall (more likely to find vectors if
648+
they exist) at the expense of latency.
649+
650+
This only applies to the HNSW-related index.
651+
The default value is 1.5 * limit.
652+
653+
Parameters
654+
----------
655+
ef: int
656+
The number of candidates to consider during search.
657+
658+
Returns
659+
-------
660+
LanceVectorQueryBuilder
661+
The LanceQueryBuilder object.
662+
"""
663+
self._ef = ef
664+
return self
665+
641666
def refine_factor(self, refine_factor: int) -> LanceVectorQueryBuilder:
642667
"""Set the refine factor to use, increasing the number of vectors sampled.
643668
@@ -700,6 +725,7 @@ def to_batches(self, /, batch_size: Optional[int] = None) -> pa.RecordBatchReade
700725
with_row_id=self._with_row_id,
701726
offset=self._offset,
702727
fast_search=self._fast_search,
728+
ef=self._ef,
703729
)
704730
result_set = self._table._execute_query(query, batch_size)
705731
if self._reranker is not None:
@@ -1071,6 +1097,8 @@ def to_arrow(self) -> pa.Table:
10711097
self._vector_query.nprobes(self._nprobes)
10721098
if self._refine_factor:
10731099
self._vector_query.refine_factor(self._refine_factor)
1100+
if self._ef:
1101+
self._vector_query.ef(self._ef)
10741102

10751103
with ThreadPoolExecutor() as executor:
10761104
fts_future = executor.submit(self._fts_query.with_row_id(True).to_arrow)
@@ -1197,6 +1225,29 @@ def nprobes(self, nprobes: int) -> LanceHybridQueryBuilder:
11971225
self._nprobes = nprobes
11981226
return self
11991227

1228+
def ef(self, ef: int) -> LanceHybridQueryBuilder:
1229+
"""
1230+
Set the number of candidates to consider during search.
1231+
1232+
Higher values will yield better recall (more likely to find vectors if
1233+
they exist) at the expense of latency.
1234+
1235+
This only applies to the HNSW-related index.
1236+
The default value is 1.5 * limit.
1237+
1238+
Parameters
1239+
----------
1240+
ef: int
1241+
The number of candidates to consider during search.
1242+
1243+
Returns
1244+
-------
1245+
LanceHybridQueryBuilder
1246+
The LanceHybridQueryBuilder object.
1247+
"""
1248+
self._ef = ef
1249+
return self
1250+
12001251
def metric(self, metric: Literal["L2", "cosine", "dot"]) -> LanceHybridQueryBuilder:
12011252
"""Set the distance metric to use.
12021253
@@ -1644,6 +1695,21 @@ def nprobes(self, nprobes: int) -> AsyncVectorQuery:
16441695
self._inner.nprobes(nprobes)
16451696
return self
16461697

1698+
def ef(self, ef: int) -> AsyncVectorQuery:
1699+
"""
1700+
Set the number of candidates to consider during search
1701+
1702+
This argument is only used when the vector column has an HNSW index.
1703+
If there is no index then this value is ignored.
1704+
1705+
Increasing this value will increase the recall of your query but will also
1706+
increase the latency of your query. The default value is 1.5 * limit. This
1707+
default is good for many cases but the best value to use will depend on your
1708+
data and the recall that you need to achieve.
1709+
"""
1710+
self._inner.ef(ef)
1711+
return self
1712+
16471713
def refine_factor(self, refine_factor: int) -> AsyncVectorQuery:
16481714
"""
16491715
A multiplier to control how many additional rows are taken during the refine

python/python/lancedb/table.py

+3
Original file line numberDiff line numberDiff line change
@@ -1959,6 +1959,7 @@ def _execute_query(
19591959
"metric": query.metric,
19601960
"nprobes": query.nprobes,
19611961
"refine_factor": query.refine_factor,
1962+
"ef": query.ef,
19621963
}
19631964
return ds.scanner(
19641965
columns=query.columns,
@@ -2736,6 +2737,8 @@ async def _execute_query(
27362737
async_query = async_query.refine_factor(query.refine_factor)
27372738
if query.vector_column:
27382739
async_query = async_query.column(query.vector_column)
2740+
if query.ef:
2741+
async_query = async_query.ef(query.ef)
27392742

27402743
if not query.prefilter:
27412744
async_query = async_query.postfilter()

python/python/tests/test_remote_db.py

+3
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ def handler(body):
185185
"k": 10,
186186
"prefilter": False,
187187
"refine_factor": None,
188+
"ef": None,
188189
"vector": [1.0, 2.0, 3.0],
189190
"nprobes": 20,
190191
}
@@ -223,6 +224,7 @@ def handler(body):
223224
"refine_factor": 10,
224225
"vector": [1.0, 2.0, 3.0],
225226
"nprobes": 5,
227+
"ef": None,
226228
"filter": "id > 0",
227229
"columns": ["id", "name"],
228230
"vector_column": "vector2",
@@ -318,6 +320,7 @@ def handler(body):
318320
"refine_factor": None,
319321
"vector": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
320322
"nprobes": 20,
323+
"ef": None,
321324
"with_row_id": True,
322325
}
323326
return pa.table({"_rowid": [1, 2, 3], "_distance": [0.1, 0.2, 0.3]})

python/src/query.rs

+4
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,10 @@ impl VectorQuery {
195195
self.inner = self.inner.clone().nprobes(nprobe as usize);
196196
}
197197

198+
pub fn ef(&mut self, ef: u32) {
199+
self.inner = self.inner.clone().ef(ef as usize);
200+
}
201+
198202
pub fn bypass_vector_index(&mut self) {
199203
self.inner = self.inner.clone().bypass_vector_index()
200204
}

rust/lancedb/src/query.rs

+16
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,9 @@ pub struct VectorQuery {
704704
// IVF PQ - ANN search.
705705
pub(crate) query_vector: Vec<Arc<dyn Array>>,
706706
pub(crate) nprobes: usize,
707+
// The number of candidates to return during the refine step for HNSW,
708+
// defaults to 1.5 * limit.
709+
pub(crate) ef: Option<usize>,
707710
pub(crate) refine_factor: Option<u32>,
708711
pub(crate) distance_type: Option<DistanceType>,
709712
/// Default is true. Set to false to enforce a brute force search.
@@ -717,6 +720,7 @@ impl VectorQuery {
717720
column: None,
718721
query_vector: Vec::new(),
719722
nprobes: 20,
723+
ef: None,
720724
refine_factor: None,
721725
distance_type: None,
722726
use_index: true,
@@ -776,6 +780,18 @@ impl VectorQuery {
776780
self
777781
}
778782

783+
/// Set the number of candidates to return during the refine step for HNSW
784+
///
785+
/// This argument is only used when the vector column has an HNSW index.
786+
/// If there is no index then this value is ignored.
787+
///
788+
/// Increasing this value will increase the recall of your query but will
789+
/// also increase the latency of your query. The default value is 1.5*limit.
790+
pub fn ef(mut self, ef: usize) -> Self {
791+
self.ef = Some(ef);
792+
self
793+
}
794+
779795
/// A multiplier to control how many additional rows are taken during the refine step
780796
///
781797
/// This argument is only used when the vector column has an IVF PQ index.

rust/lancedb/src/remote/table.rs

+3
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ impl<S: HttpSend> RemoteTable<S> {
196196
body["prefilter"] = query.base.prefilter.into();
197197
body["distance_type"] = serde_json::json!(query.distance_type.unwrap_or_default());
198198
body["nprobes"] = query.nprobes.into();
199+
body["ef"] = query.ef.into();
199200
body["refine_factor"] = query.refine_factor.into();
200201
if let Some(vector_column) = query.column.as_ref() {
201202
body["vector_column"] = serde_json::Value::String(vector_column.clone());
@@ -1121,6 +1122,7 @@ mod tests {
11211122
"prefilter": true,
11221123
"distance_type": "l2",
11231124
"nprobes": 20,
1125+
"ef": Option::<usize>::None,
11241126
"refine_factor": null,
11251127
});
11261128
// Pass vector separately to make sure it matches f32 precision.
@@ -1166,6 +1168,7 @@ mod tests {
11661168
"bypass_vector_index": true,
11671169
"columns": ["a", "b"],
11681170
"nprobes": 12,
1171+
"ef": Option::<usize>::None,
11691172
"refine_factor": 2,
11701173
});
11711174
// Pass vector separately to make sure it matches f32 precision.

rust/lancedb/src/table.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1904,6 +1904,9 @@ impl TableInternal for NativeTable {
19041904
query.base.offset.map(|offset| offset as i64),
19051905
)?;
19061906
scanner.nprobs(query.nprobes);
1907+
if let Some(ef) = query.ef {
1908+
scanner.ef(ef);
1909+
}
19071910
scanner.use_index(query.use_index);
19081911
scanner.prefilter(query.base.prefilter);
19091912
match query.base.select {

0 commit comments

Comments
 (0)