From 808a10530cbaca4574a30421cb581bff91a019f8 Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Wed, 4 Dec 2024 15:03:53 +0800 Subject: [PATCH] enhance: add search iterator v2 Signed-off-by: Patrick Weizhi Xu --- pymilvus/client/abstract.py | 4 + pymilvus/client/constants.py | 4 + pymilvus/client/prepare.py | 20 ++++ pymilvus/client/search_iterator.py | 120 ++++++++++++++++++++++++ pymilvus/milvus_client/milvus_client.py | 73 +++++++++++++- 5 files changed, 220 insertions(+), 1 deletion(-) create mode 100644 pymilvus/client/search_iterator.py diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 7cbff180a..88d23f683 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -504,11 +504,15 @@ def __init__( ) nq_thres += topk self._session_ts = session_ts + self._search_iterator_v2_results = res.search_iterator_v2_results super().__init__(data) def get_session_ts(self): return self._session_ts + def get_search_iterator_v2_results_info(self): + return self._search_iterator_v2_results + def get_fields_by_range( self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData] ) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]: diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index 1d1dac08b..0d31c7af1 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -16,6 +16,10 @@ STRICT_GROUP_SIZE = "strict_group_size" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +ITER_SEARCH_V2_KEY = "search_iter_v2" +ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" +ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" +ITER_SEARCH_ID_KEY = "search_iter_id" PAGE_RETAIN_ORDER_FIELD = "page_retain_order" HINTS = "hints" diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 586d7e5ae..921ed8257 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -20,6 +20,10 @@ GROUP_BY_FIELD, GROUP_SIZE, HINTS, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_ID_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_V2_KEY, ITERATOR_FIELD, PAGE_RETAIN_ORDER_FIELD, RANK_GROUP_SCORER, @@ -958,6 +962,22 @@ def search_requests_with_expr( if is_iterator is not None: search_params[ITERATOR_FIELD] = is_iterator + is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY) + if is_search_iter_v2 is not None: + search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2 + + search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY) + if search_iter_batch_size is not None: + search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size + + search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY) + if search_iter_last_bound is not None: + search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound + + search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY) + if search_iter_id is not None: + search_params[ITER_SEARCH_ID_KEY] = search_iter_id + group_by_field = kwargs.get(GROUP_BY_FIELD) if group_by_field is not None: search_params[GROUP_BY_FIELD] = group_by_field diff --git a/pymilvus/client/search_iterator.py b/pymilvus/client/search_iterator.py new file mode 100644 index 000000000..b75455437 --- /dev/null +++ b/pymilvus/client/search_iterator.py @@ -0,0 +1,120 @@ +import logging +from copy import deepcopy +from typing import Dict, List, Optional, Union +from pymilvus.client import entity_helper, utils +from pymilvus.client.constants import ( + GUARANTEE_TIMESTAMP, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_ID_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_V2_KEY, + ITERATOR_FIELD, +) +from pymilvus.exceptions import MilvusException, ParamError +from pymilvus.orm.connections import Connections +from pymilvus.orm.constants import MAX_BATCH_SIZE, METRIC_TYPE, MILVUS_LIMIT, OFFSET +from pymilvus.orm.iterator import fall_back_to_latest_session_ts, SearchPage + +logger = logging.getLogger(__name__) + + +class SearchIteratorV2: + def __init__( + self, + connection: Connections, + collection_name: str, + data: Union[List, utils.SparseMatrixInputType], + batch_size: int = 1000, + filter: Optional[str] = None, + output_fields: Optional[List[str]] = None, + search_params: Optional[Dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + round_decimal: int = -1, + **kwargs, + ): + self.__check_params(batch_size, data, search_params, kwargs) + + # delete limit from incoming for compatibility + if MILVUS_LIMIT in kwargs: + del kwargs[MILVUS_LIMIT] + + self._conn = connection + self._params = { + "collection_name": collection_name, + "data": data, + "anns_field": anns_field, + "param": deepcopy(search_params), + "limit": batch_size, + "expression": filter, + "partition_names": partition_names, + "output_fields": output_fields, + "round_decimal": round_decimal, + "timeout": timeout, + ITERATOR_FIELD: True, + ITER_SEARCH_V2_KEY: True, + ITER_SEARCH_BATCH_SIZE_KEY: batch_size, + GUARANTEE_TIMESTAMP: 0, + **kwargs, + } + + def next(self): + res = self._conn.search(**self._params) + iter_info = res.get_search_iterator_v2_results_info() + self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound + + # patch token and guarantee timestamp for the first next() call + if ITER_SEARCH_ID_KEY not in self._params: + if iter_info.token is not None and iter_info.token != "": + self._params[ITER_SEARCH_ID_KEY] = iter_info.token + else: + raise MilvusException( + message="The server does not support Search Iterator V2. Please upgrade your Milvus server, or create a search_iterator (v1) instead" + ) + if self._params[GUARANTEE_TIMESTAMP] <= 0: + if res.get_session_ts() > 0: + self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts() + else: + logger.warning( + "failed to set up mvccTs from milvus server, use client-side ts instead" + ) + self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts() + + # return SearchPage for compability + if len(res) > 0: + return SearchPage(res[0]) + return SearchPage(None) + + def close(self): + pass + + def __check_params( + self, + batch_size: int, + data: Union[List, utils.SparseMatrixInputType], + param: Dict, + kwargs: Dict, + ): + # metric_type can be empty, deduced at server side + # anns_field can be empty, deduced at server side + + # check batch size + if batch_size < 0: + raise ParamError(message="batch size cannot be less than zero") + if batch_size > MAX_BATCH_SIZE: + raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") + + # check offset + if kwargs.get(OFFSET, 0) != 0: + raise ParamError(message="Offset is not supported for search_iterator_v2") + + # check num queries, heavy to check at server side + rows = entity_helper.get_input_num_rows(data) + print(f"rows: {rows}") + if rows > 1: + raise ParamError( + message="search_iterator_v2 does not support processing multiple vectors simultaneously" + ) + if rows == 0: + raise ParamError(message="The vector data for search cannot be empty") diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index d58f0f77a..dd436ddfc 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -6,6 +6,7 @@ from pymilvus.client.abstract import AnnSearchRequest, BaseRanker from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL +from pymilvus.client.search_iterator import SearchIteratorV2 from pymilvus.client.types import ( ExceptionsMessage, ExtraList, @@ -13,7 +14,7 @@ OmitZeroDict, construct_cost_extra, ) -from pymilvus.client.utils import is_vector_type +from pymilvus.client.utils import is_vector_type, SparseMatrixInputType from pymilvus.exceptions import ( DataTypeNotMatchException, ErrorCode, @@ -604,6 +605,76 @@ def search_iterator( **kwargs, ) + def search_iterator_v2( + self, + collection_name: str, + data: Union[List, SparseMatrixInputType], + batch_size: int = 1000, + filter: Optional[str] = None, + output_fields: Optional[List[str]] = None, + search_params: Optional[Dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + round_decimal: int = -1, + **kwargs, + ): + """Creates a search iterator for batch processing of search operations. + + This method returns a SearchIteratorV2 object that allows for efficient batch processing + of search operations, particularly useful when dealing with large datasets. + + Args: + collection_name (str): The name of the collection to search in. + data (Union[List, SparseMatrixInputType]): The query data to search with. Can be either + a list of floats or a sparse matrix input. + batch_size (int, optional): The number of queries to process in each batch. + Defaults to 1000. + filter (str, optional): The filter expression to apply during search. + Defaults to None. + output_fields (List[str], optional): A list of fields to return in the search + results. + search_params (Dict, optional): Parameters to configure the search behavior. + Defaults to empty dict. + timeout (float, optional): An optional duration of time in seconds to allow for the RPC. + partition_names (List[str], optional): List of partition names to search in. + If None, searches the entire collection. + anns_field (str, optional): The vector field name to search on. It can be empty + if there is only one vector field in the collection. + round_decimal (int, optional): The number of decimal places to round the distance + values to. Defaults to -1 (no rounding). + **kwargs: Optional keyword arguments to pass to the search operation. + + Returns: + SearchIteratorV2: An iterator object that yields search results in batches. + + Raises: + MilvusException: If anything goes wrong during the search operation. + ParamError: If the input parameters are invalid. + + Examples: + >>> # Basic usage + >>> search_iterator = client.search_iterator_v2( + ... collection_name="my_collection", + ... data=[[1.0, 2.0, 3.0]], + ... batch_size=100, + ... ) + """ + return SearchIteratorV2( + connection=self._get_connection(), + collection_name=collection_name, + data=data, + batch_size=batch_size, + filter=filter, + output_fields=output_fields, + search_params=search_params or {}, + timeout=timeout, + partition_names=partition_names, + anns_field=anns_field or "", + round_decimal=round_decimal, + **kwargs, + ) + def get( self, collection_name: str,