Skip to content

Commit

Permalink
Python: introducing Vector Search for Qdrant Collection (#9621)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->
Adds the vector search functions to the existing QdrantCollection.
Currently only support VectorizedSearch.

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [x] The code builds clean without any errors or warnings
- [x] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [x] All unit tests pass, and I have added new tests where possible
- [x] I didn't break anyone 😄
  • Loading branch information
eavanvalkenburg authored Nov 12, 2024
1 parent 5764c8c commit 02bb98e
Show file tree
Hide file tree
Showing 10 changed files with 163 additions and 33 deletions.
3 changes: 2 additions & 1 deletion python/.cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
"opentelemetry",
"SEMANTICKERNEL",
"OTEL",
"vectorizable"
"vectorizable",
"desync"
]
}
54 changes: 41 additions & 13 deletions python/samples/concepts/memory/new_memory.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.

import argparse
import asyncio
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Annotated
Expand Down Expand Up @@ -29,6 +30,8 @@
vectorstoremodel,
)
from semantic_kernel.data.const import DistanceFunction, IndexKind
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin


def get_data_model_array(index_kind: IndexKind, distance_function: DistanceFunction) -> type:
Expand Down Expand Up @@ -82,7 +85,7 @@ class DataModelList:
collection_name = "test"
# Depending on the vector database, the index kind and distance function may need to be adjusted,
# since not all combinations are supported by all databases.
DataModel = get_data_model_array(IndexKind.HNSW, DistanceFunction.COSINE)
DataModel = get_data_model_array(IndexKind.HNSW, DistanceFunction.COSINE_SIMILARITY)

# A list of VectorStoreRecordCollection that can be used.
# Available collections are:
Expand Down Expand Up @@ -144,38 +147,63 @@ class DataModelList:


async def main(collection: str, use_azure_openai: bool, embedding_model: str):
print("-" * 30)
kernel = Kernel()
service_id = "embedding"
if use_azure_openai:
kernel.add_service(AzureTextEmbedding(service_id=service_id, deployment_name=embedding_model))
embedder = AzureTextEmbedding(service_id=service_id, deployment_name=embedding_model)
else:
kernel.add_service(OpenAITextEmbedding(service_id=service_id, ai_model_id=embedding_model))
embedder = OpenAITextEmbedding(service_id=service_id, ai_model_id=embedding_model)
kernel.add_service(embedder)
async with collections[collection]() as record_collection:
print(f"Creating {collection} collection!")
await record_collection.create_collection_if_not_exists()

record1 = DataModel(content="My text", id="e6103c03-487f-4d7d-9c23-4723651c17f4")
record2 = DataModel(content="My other text", id="09caec77-f7e1-466a-bcec-f1d51c5b15be")
record1 = DataModel(content="Semantic Kernel is awesome", id="e6103c03-487f-4d7d-9c23-4723651c17f4")
record2 = DataModel(
content="Semantic Kernel is available in dotnet, python and Java.",
id="09caec77-f7e1-466a-bcec-f1d51c5b15be",
)

print("Adding records!")
records = await VectorStoreRecordUtils(kernel).add_vector_to_records(
[record1, record2], data_model_type=DataModel
)
keys = await record_collection.upsert_batch(records)
print(f"upserted {keys=}")

print(f" Upserted {keys=}")
print("Getting records!")
results = await record_collection.get_batch([record1.id, record2.id])
if results:
for result in results:
print(f"found {result.id=}")
print(f"{result.content=}")
print(f" Found id: {result.id}")
print(f" Content: {result.content}")
if result.vector is not None:
print(f"{result.vector[:5]=}")
print(f" Vector (first five): {result.vector[:5]}")
else:
print("not found")
print("Nothing found...")
if isinstance(record_collection, VectorizedSearchMixin):
print("-" * 30)
print("Using vectorized search, the distance function is set to cosine_similarity.")
print("This means that the higher the score the more similar.")
search_results = await record_collection.vectorized_search(
vector=(await embedder.generate_raw_embeddings(["python"]))[0],
options=VectorSearchOptions(vector_field_name="vector", include_vectors=True),
)
results = [record async for record in search_results.results]
for result in results:
print(f" Found id: {result.record.id}")
print(f" Content: {result.record.content}")
if result.record.vector is not None:
print(f" Vector (first five): {result.record.vector[:5]}")
print(f" Score: {result.score:.4f}")
print("")
print("-" * 30)
print("Deleting collection!")
await record_collection.delete_collection()
print("Done!")


if __name__ == "__main__":
import asyncio

argparse.ArgumentParser()

parser = argparse.ArgumentParser()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def _inner_search(
]
raw_results = await self.search_client.search(**search_args)
return KernelSearchResults(
results=self._get_vector_search_results_from_results(raw_results),
results=self._get_vector_search_results_from_results(raw_results, options),
total_count=await raw_results.get_count() if options.include_total_count else None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ async def _inner_search_text(
if return_records:
return KernelSearchResults(
results=self._get_vector_search_results_from_results(
self._generate_return_list(return_records, options)
self._generate_return_list(return_records, options), options
),
total_count=len(return_records) if options and options.include_total_count else None,
)
Expand Down Expand Up @@ -167,7 +167,7 @@ async def _inner_search_vectorized(
if sorted_records:
return KernelSearchResults(
results=self._get_vector_search_results_from_results(
self._generate_return_list(sorted_records, options)
self._generate_return_list(sorted_records, options), options
),
total_count=len(return_records) if options and options.include_total_count else None,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import sys
from collections.abc import Mapping, Sequence
from typing import Any, ClassVar, TypeVar
from typing import Any, ClassVar, Generic, TypeVar

if sys.version_info >= (3, 12):
from typing import override # pragma: no cover
Expand All @@ -12,17 +12,22 @@

from pydantic import ValidationError
from qdrant_client.async_qdrant_client import AsyncQdrantClient
from qdrant_client.models import PointStruct, VectorParams
from qdrant_client.models import FieldCondition, Filter, MatchAny, PointStruct, QueryResponse, ScoredPoint, VectorParams

from semantic_kernel.connectors.memory.qdrant.const import DISTANCE_FUNCTION_MAP, TYPE_MAPPER_VECTOR
from semantic_kernel.connectors.memory.qdrant.utils import AsyncQdrantClientWrapper
from semantic_kernel.data.kernel_search_results import KernelSearchResults
from semantic_kernel.data.record_definition import VectorStoreRecordDefinition, VectorStoreRecordVectorField
from semantic_kernel.data.vector_storage import VectorStoreRecordCollection
from semantic_kernel.data.vector_search.vector_search import VectorSearchBase
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin
from semantic_kernel.exceptions import (
MemoryConnectorInitializationError,
VectorStoreModelValidationError,
)
from semantic_kernel.exceptions.memory_connector_exceptions import MemoryConnectorException
from semantic_kernel.exceptions.search_exceptions import VectorSearchExecutionException
from semantic_kernel.kernel_types import OneOrMany
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.telemetry.user_agent import APP_INFO, prepend_semantic_kernel_to_user_agent
Expand All @@ -34,7 +39,11 @@


@experimental_class
class QdrantCollection(VectorStoreRecordCollection[str | int, TModel]):
class QdrantCollection(
VectorSearchBase[str | int, TModel],
VectorizedSearchMixin[TModel],
Generic[TModel],
):
"""A QdrantCollection is a memory collection that uses Qdrant as the backend."""

qdrant_client: AsyncQdrantClient
Expand Down Expand Up @@ -163,6 +172,53 @@ async def _inner_delete(self, keys: Sequence[TKey], **kwargs: Any) -> None:
**kwargs,
)

@override
async def _inner_search(
self,
options: VectorSearchOptions,
search_text: str | None = None,
vectorizable_text: str | None = None,
vector: list[float | int] | None = None,
**kwargs: Any,
) -> KernelSearchResults[VectorSearchResult[TModel]]:
query_vector: tuple[str, list[float | int]] | list[float | int] | None = None
if vector is not None:
if self.named_vectors and options.vector_field_name:
query_vector = (options.vector_field_name, vector)
else:
query_vector = vector
if query_vector is None:
raise VectorSearchExecutionException("Search requires either a vector.")
results = await self.qdrant_client.search(
collection_name=self.collection_name,
query_vector=query_vector,
query_filter=self._create_filter(options),
with_vectors=options.include_vectors,
limit=options.top,
offset=options.skip,
**kwargs,
)
return KernelSearchResults(
results=self._get_vector_search_results_from_results(results, options),
total_count=len(results) if options.include_total_count else None,
)

@override
def _get_record_from_result(self, result: ScoredPoint | QueryResponse) -> Any:
return result

@override
def _get_score_from_result(self, result: ScoredPoint) -> float:
return result.score

def _create_filter(self, options: VectorSearchOptions) -> Filter:
return Filter(
must=[
FieldCondition(key=filter.field_name, match=MatchAny(any=filter.value))
for filter in options.filter.filters
]
)

@override
def _serialize_dicts_to_store_models(
self,
Expand All @@ -183,15 +239,17 @@ def _serialize_dicts_to_store_models(
@override
def _deserialize_store_models_to_dicts(
self,
records: Sequence[PointStruct],
records: Sequence[PointStruct] | Sequence[ScoredPoint],
**kwargs: Any,
) -> Sequence[dict[str, Any]]:
return [
{
self._key_field_name: record.id,
**(record.payload if record.payload else {}),
**(
record.vector
{}
if not record.vector
else record.vector
if isinstance(record.vector, dict)
else {self.data_model_definition.vector_field_names[0]: record.vector}
),
Expand Down
11 changes: 8 additions & 3 deletions python/semantic_kernel/data/vector_search/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import abstractmethod
from collections.abc import AsyncIterable
from collections.abc import AsyncIterable, Sequence
from typing import Any, Generic, TypeVar

from semantic_kernel.data.kernel_search_results import KernelSearchResults
Expand All @@ -11,6 +11,7 @@
from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult
from semantic_kernel.data.vector_storage.vector_store_record_collection import VectorStoreRecordCollection
from semantic_kernel.utils.experimental_decorator import experimental_class
from semantic_kernel.utils.list_handler import desync_list

TModel = TypeVar("TModel")
TKey = TypeVar("TKey")
Expand Down Expand Up @@ -100,10 +101,14 @@ def _get_score_from_result(self, result: Any) -> float | None:
# region: New methods

async def _get_vector_search_results_from_results(
self, results: AsyncIterable[Any]
self, results: AsyncIterable[Any] | Sequence[Any], options: VectorSearchOptions | None = None
) -> AsyncIterable[VectorSearchResult[TModel]]:
if isinstance(results, Sequence):
results = desync_list(results)
async for result in results:
record = self.deserialize(self._get_record_from_result(result))
record = self.deserialize(
self._get_record_from_result(result), include_vectors=options.include_vectors if options else True
)
score = self._get_score_from_result(result)
if record:
# single records are always returned as single records by the deserializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **k
The input of this should come from the _deserialized_store_model_to_dict function.
"""
include_vectors = kwargs.get("include_vectors", True)
if self.data_model_definition.from_dict:
if isinstance(record, Sequence):
return self.data_model_definition.from_dict(record, **kwargs)
Expand All @@ -544,24 +545,28 @@ def _deserialize_dict_to_data_model(self, record: OneOrMany[dict[str, Any]], **k
try:
if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields):
return self.data_model_type.model_validate(record) # type: ignore
for field in self.data_model_definition.vector_fields:
if field.serialize_function:
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
if include_vectors:
for field in self.data_model_definition.vector_fields:
if field.serialize_function:
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
return self.data_model_type.model_validate(record) # type: ignore
except Exception as exc:
raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc
if hasattr(self.data_model_type, "from_dict"):
try:
if not any(field.serialize_function is not None for field in self.data_model_definition.vector_fields):
return self.data_model_type.from_dict(record) # type: ignore
for field in self.data_model_definition.vector_fields:
if field.serialize_function:
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
if include_vectors:
for field in self.data_model_definition.vector_fields:
if field.serialize_function:
record[field.name] = field.serialize_function(record[field.name]) # type: ignore
return self.data_model_type.from_dict(record) # type: ignore
except Exception as exc:
raise VectorStoreModelDeserializationException(f"Error deserializing record: {exc}") from exc
data_model_dict: dict[str, Any] = {}
for field_name in self.data_model_definition.fields: # type: ignore
if not include_vectors and field_name in self.data_model_definition.vector_field_names:
continue
try:
value = record[field_name]
if func := getattr(self.data_model_definition.fields[field_name], "deserialize_function", None):
Expand Down
13 changes: 13 additions & 0 deletions python/semantic_kernel/utils/list_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.


from collections.abc import AsyncIterable, Sequence
from typing import TypeVar

_T = TypeVar("_T")


async def desync_list(sync_list: Sequence[_T]) -> AsyncIterable[_T]: # noqa: RUF029
"""De synchronize a list of synchronous objects."""
for x in sync_list:
yield x
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ async def test_vectorized_search_similar(collection, distance_function):
await collection.upsert_batch([record1, record2])
results = await collection.vectorized_search(
vector=[0.9, 0.9, 0.9, 0.9, 0.9],
options=VectorSearchOptions(vector_field_name="vector", include_total_count=True),
options=VectorSearchOptions(vector_field_name="vector", include_total_count=True, include_vectors=True),
)
assert results.total_count == 2
idx = 0
Expand Down
20 changes: 20 additions & 0 deletions python/tests/unit/connectors/memory/qdrant/test_qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from semantic_kernel.connectors.memory.qdrant.qdrant_collection import QdrantCollection
from semantic_kernel.connectors.memory.qdrant.qdrant_store import QdrantStore
from semantic_kernel.data.record_definition.vector_store_record_fields import VectorStoreRecordVectorField
from semantic_kernel.data.vector_search.vector_search_options import VectorSearchOptions
from semantic_kernel.exceptions.memory_connector_exceptions import (
MemoryConnectorException,
MemoryConnectorInitializationError,
Expand Down Expand Up @@ -107,6 +108,17 @@ def mock_delete():
yield mock_delete


@fixture(autouse=True)
def mock_search():
with patch(f"{BASE_PATH}.search") as mock_search:
from qdrant_client.models import ScoredPoint

response1 = ScoredPoint(id="id1", version=1, score=0.0, payload={"content": "content"})
response2 = ScoredPoint(id="id2", version=1, score=0.0, payload={"content": "content"})
mock_search.return_value = [response1, response2]
yield mock_search


def test_vector_store_defaults(vector_store):
assert vector_store.qdrant_client is not None
assert vector_store.qdrant_client._client.rest_uri == "http://localhost:6333"
Expand Down Expand Up @@ -269,3 +281,11 @@ async def test_create_index_fail(collection_to_use, request):
collection.data_model_definition.fields["vector"].dimensions = None
with raises(MemoryConnectorException, match="Vector field must have dimensions."):
await collection.create_collection()


@mark.asyncio
async def test_search(collection):
results = await collection._inner_search(vector=[1.0, 2.0, 3.0], options=VectorSearchOptions(include_vectors=False))
async for result in results.results:
assert result.record["id"] == "id1"
break

0 comments on commit 02bb98e

Please sign in to comment.