-
Notifications
You must be signed in to change notification settings - Fork 339
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Signed-off-by: Patrick Weizhi Xu <[email protected]>
- Loading branch information
Showing
5 changed files
with
220 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
import logging | ||
from copy import deepcopy | ||
from typing import Dict, List, Optional, Union | ||
from pymilvus.client import entity_helper, utils | ||
from pymilvus.client.constants import ( | ||
GUARANTEE_TIMESTAMP, | ||
ITER_SEARCH_BATCH_SIZE_KEY, | ||
ITER_SEARCH_ID_KEY, | ||
ITER_SEARCH_LAST_BOUND_KEY, | ||
ITER_SEARCH_V2_KEY, | ||
ITERATOR_FIELD, | ||
) | ||
from pymilvus.exceptions import MilvusException, ParamError | ||
from pymilvus.orm.connections import Connections | ||
from pymilvus.orm.constants import MAX_BATCH_SIZE, METRIC_TYPE, MILVUS_LIMIT, OFFSET | ||
from pymilvus.orm.iterator import fall_back_to_latest_session_ts, SearchPage | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SearchIteratorV2: | ||
def __init__( | ||
self, | ||
connection: Connections, | ||
collection_name: str, | ||
data: Union[List, utils.SparseMatrixInputType], | ||
batch_size: int = 1000, | ||
filter: Optional[str] = None, | ||
output_fields: Optional[List[str]] = None, | ||
search_params: Optional[Dict] = None, | ||
timeout: Optional[float] = None, | ||
partition_names: Optional[List[str]] = None, | ||
anns_field: Optional[str] = None, | ||
round_decimal: int = -1, | ||
**kwargs, | ||
): | ||
self.__check_params(batch_size, data, search_params, kwargs) | ||
|
||
# delete limit from incoming for compatibility | ||
if MILVUS_LIMIT in kwargs: | ||
del kwargs[MILVUS_LIMIT] | ||
|
||
self._conn = connection | ||
self._params = { | ||
"collection_name": collection_name, | ||
"data": data, | ||
"anns_field": anns_field, | ||
"param": deepcopy(search_params), | ||
"limit": batch_size, | ||
"expression": filter, | ||
"partition_names": partition_names, | ||
"output_fields": output_fields, | ||
"round_decimal": round_decimal, | ||
"timeout": timeout, | ||
ITERATOR_FIELD: True, | ||
ITER_SEARCH_V2_KEY: True, | ||
ITER_SEARCH_BATCH_SIZE_KEY: batch_size, | ||
GUARANTEE_TIMESTAMP: 0, | ||
**kwargs, | ||
} | ||
|
||
def next(self): | ||
res = self._conn.search(**self._params) | ||
iter_info = res.get_search_iterator_v2_results_info() | ||
self._params[ITER_SEARCH_LAST_BOUND_KEY] = iter_info.last_bound | ||
|
||
# patch token and guarantee timestamp for the first next() call | ||
if ITER_SEARCH_ID_KEY not in self._params: | ||
if iter_info.token is not None and iter_info.token != "": | ||
self._params[ITER_SEARCH_ID_KEY] = iter_info.token | ||
else: | ||
raise MilvusException( | ||
message="The server does not support Search Iterator V2. Please upgrade your Milvus server, or create a search_iterator (v1) instead" | ||
) | ||
if self._params[GUARANTEE_TIMESTAMP] <= 0: | ||
if res.get_session_ts() > 0: | ||
self._params[GUARANTEE_TIMESTAMP] = res.get_session_ts() | ||
else: | ||
logger.warning( | ||
"failed to set up mvccTs from milvus server, use client-side ts instead" | ||
) | ||
self._params[GUARANTEE_TIMESTAMP] = fall_back_to_latest_session_ts() | ||
|
||
# return SearchPage for compability | ||
if len(res) > 0: | ||
return SearchPage(res[0]) | ||
return SearchPage(None) | ||
|
||
def close(self): | ||
pass | ||
|
||
def __check_params( | ||
self, | ||
batch_size: int, | ||
data: Union[List, utils.SparseMatrixInputType], | ||
param: Dict, | ||
kwargs: Dict, | ||
): | ||
# metric_type can be empty, deduced at server side | ||
# anns_field can be empty, deduced at server side | ||
|
||
# check batch size | ||
if batch_size < 0: | ||
raise ParamError(message="batch size cannot be less than zero") | ||
if batch_size > MAX_BATCH_SIZE: | ||
raise ParamError(message=f"batch size cannot be larger than {MAX_BATCH_SIZE}") | ||
|
||
# check offset | ||
if kwargs.get(OFFSET, 0) != 0: | ||
raise ParamError(message="Offset is not supported for search_iterator_v2") | ||
|
||
# check num queries, heavy to check at server side | ||
rows = entity_helper.get_input_num_rows(data) | ||
print(f"rows: {rows}") | ||
if rows > 1: | ||
raise ParamError( | ||
message="search_iterator_v2 does not support processing multiple vectors simultaneously" | ||
) | ||
if rows == 0: | ||
raise ParamError(message="The vector data for search cannot be empty") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters