Skip to content

Commit

Permalink
enhance: add search iterator v2
Browse files Browse the repository at this point in the history
Signed-off-by: Patrick Weizhi Xu <[email protected]>
  • Loading branch information
PwzXxm committed Dec 26, 2024
1 parent a94b1a2 commit bf04252
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 2 deletions.
4 changes: 4 additions & 0 deletions pymilvus/client/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,11 +504,15 @@ def __init__(
)
nq_thres += topk
self._session_ts = session_ts
self._search_iterator_v2_results = res.search_iterator_v2_results
super().__init__(data)

def get_session_ts(self):
return self._session_ts

def get_search_iterator_v2_results_info(self):
return self._search_iterator_v2_results

def get_fields_by_range(
self, start: int, end: int, all_fields_data: List[schema_pb2.FieldData]
) -> Dict[str, Tuple[List[Any], schema_pb2.FieldData]]:
Expand Down
4 changes: 4 additions & 0 deletions pymilvus/client/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
STRICT_GROUP_SIZE = "strict_group_size"
ITERATOR_FIELD = "iterator"
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
ITER_SEARCH_V2_KEY = "search_iter_v2"
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"
ITER_SEARCH_LAST_BOUND_KEY = "search_iter_last_bound"
ITER_SEARCH_ID_KEY = "search_iter_id"
PAGE_RETAIN_ORDER_FIELD = "page_retain_order"
HINTS = "hints"

Expand Down
20 changes: 20 additions & 0 deletions pymilvus/client/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
GROUP_BY_FIELD,
GROUP_SIZE,
HINTS,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
PAGE_RETAIN_ORDER_FIELD,
RANK_GROUP_SCORER,
Expand Down Expand Up @@ -958,6 +962,22 @@ def search_requests_with_expr(
if is_iterator is not None:
search_params[ITERATOR_FIELD] = is_iterator

is_search_iter_v2 = kwargs.get(ITER_SEARCH_V2_KEY)
if is_search_iter_v2 is not None:
search_params[ITER_SEARCH_V2_KEY] = is_search_iter_v2

search_iter_batch_size = kwargs.get(ITER_SEARCH_BATCH_SIZE_KEY)
if search_iter_batch_size is not None:
search_params[ITER_SEARCH_BATCH_SIZE_KEY] = search_iter_batch_size

search_iter_last_bound = kwargs.get(ITER_SEARCH_LAST_BOUND_KEY)
if search_iter_last_bound is not None:
search_params[ITER_SEARCH_LAST_BOUND_KEY] = search_iter_last_bound

search_iter_id = kwargs.get(ITER_SEARCH_ID_KEY)
if search_iter_id is not None:
search_params[ITER_SEARCH_ID_KEY] = search_iter_id

group_by_field = kwargs.get(GROUP_BY_FIELD)
if group_by_field is not None:
search_params[GROUP_BY_FIELD] = group_by_field
Expand Down
135 changes: 135 additions & 0 deletions pymilvus/client/search_iterator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import logging
from copy import deepcopy
from typing import Dict, List, Optional, Union

from pymilvus.client import entity_helper, utils
from pymilvus.client.constants import (
GUARANTEE_TIMESTAMP,
ITER_SEARCH_BATCH_SIZE_KEY,
ITER_SEARCH_ID_KEY,
ITER_SEARCH_LAST_BOUND_KEY,
ITER_SEARCH_V2_KEY,
ITERATOR_FIELD,
)
from pymilvus.exceptions import MilvusException, ParamError
from pymilvus.orm.connections import Connections
from pymilvus.orm.constants import MAX_BATCH_SIZE, MILVUS_LIMIT, OFFSET
from pymilvus.orm.iterator import SearchPage, fall_back_to_latest_session_ts

logger = logging.getLogger(__name__)


class SearchIteratorV2:
_NOT_SUPPORT_V2_MSG = """
The server does not support Search Iterator V2.
Please upgrade your Milvus server, or create a search_iterator (v1) instead.
"""

# for compatibility, save the first result during init
_saved_first_res = None
_is_saved = False

def __init__(
self,
connection: Connections,
collection_name: str,
data: Union[List, utils.SparseMatrixInputType],
batch_size: int = 1000,
filter: Optional[str] = None,
output_fields: Optional[List[str]] = None,
search_params: Optional[Dict] = None,
timeout: Optional[float] = None,
partition_names: Optional[List[str]] = None,
anns_field: Optional[str] = None,
round_decimal: Optional[int] = -1,
**kwargs,
):
self._check_params(batch_size, data, kwargs)

# delete limit from incoming for compatibility
if MILVUS_LIMIT in kwargs:
del kwargs[MILVUS_LIMIT]

self._conn = connection
self._params = {
"collection_name": collection_name,
"data": data,
"anns_field": anns_field,
"param": deepcopy(search_params),
"limit": batch_size,
"expression": filter,
"partition_names": partition_names,
"output_fields": output_fields,
"timeout": timeout,
"round_decimal": round_decimal,
ITERATOR_FIELD: True,
ITER_SEARCH_V2_KEY: True,
ITER_SEARCH_BATCH_SIZE_KEY: batch_size,
GUARANTEE_TIMESTAMP: 0,
**kwargs,
}
# this raises MilvusException if the server does not support V2
self._saved_first_res = self.next()
self._is_saved = True

def next(self):
# for compatibility
if self._is_saved:
self._is_saved = False
return self._saved_first_res
self._saved_first_res = None

res = self._conn.search(**self._params)
iter_info = res.get_search_iterator_v2_results_info()
self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound

# patch token and guarantee timestamp for the first next() call
if ITER_SEARCH_ID_KEY not in self._params:
if iter_info.token is not None and iter_info.token != "":
self._params[ITER_SEARCH_ID_KEY] = iter_info.token
else:
raise MilvusException(message=self.NOT_SUPPORT_V2_MSG)
if self._params[GUARANTEE_TIMESTAMP] <= 0:
if res.get_session_ts() > 0:
self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts()
else:
logger.warning(
"failed to set up mvccTs from milvus server, use client-side ts instead"
)
self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts()

# return SearchPage for compability
if len(res) > 0:
return SearchPage(res[0])
return SearchPage(None)

def close(self):
pass

def _check_params(
self,
batch_size: int,
data: Union[List, utils.SparseMatrixInputType],
kwargs: Dict,
):
# metric_type can be empty, deduced at server side
# anns_field can be empty, deduced at server side

# check batch size
if batch_size < 0:
raise ParamError(message="batch size cannot be less than zero")
if batch_size > MAX_BATCH_SIZE:
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}")

# check offset
if kwargs.get(OFFSET, 0) != 0:
raise ParamError(message="Offset is not supported for search_iterator_v2")

# check num queries, heavy to check at server side
rows = entity_helper.get_input_num_rows(data)
if rows > 1:
raise ParamError(
message="search_iterator_v2 does not support processing multiple vectors simultaneously"
)
if rows == 0:
raise ParamError(message="The vector data for search cannot be empty")
69 changes: 67 additions & 2 deletions pymilvus/milvus_client/milvus_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pymilvus.client.abstract import AnnSearchRequest, BaseRanker
from pymilvus.client.constants import DEFAULT_CONSISTENCY_LEVEL
from pymilvus.client.search_iterator import SearchIteratorV2
from pymilvus.client.types import (
ExceptionsMessage,
ExtraList,
Expand Down Expand Up @@ -533,11 +534,75 @@ def search_iterator(
anns_field: Optional[str] = None,
round_decimal: int = -1,
**kwargs,
):
) -> Union[SearchIteratorV2, SearchIterator]:
"""Creates an iterator for searching vectors in batches.
This method returns an iterator that performs vector similarity search in batches,
which is useful when dealing with large result sets. It automatically attempts to use
Search Iterator V2 if supported by the server, otherwise falls back to V1.
Args:
collection_name (str): Name of the collection to search in.
data (Union[List[list], list]): Vector data to search with. For V2, only single vector
search is supported.
batch_size (int, optional): Number of results to fetch per batch. Defaults to 1000.
Must be between 1 and MAX_BATCH_SIZE.
filter (str, optional): Filtering expression to filter the results. Defaults to None.
limit (int, optional): Total number of results to return. Defaults to UNLIMITED.
V2 ignores this parameter.
output_fields (List[str], optional): Fields to return in the results.
search_params (dict, optional): Parameters for the search operation.
timeout (float, optional): Timeout in seconds for each RPC call.
partition_names (List[str], optional): Names of partitions to search in.
anns_field (str, optional): Name of the vector field to search. Can be empty when
there is only one vector field in the collection.
round_decimal (int, optional): Number of decimal places for distance values.
Defaults to -1 (no rounding).
**kwargs: Additional arguments to pass to the search operation.
Returns:
SearchIterator: An iterator object that yields search results in batches.
Raises:
MilvusException: If the search operation fails.
ParamError: If the input parameters are invalid (e.g., invalid batch_size or multiple
vectors in data when using V2).
Examples:
>>> # Search with iterator
>>> iterator = client.search_iterator(
... collection_name="my_collection",
... data=[[0.1, 0.2]],
... batch_size=100
... )
"""

conn = self._get_connection()

# compatibility logic, change this when support get version from server
try:
return SearchIteratorV2(
connection=conn,
collection_name=collection_name,
data=data,
batch_size=batch_size,
filter=filter,
output_fields=output_fields,
search_params=search_params or {},
timeout=timeout,
partition_names=partition_names,
anns_field=anns_field or "",
round_decimal=round_decimal,
**kwargs,
)
except MilvusException as ex:
if ex.message != SearchIteratorV2._NOT_SUPPORT_V2_MSG:
raise ex from ex

# following is the old code for search_iterator V1
if filter is not None and not isinstance(filter, str):
raise DataTypeNotMatchException(message=ExceptionsMessage.ExprType % type(filter))

conn = self._get_connection()
# set up schema for iterator
try:
schema_dict = conn.describe_collection(collection_name, timeout=timeout, **kwargs)
Expand Down

0 comments on commit bf04252

Please sign in to comment.