From bf042523a1b9d6802a90b434fddc88f469b3015d Mon Sep 17 00:00:00 2001 From: Patrick Weizhi Xu Date: Wed, 4 Dec 2024 15:03:53 +0800 Subject: [PATCH] enhance: add search iterator v2 Signed-off-by: Patrick Weizhi Xu --- pymilvus/client/abstract.py | 4 + pymilvus/client/constants.py | 4 + pymilvus/client/prepare.py | 20 ++++ pymilvus/client/search_iterator.py | 135 ++++++++++++++++++++++++ pymilvus/milvus_client/milvus_client.py | 69 +++++++++++- 5 files changed, 230 insertions(+), 2 deletions(-) create mode 100644 pymilvus/client/search_iterator.py diff --git a/pymilvus/client/abstract.py b/pymilvus/client/abstract.py index 7cbff180a..88d23f683 100644 --- a/pymilvus/client/abstract.py +++ b/pymilvus/client/abstract.py @@ -504,11 +504,15 @@ def __init__( ) nq_thres += topk self._session_ts = session_ts + self._search_iterator_v2_results = res.search_iterator_v2_results super().__init__(data) def get_session_ts(self): return self._session_ts + def get_search_iterator_v2_results_info(self): + return self._search_iterator_v2_results + def get_fields_by_range( self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData] ) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]: diff --git a/pymilvus/client/constants.py b/pymilvus/client/constants.py index 1d1dac08b..0d31c7af1 100644 --- a/pymilvus/client/constants.py +++ b/pymilvus/client/constants.py @@ -16,6 +16,10 @@ STRICT_GROUP_SIZE = "strict_group_size" ITERATOR_FIELD = "iterator" ITERATOR_SESSION_TS_FIELD = "iterator_session_ts" +ITER_SEARCH_V2_KEY = "search_iter_v2" +ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size" +ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound" +ITER_SEARCH_ID_KEY = "search_iter_id" PAGE_RETAIN_ORDER_FIELD = "page_retain_order" HINTS = "hints" diff --git a/pymilvus/client/prepare.py b/pymilvus/client/prepare.py index 586d7e5ae..921ed8257 100644 --- a/pymilvus/client/prepare.py +++ b/pymilvus/client/prepare.py @@ -20,6 +20,10 @@ GROUP_BY_FIELD, GROUP_SIZE, HINTS, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_ID_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_V2_KEY, ITERATOR_FIELD, PAGE_RETAIN_ORDER_FIELD, RANK_GROUP_SCORER, @@ -958,6 +962,22 @@ def search_requests_with_expr( if is_iterator is not None: search_params[ITERATOR_FIELD] = is_iterator + is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY) + if is_search_iter_v2 is not None: + search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2 + + search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY) + if search_iter_batch_size is not None: + search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size + + search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY) + if search_iter_last_bound is not None: + search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound + + search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY) + if search_iter_id is not None: + search_params[ITER_SEARCH_ID_KEY] = search_iter_id + group_by_field = kwargs.get(GROUP_BY_FIELD) if group_by_field is not None: search_params[GROUP_BY_FIELD] = group_by_field diff --git a/pymilvus/client/search_iterator.py b/pymilvus/client/search_iterator.py new file mode 100644 index 000000000..b6850473c --- /dev/null +++ b/pymilvus/client/search_iterator.py @@ -0,0 +1,135 @@ +import logging +from copy import deepcopy +from typing import Dict, List, Optional, Union + +from pymilvus.client import entity_helper, utils +from pymilvus.client.constants import ( + GUARANTEE_TIMESTAMP, + ITER_SEARCH_BATCH_SIZE_KEY, + ITER_SEARCH_ID_KEY, + ITER_SEARCH_LAST_BOUND_KEY, + ITER_SEARCH_V2_KEY, + ITERATOR_FIELD, +) +from pymilvus.exceptions import MilvusException, ParamError +from pymilvus.orm.connections import Connections +from pymilvus.orm.constants import MAX_BATCH_SIZE, MILVUS_LIMIT, OFFSET +from pymilvus.orm.iterator import SearchPage, fall_back_to_latest_session_ts + +logger = logging.getLogger(__name__) + + +class SearchIteratorV2: + _NOT_SUPPORT_V2_MSG = """ + The server does not support Search Iterator V2. + Please upgrade your Milvus server, or create a search_iterator (v1) instead. + """ + + # for compatibility, save the first result during init + _saved_first_res = None + _is_saved = False + + def __init__( + self, + connection: Connections, + collection_name: str, + data: Union[List, utils.SparseMatrixInputType], + batch_size: int = 1000, + filter: Optional[str] = None, + output_fields: Optional[List[str]] = None, + search_params: Optional[Dict] = None, + timeout: Optional[float] = None, + partition_names: Optional[List[str]] = None, + anns_field: Optional[str] = None, + round_decimal: Optional[int] = -1, + **kwargs, + ): + self._check_params(batch_size, data, kwargs) + + # delete limit from incoming for compatibility + if MILVUS_LIMIT in kwargs: + del kwargs[MILVUS_LIMIT] + + self._conn = connection + self._params = { + "collection_name": collection_name, + "data": data, + "anns_field": anns_field, + "param": deepcopy(search_params), + "limit": batch_size, + "expression": filter, + "partition_names": partition_names, + "output_fields": output_fields, + "timeout": timeout, + "round_decimal": round_decimal, + ITERATOR_FIELD: True, + ITER_SEARCH_V2_KEY: True, + ITER_SEARCH_BATCH_SIZE_KEY: batch_size, + GUARANTEE_TIMESTAMP: 0, + **kwargs, + } + # this raises MilvusException if the server does not support V2 + self._saved_first_res = self.next() + self._is_saved = True + + def next(self): + # for compatibility + if self._is_saved: + self._is_saved = False + return self._saved_first_res + self._saved_first_res = None + + res = self._conn.search(**self._params) + iter_info = res.get_search_iterator_v2_results_info() + self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound + + # patch token and guarantee timestamp for the first next() call + if ITER_SEARCH_ID_KEY not in self._params: + if iter_info.token is not None and iter_info.token != "": + self._params[ITER_SEARCH_ID_KEY] = iter_info.token + else: + raise MilvusException(message=self.NOT_SUPPORT_V2_MSG) + if self._params[GUARANTEE_TIMESTAMP] <= 0: + if res.get_session_ts() > 0: + self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts() + else: + logger.warning( + "failed to set up mvccTs from milvus server, use client-side ts instead" + ) + self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts() + + # return SearchPage for compability + if len(res) > 0: + return SearchPage(res[0]) + return SearchPage(None) + + def close(self): + pass + + def _check_params( + self, + batch_size: int, + data: Union[List, utils.SparseMatrixInputType], + kwargs: Dict, + ): + # metric_type can be empty, deduced at server side + # anns_field can be empty, deduced at server side + + # check batch size + if batch_size < 0: + raise ParamError(message="batch size cannot be less than zero") + if batch_size > MAX_BATCH_SIZE: + raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") + + # check offset + if kwargs.get(OFFSET, 0) != 0: + raise ParamError(message="Offset is not supported for search_iterator_v2") + + # check num queries, heavy to check at server side + rows = entity_helper.get_input_num_rows(data) + if rows > 1: + raise ParamError( + message="search_iterator_v2 does not support processing multiple vectors simultaneously" + ) + if rows == 0: + raise ParamError(message="The vector data for search cannot be empty") diff --git a/pymilvus/milvus_client/milvus_client.py b/pymilvus/milvus_client/milvus_client.py index d58f0f77a..3393288e4 100644 --- a/pymilvus/milvus_client/milvus_client.py +++ b/pymilvus/milvus_client/milvus_client.py @@ -6,6 +6,7 @@ from pymilvus.client.abstract import AnnSearchRequest, BaseRanker from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL +from pymilvus.client.search_iterator import SearchIteratorV2 from pymilvus.client.types import ( ExceptionsMessage, ExtraList, @@ -533,11 +534,75 @@ def search_iterator( anns_field: Optional[str] = None, round_decimal: int = -1, **kwargs, - ): + ) -> Union[SearchIteratorV2, SearchIterator]: + """Creates an iterator for searching vectors in batches. + + This method returns an iterator that performs vector similarity search in batches, + which is useful when dealing with large result sets. It automatically attempts to use + Search Iterator V2 if supported by the server, otherwise falls back to V1. + + Args: + collection_name (str): Name of the collection to search in. + data (Union[List[list], list]): Vector data to search with. For V2, only single vector + search is supported. + batch_size (int, optional): Number of results to fetch per batch. Defaults to 1000. + Must be between 1 and MAX_BATCH_SIZE. + filter (str, optional): Filtering expression to filter the results. Defaults to None. + limit (int, optional): Total number of results to return. Defaults to UNLIMITED. + V2 ignores this parameter. + output_fields (List[str], optional): Fields to return in the results. + search_params (dict, optional): Parameters for the search operation. + timeout (float, optional): Timeout in seconds for each RPC call. + partition_names (List[str], optional): Names of partitions to search in. + anns_field (str, optional): Name of the vector field to search. Can be empty when + there is only one vector field in the collection. + round_decimal (int, optional): Number of decimal places for distance values. + Defaults to -1 (no rounding). + **kwargs: Additional arguments to pass to the search operation. + + Returns: + SearchIterator: An iterator object that yields search results in batches. + + Raises: + MilvusException: If the search operation fails. + ParamError: If the input parameters are invalid (e.g., invalid batch_size or multiple + vectors in data when using V2). + + Examples: + >>> # Search with iterator + >>> iterator = client.search_iterator( + ... collection_name="my_collection", + ... data=[[0.1, 0.2]], + ... batch_size=100 + ... ) + """ + + conn = self._get_connection() + + # compatibility logic, change this when support get version from server + try: + return SearchIteratorV2( + connection=conn, + collection_name=collection_name, + data=data, + batch_size=batch_size, + filter=filter, + output_fields=output_fields, + search_params=search_params or {}, + timeout=timeout, + partition_names=partition_names, + anns_field=anns_field or "", + round_decimal=round_decimal, + **kwargs, + ) + except MilvusException as ex: + if ex.message != SearchIteratorV2._NOT_SUPPORT_V2_MSG: + raise ex from ex + + # following is the old code for search_iterator V1 if filter is not None and not isinstance(filter, str): raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter)) - conn = self._get_connection() # set up schema for iterator try: schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)