Skip to content

Commit

Permalink
enhance: add search iterator v2
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm committed Dec 25, 2024
1 parent a94b1a2 commit ca9cd43
Show file tree
Hide file tree
Showing 5 changed files with 220 additions and 1 deletion.
4 changes: 4 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,15 @@ def __init__(
)
nq_thres += topk
self._session_ts = session_ts
self._search_iterator_v2_results = res.search_iterator_v2_results
super().__init__(data)

def get_session_ts(self):
return self._session_ts

def get_search_iterator_v2_results_info(self):
return self._search_iterator_v2_results

def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
ITER_SEARCH_ID_KEY = "search_iter_id"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"
HINTS = "hints"

Expand Down
20 changes: 20 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
GROUP_BY_FIELD,
GROUP_SIZE,
HINTS,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
PAGE_RETAIN_ORDER_FIELD,
RANK_GROUP_SCORER,
Expand Down Expand Up @@ -958,6 +962,22 @@ def search_requests_with_expr(
if is_iterator is not None:
search_params[ITERATOR_FIELD] = is_iterator

is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY)
if is_search_iter_v2 is not None:
search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2

search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY)
if search_iter_batch_size is not None:
search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size

search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY)
if search_iter_last_bound is not None:
search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound

search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY)
if search_iter_id is not None:
search_params[ITER_SEARCH_ID_KEY] = search_iter_id

group_by_field = kwargs.get(GROUP_BY_FIELD)
if group_by_field is not None:
search_params[GROUP_BY_FIELD] = group_by_field
Expand Down
120 changes: 120 additions & 0 deletions pymilvus/client/search_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import logging
from copy import deepcopy
from typing import Dict, List, Optional, Union

from pymilvus.client import entity_helper, utils
from pymilvus.client.constants import (
GUARANTEE_TIMESTAMP,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
)
from pymilvus.exceptions import MilvusException, ParamError
from pymilvus.orm.connections import Connections
from pymilvus.orm.constants import MAX_BATCH_SIZE, MILVUS_LIMIT, OFFSET
from pymilvus.orm.iterator import SearchPage, fall_back_to_latest_session_ts

logger = logging.getLogger(__name__)


class SearchIteratorV2:
def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
batch_size: int = 1000,
filter: Optional[str] = None,
output_fields: Optional[List[str]] = None,
search_params: Optional[Dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
round_decimal: int = -1,
**kwargs,
):
self.__check_params(batch_size, data, search_params, kwargs)

# delete limit from incoming for compatibility
if MILVUS_LIMIT in kwargs:
del kwargs[MILVUS_LIMIT]

self._conn = connection
self._params = {
"collection_name": collection_name,
"data": data,
"anns_field": anns_field,
"param": deepcopy(search_params),
"limit": batch_size,
"expression": filter,
"partition_names": partition_names,
"output_fields": output_fields,
"round_decimal": round_decimal,
"timeout": timeout,
ITERATOR_FIELD: True,
ITER_SEARCH_V2_KEY: True,
ITER_SEARCH_BATCH_SIZE_KEY: batch_size,
GUARANTEE_TIMESTAMP: 0,
**kwargs,
}

def next(self):
res = self._conn.search(**self._params)
iter_info = res.get_search_iterator_v2_results_info()
self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound

# patch token and guarantee timestamp for the first next() call
if ITER_SEARCH_ID_KEY not in self._params:
if iter_info.token is not None and iter_info.token != "":
self._params[ITER_SEARCH_ID_KEY] = iter_info.token
else:
raise MilvusException(
message="The server does not support Search Iterator V2. Please upgrade your Milvus server, or create a search_iterator (v1) instead"
)
if self._params[GUARANTEE_TIMESTAMP] <= 0:
if res.get_session_ts() > 0:
self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts()
else:
logger.warning(
"failed to set up mvccTs from milvus server, use client-side ts instead"
)
self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts()

# return SearchPage for compability
if len(res) > 0:
return SearchPage(res[0])
return SearchPage(None)

def close(self):
pass

def __check_params(
self,
batch_size: int,
data: Union[List, utils.SparseMatrixInputType],
param: Dict,
kwargs: Dict,
):
# metric_type can be empty, deduced at server side
# anns_field can be empty, deduced at server side

# check batch size
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")

# check offset
if kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Offset is not supported for search_iterator_v2")

# check num queries, heavy to check at server side
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(
message="search_iterator_v2 does not support processing multiple vectors simultaneously"
)
if rows == 0:
raise ParamError(message="The vector data for search cannot be empty")
73 changes: 72 additions & 1 deletion pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from pymilvus.client.abstract import AnnSearchRequest, BaseRanker
from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL
from pymilvus.client.search_iterator import SearchIteratorV2
from pymilvus.client.types import (
ExceptionsMessage,
ExtraList,
LoadState,
OmitZeroDict,
construct_cost_extra,
)
from pymilvus.client.utils import is_vector_type
from pymilvus.client.utils import SparseMatrixInputType, is_vector_type
from pymilvus.exceptions import (
DataTypeNotMatchException,
ErrorCode,
Expand Down Expand Up @@ -604,6 +605,76 @@ def search_iterator(
**kwargs,
)

def search_iterator_v2(
self,
collection_name: str,
data: Union[List, SparseMatrixInputType],
batch_size: int = 1000,
filter: Optional[str] = None,
output_fields: Optional[List[str]] = None,
search_params: Optional[Dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
round_decimal: int = -1,
**kwargs,
):
"""Creates a search iterator for batch processing of search operations.
This method returns a SearchIteratorV2 object that allows for efficient batch processing
of search operations, particularly useful when dealing with large datasets.
Args:
collection_name (str): The name of the collection to search in.
data (Union[List, SparseMatrixInputType]): The query data to search with. Can be either
a list of floats or a sparse matrix input.
batch_size (int, optional): The number of queries to process in each batch.
Defaults to 1000.
filter (str, optional): The filter expression to apply during search.
Defaults to None.
output_fields (List[str], optional): A list of fields to return in the search
results.
search_params (Dict, optional): Parameters to configure the search behavior.
Defaults to empty dict.
timeout (float, optional): An optional duration of time in seconds to allow for the RPC.
partition_names (List[str], optional): List of partition names to search in.
If None, searches the entire collection.
anns_field (str, optional): The vector field name to search on. It can be empty
if there is only one vector field in the collection.
round_decimal (int, optional): The number of decimal places to round the distance
values to. Defaults to -1 (no rounding).
**kwargs: Optional keyword arguments to pass to the search operation.
Returns:
SearchIteratorV2: An iterator object that yields search results in batches.
Raises:
MilvusException: If anything goes wrong during the search operation.
ParamError: If the input parameters are invalid.
Examples:
>>> # Basic usage
>>> search_iterator = client.search_iterator_v2(
... collection_name="my_collection",
... data=[[1.0, 2.0, 3.0]],
... batch_size=100,
... )
"""
return SearchIteratorV2(
connection=self._get_connection(),
collection_name=collection_name,
data=data,
batch_size=batch_size,
filter=filter,
output_fields=output_fields,
search_params=search_params or {},
timeout=timeout,
partition_names=partition_names,
anns_field=anns_field or "",
round_decimal=round_decimal,
**kwargs,
)

def get(
self,
collection_name: str,
Expand Down

0 comments on commit ca9cd43

Please sign in to comment.