Skip to content

Commit

Permalink
Add hybrid search index (#46)
Browse files Browse the repository at this point in the history
  • Loading branch information
vhaldemar authored Dec 18, 2024
1 parent c9be1bb commit 48c43d5
Show file tree
Hide file tree
Showing 10 changed files with 964 additions and 38 deletions.
78 changes: 78 additions & 0 deletions examples/async/assistants/hybrid_search_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#!/usr/bin/env python3

from __future__ import annotations

import asyncio
import pathlib
import pprint

from yandex_cloud_ml_sdk import AsyncYCloudML
from yandex_cloud_ml_sdk.search_indexes import (
HybridSearchIndexType, ReciprocalRankFusionIndexCombinationStrategy, StaticIndexChunkingStrategy,
TextSearchIndexType, VectorSearchIndexType
)


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


async def main() -> None:
sdk = AsyncYCloudML(folder_id='b1ghsjum2v37c2un8h64')

file_coros = (
sdk.files.upload(
local_path(path),
ttl_days=5,
expiration_policy="static",
)
for path in ['turkey_example.txt', 'maldives_example.txt']
)
files = await asyncio.gather(*file_coros)

# How to create search index with all default settings:
operation = await sdk.search_indexes.create_deferred(
files,
index_type=HybridSearchIndexType()
)
default_search_index = await operation.wait()
print("new hybrid search index with default settings:")
pprint.pprint(default_search_index)

# But you could override any default:
operation = await sdk.search_indexes.create_deferred(
files,
index_type=HybridSearchIndexType(
chunking_strategy=StaticIndexChunkingStrategy(
max_chunk_size_tokens=700,
chunk_overlap_tokens=300
),
# you could also override some text/vector indexes settings
text_search_index=TextSearchIndexType(),
vector_search_index=VectorSearchIndexType(),
normalization_strategy='L2',
# you don't really want to change `k` parameter if you don't
# really know what you are doing
combination_strategy=ReciprocalRankFusionIndexCombinationStrategy(
k=60
)
)
)
search_index = await operation.wait()
print("new hybrid search index with overridden settings:")
pprint.pprint(search_index)

# And how to use your index you could learn in example file "assistant_with_search_index.py".
# Working with hybrid index does not differ from working with any other index besides creation.

# Created resources cleanup:
for file in files:
await file.delete()

for search_index in [default_search_index, search_index]:
print(f"delete {search_index.id=}")
await search_index.delete()


if __name__ == '__main__':
asyncio.run(main())
76 changes: 76 additions & 0 deletions examples/sync/assistants/hybrid_search_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python3

from __future__ import annotations

import pathlib
import pprint

from yandex_cloud_ml_sdk import YCloudML
from yandex_cloud_ml_sdk.search_indexes import (
HybridSearchIndexType, ReciprocalRankFusionIndexCombinationStrategy, StaticIndexChunkingStrategy,
TextSearchIndexType, VectorSearchIndexType
)


def local_path(path: str) -> pathlib.Path:
return pathlib.Path(__file__).parent / path


def main() -> None:
sdk = YCloudML(folder_id='b1ghsjum2v37c2un8h64')

files = []
for path in ['turkey_example.txt', 'maldives_example.txt']:
file = sdk.files.upload(
local_path(path),
ttl_days=5,
expiration_policy="static",
)
files.append(file)

# How to create search index with all default settings:
operation = sdk.search_indexes.create_deferred(
files,
index_type=HybridSearchIndexType()
)
default_search_index = operation.wait()
print("new hybrid search index with default settings:")
pprint.pprint(default_search_index)

# But you could override any default:
operation = sdk.search_indexes.create_deferred(
files,
index_type=HybridSearchIndexType(
chunking_strategy=StaticIndexChunkingStrategy(
max_chunk_size_tokens=700,
chunk_overlap_tokens=300
),
# you could also override some text/vector indexes settings
text_search_index=TextSearchIndexType(),
vector_search_index=VectorSearchIndexType(),
normalization_strategy='L2',
# you don't really want to change `k` parameter if you don't
# really know what you are doing
combination_strategy=ReciprocalRankFusionIndexCombinationStrategy(
k=60
)
)
)
search_index = operation.wait()
print("new hybrid search index with overridden settings:")
pprint.pprint(search_index)

# And how to use your index you could learn in example file "assistant_with_search_index.py".
# Working with hybrid index does not differ from working with any other index besides creation.

# Created resources cleanup:
for file in files:
file.delete()

for search_index in [default_search_index, search_index]:
print(f"delete {search_index.id=}")
search_index.delete()


if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ exclude-protected= [
"_to_proto",
"_result_type",
"_proto_result_type",
"_proto_field_name",
"_coerce",
]
valid-classmethod-first-arg="cls"
valid-metaclass-classmethod-first-arg="cls"
Expand Down
109 changes: 109 additions & 0 deletions src/yandex_cloud_ml_sdk/_search_indexes/combination_strategy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# pylint: disable=no-name-in-module,protected-access
from __future__ import annotations

import abc
import enum
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Collection

from google.protobuf.wrappers_pb2 import Int64Value
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import CombinationStrategy as ProtoCombinationStrategy
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import MeanCombinationStrategy as ProtoMeanCombinationStrategy
from yandex.cloud.ai.assistants.v1.searchindex.common_pb2 import (
ReciprocalRankFusionCombinationStrategy as ProtoReciprocalRankFusionCombinationStrategy
)

if TYPE_CHECKING:
from yandex_cloud_ml_sdk._sdk import BaseSDK


class BaseIndexCombinationStrategy(abc.ABC):
@classmethod
@abc.abstractmethod
def _from_proto(cls, proto: Any, sdk: BaseSDK) -> BaseIndexCombinationStrategy:
pass

@abc.abstractmethod
def _to_proto(self) -> ProtoCombinationStrategy:
pass

@classmethod
def _from_upper_proto(cls, proto: ProtoCombinationStrategy, sdk: BaseSDK) -> BaseIndexCombinationStrategy:
if proto.HasField('mean_combination'):
return MeanIndexCombinationStrategy._from_proto(
proto=proto.mean_combination,
sdk=sdk
)
if proto.HasField('rrf_combination'):
return ReciprocalRankFusionIndexCombinationStrategy._from_proto(
proto=proto.rrf_combination,
sdk=sdk
)
raise NotImplementedError(
'combination strategies other then Mean and RRF are not supported in this SDK version'
)


_orig = ProtoMeanCombinationStrategy.MeanEvaluationTechnique

class MeanIndexEvaluationTechnique(enum.IntEnum):
MEAN_EVALUATION_TECHNIQUE_UNSPECIFIED = _orig.MEAN_EVALUATION_TECHNIQUE_UNSPECIFIED
ARITHMETIC = _orig.ARITHMETIC
GEOMETRIC = _orig.GEOMETRIC
HARMONIC = _orig.HARMONIC

@classmethod
def _coerce(cls, technique: str | int ) -> MeanIndexEvaluationTechnique:
if isinstance(technique, str):
technique = _orig.Value(technique.upper())
return cls(technique)


@dataclass(frozen=True)
class MeanIndexCombinationStrategy(BaseIndexCombinationStrategy):
mean_evaluation_technique: MeanIndexEvaluationTechnique | None
weights: Collection[float] | None

@classmethod
# pylint: disable=unused-argument
def _from_proto(cls, proto: ProtoMeanCombinationStrategy, sdk: BaseSDK) -> MeanIndexCombinationStrategy:
return cls(
mean_evaluation_technique=MeanIndexEvaluationTechnique._coerce(proto.mean_evaluation_technique),
weights=tuple(proto.weights)
)

def _to_proto(self) -> ProtoCombinationStrategy:
kwargs: dict[str, Any] = {}
if self.mean_evaluation_technique:
kwargs['mean_evaluation_technique'] = int(self.mean_evaluation_technique)
if self.weights is not None:
kwargs['weghts'] = tuple(self.weights)

return ProtoCombinationStrategy(
mean_combination=ProtoMeanCombinationStrategy(**kwargs)
)


@dataclass(frozen=True)
class ReciprocalRankFusionIndexCombinationStrategy(BaseIndexCombinationStrategy):
k: int | None = None

@classmethod
# pylint: disable=unused-argument
def _from_proto(
cls, proto: ProtoReciprocalRankFusionCombinationStrategy, sdk: BaseSDK
) -> ReciprocalRankFusionIndexCombinationStrategy:
kwargs = {}
if proto.HasField('k'):
kwargs['k'] = proto.k.value
return ReciprocalRankFusionIndexCombinationStrategy(
**kwargs
)

def _to_proto(self) -> ProtoCombinationStrategy:
kwargs = {}
if self.k is not None:
kwargs['k'] = Int64Value(value=self.k)
return ProtoCombinationStrategy(
rrf_combination=ProtoReciprocalRankFusionCombinationStrategy(**kwargs)
)
19 changes: 7 additions & 12 deletions src/yandex_cloud_ml_sdk/_search_indexes/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import AsyncIterator, Generic, Iterator

from yandex.cloud.ai.assistants.v1.searchindex.search_index_pb2 import SearchIndex as ProtoSearchIndex
from yandex.cloud.ai.assistants.v1.searchindex.search_index_pb2 import TextSearchIndex, VectorSearchIndex
from yandex.cloud.ai.assistants.v1.searchindex.search_index_service_pb2 import (
CreateSearchIndexRequest, GetSearchIndexRequest, ListSearchIndicesRequest, ListSearchIndicesResponse
)
Expand All @@ -19,7 +18,7 @@
from yandex_cloud_ml_sdk._utils.coerce import ResourceType, coerce_resource_ids
from yandex_cloud_ml_sdk._utils.sync import run_sync, run_sync_generator

from .index_type import BaseSearchIndexType, TextSearchIndexType, VectorSearchIndexType
from .index_type import BaseSearchIndexType
from .search_index import AsyncSearchIndex, SearchIndex, SearchIndexTypeT


Expand Down Expand Up @@ -47,14 +46,11 @@ async def _create_deferred(

expiration_config = ExpirationConfig.coerce(ttl_days=ttl_days, expiration_policy=expiration_policy)

vector_search_index: VectorSearchIndex | None = None
text_search_index: TextSearchIndex | None = None
if isinstance(index_type, VectorSearchIndexType):
vector_search_index = index_type._to_proto()
elif isinstance(index_type, TextSearchIndexType):
text_search_index = index_type._to_proto()
elif is_defined(index_type):
raise TypeError('index type must be instance of SearchIndexType')
kwargs = {}
if is_defined(index_type):
if not isinstance(index_type, BaseSearchIndexType):
raise TypeError('index type must be instance of BaseSearchIndexType')
kwargs[index_type._proto_field_name] = index_type._to_proto()

request = CreateSearchIndexRequest(
folder_id=self._folder_id,
Expand All @@ -63,8 +59,7 @@ async def _create_deferred(
description=get_defined_value(description, ''),
labels=get_defined_value(labels, {}),
expiration_config=expiration_config.to_proto(),
vector_search_index=vector_search_index,
text_search_index=text_search_index,
**kwargs, # type: ignore[arg-type]
)

async with self._client.get_service_stub(SearchIndexServiceStub, timeout=timeout) as stub:
Expand Down
Loading

0 comments on commit 48c43d5

Please sign in to comment.