-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
964 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
import pathlib | ||
import pprint | ||
|
||
from yandex_cloud_ml_sdk import AsyncYCloudML | ||
from yandex_cloud_ml_sdk.search_indexes import ( | ||
HybridSearchIndexType, ReciprocalRankFusionIndexCombinationStrategy, StaticIndexChunkingStrategy, | ||
TextSearchIndexType, VectorSearchIndexType | ||
) | ||
|
||
|
||
def local_path(path: str) -> pathlib.Path: | ||
return pathlib.Path(__file__).parent / path | ||
|
||
|
||
async def main() -> None: | ||
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64') | ||
|
||
file_coros = ( | ||
sdk.files.upload( | ||
local_path(path), | ||
ttl_days=5, | ||
expiration_policy="static", | ||
) | ||
for path in ['turkey_example.txt', 'maldives_example.txt'] | ||
) | ||
files = await asyncio.gather(*file_coros) | ||
|
||
# How to create search index with all default settings: | ||
operation = await sdk.search_indexes.create_deferred( | ||
files, | ||
index_type=HybridSearchIndexType() | ||
) | ||
default_search_index = await operation.wait() | ||
print("new hybrid search index with default settings:") | ||
pprint.pprint(default_search_index) | ||
|
||
# But you could override any default: | ||
operation = await sdk.search_indexes.create_deferred( | ||
files, | ||
index_type=HybridSearchIndexType( | ||
chunking_strategy=StaticIndexChunkingStrategy( | ||
max_chunk_size_tokens=700, | ||
chunk_overlap_tokens=300 | ||
), | ||
# you could also override some text/vector indexes settings | ||
text_search_index=TextSearchIndexType(), | ||
vector_search_index=VectorSearchIndexType(), | ||
normalization_strategy='L2', | ||
# you don't really want to change `k` parameter if you don't | ||
# really know what you are doing | ||
combination_strategy=ReciprocalRankFusionIndexCombinationStrategy( | ||
k=60 | ||
) | ||
) | ||
) | ||
search_index = await operation.wait() | ||
print("new hybrid search index with overridden settings:") | ||
pprint.pprint(search_index) | ||
|
||
# And how to use your index you could learn in example file "assistant_with_search_index.py". | ||
# Working with hybrid index does not differ from working with any other index besides creation. | ||
|
||
# Created resources cleanup: | ||
for file in files: | ||
await file.delete() | ||
|
||
for search_index in [default_search_index, search_index]: | ||
print(f"delete {search_index.id=}") | ||
await search_index.delete() | ||
|
||
|
||
if __name__ == '__main__': | ||
asyncio.run(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from __future__ import annotations | ||
|
||
import pathlib | ||
import pprint | ||
|
||
from yandex_cloud_ml_sdk import YCloudML | ||
from yandex_cloud_ml_sdk.search_indexes import ( | ||
HybridSearchIndexType, ReciprocalRankFusionIndexCombinationStrategy, StaticIndexChunkingStrategy, | ||
TextSearchIndexType, VectorSearchIndexType | ||
) | ||
|
||
|
||
def local_path(path: str) -> pathlib.Path: | ||
return pathlib.Path(__file__).parent / path | ||
|
||
|
||
def main() -> None: | ||
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64') | ||
|
||
files = [] | ||
for path in ['turkey_example.txt', 'maldives_example.txt']: | ||
file = sdk.files.upload( | ||
local_path(path), | ||
ttl_days=5, | ||
expiration_policy="static", | ||
) | ||
files.append(file) | ||
|
||
# How to create search index with all default settings: | ||
operation = sdk.search_indexes.create_deferred( | ||
files, | ||
index_type=HybridSearchIndexType() | ||
) | ||
default_search_index = operation.wait() | ||
print("new hybrid search index with default settings:") | ||
pprint.pprint(default_search_index) | ||
|
||
# But you could override any default: | ||
operation = sdk.search_indexes.create_deferred( | ||
files, | ||
index_type=HybridSearchIndexType( | ||
chunking_strategy=StaticIndexChunkingStrategy( | ||
max_chunk_size_tokens=700, | ||
chunk_overlap_tokens=300 | ||
), | ||
# you could also override some text/vector indexes settings | ||
text_search_index=TextSearchIndexType(), | ||
vector_search_index=VectorSearchIndexType(), | ||
normalization_strategy='L2', | ||
# you don't really want to change `k` parameter if you don't | ||
# really know what you are doing | ||
combination_strategy=ReciprocalRankFusionIndexCombinationStrategy( | ||
k=60 | ||
) | ||
) | ||
) | ||
search_index = operation.wait() | ||
print("new hybrid search index with overridden settings:") | ||
pprint.pprint(search_index) | ||
|
||
# And how to use your index you could learn in example file "assistant_with_search_index.py". | ||
# Working with hybrid index does not differ from working with any other index besides creation. | ||
|
||
# Created resources cleanup: | ||
for file in files: | ||
file.delete() | ||
|
||
for search_index in [default_search_index, search_index]: | ||
print(f"delete {search_index.id=}") | ||
search_index.delete() | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
src/yandex_cloud_ml_sdk/_search_indexes/combination_strategy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
# pylint: disable=no-name-in-module,protected-access | ||
from __future__ import annotations | ||
|
||
import abc | ||
import enum | ||
from dataclasses import dataclass | ||
from typing import TYPE_CHECKING, Any, Collection | ||
|
||
from google.protobuf.wrappers_pb2 import Int64Value | ||
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import CombinationStrategy as ProtoCombinationStrategy | ||
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import MeanCombinationStrategy as ProtoMeanCombinationStrategy | ||
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import ( | ||
ReciprocalRankFusionCombinationStrategy as ProtoReciprocalRankFusionCombinationStrategy | ||
) | ||
|
||
if TYPE_CHECKING: | ||
from yandex_cloud_ml_sdk._sdk import BaseSDK | ||
|
||
|
||
class BaseIndexCombinationStrategy(abc.ABC): | ||
@classmethod | ||
@abc.abstractmethod | ||
def _from_proto(cls, proto: Any, sdk: BaseSDK) -> BaseIndexCombinationStrategy: | ||
pass | ||
|
||
@abc.abstractmethod | ||
def _to_proto(self) -> ProtoCombinationStrategy: | ||
pass | ||
|
||
@classmethod | ||
def _from_upper_proto(cls, proto: ProtoCombinationStrategy, sdk: BaseSDK) -> BaseIndexCombinationStrategy: | ||
if proto.HasField('mean_combination'): | ||
return MeanIndexCombinationStrategy._from_proto( | ||
proto=proto.mean_combination, | ||
sdk=sdk | ||
) | ||
if proto.HasField('rrf_combination'): | ||
return ReciprocalRankFusionIndexCombinationStrategy._from_proto( | ||
proto=proto.rrf_combination, | ||
sdk=sdk | ||
) | ||
raise NotImplementedError( | ||
'combination strategies other then Mean and RRF are not supported in this SDK version' | ||
) | ||
|
||
|
||
_orig = ProtoMeanCombinationStrategy.MeanEvaluationTechnique | ||
|
||
class MeanIndexEvaluationTechnique(enum.IntEnum): | ||
MEAN_EVALUATION_TECHNIQUE_UNSPECIFIED = _orig.MEAN_EVALUATION_TECHNIQUE_UNSPECIFIED | ||
ARITHMETIC = _orig.ARITHMETIC | ||
GEOMETRIC = _orig.GEOMETRIC | ||
HARMONIC = _orig.HARMONIC | ||
|
||
@classmethod | ||
def _coerce(cls, technique: str | int ) -> MeanIndexEvaluationTechnique: | ||
if isinstance(technique, str): | ||
technique = _orig.Value(technique.upper()) | ||
return cls(technique) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class MeanIndexCombinationStrategy(BaseIndexCombinationStrategy): | ||
mean_evaluation_technique: MeanIndexEvaluationTechnique | None | ||
weights: Collection[float] | None | ||
|
||
@classmethod | ||
# pylint: disable=unused-argument | ||
def _from_proto(cls, proto: ProtoMeanCombinationStrategy, sdk: BaseSDK) -> MeanIndexCombinationStrategy: | ||
return cls( | ||
mean_evaluation_technique=MeanIndexEvaluationTechnique._coerce(proto.mean_evaluation_technique), | ||
weights=tuple(proto.weights) | ||
) | ||
|
||
def _to_proto(self) -> ProtoCombinationStrategy: | ||
kwargs: dict[str, Any] = {} | ||
if self.mean_evaluation_technique: | ||
kwargs['mean_evaluation_technique'] = int(self.mean_evaluation_technique) | ||
if self.weights is not None: | ||
kwargs['weghts'] = tuple(self.weights) | ||
|
||
return ProtoCombinationStrategy( | ||
mean_combination=ProtoMeanCombinationStrategy(**kwargs) | ||
) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ReciprocalRankFusionIndexCombinationStrategy(BaseIndexCombinationStrategy): | ||
k: int | None = None | ||
|
||
@classmethod | ||
# pylint: disable=unused-argument | ||
def _from_proto( | ||
cls, proto: ProtoReciprocalRankFusionCombinationStrategy, sdk: BaseSDK | ||
) -> ReciprocalRankFusionIndexCombinationStrategy: | ||
kwargs = {} | ||
if proto.HasField('k'): | ||
kwargs['k'] = proto.k.value | ||
return ReciprocalRankFusionIndexCombinationStrategy( | ||
**kwargs | ||
) | ||
|
||
def _to_proto(self) -> ProtoCombinationStrategy: | ||
kwargs = {} | ||
if self.k is not None: | ||
kwargs['k'] = Int64Value(value=self.k) | ||
return ProtoCombinationStrategy( | ||
rrf_combination=ProtoReciprocalRankFusionCombinationStrategy(**kwargs) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.