Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Qdrant RAG Storage #524

Merged
merged 10 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/tutorials/gallery_python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#
# - [ragna.source_storages.Chroma][]
# - [ragna.source_storages.LanceDB][]
# - [ragna.source_storages.QdrantDB][]

# %%
# ## Step 3: Select an assistant
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ all = [
"pymupdf<=1.24.10,>=1.23.6",
"python-docx",
"python-pptx",
"qdrant-client>=1.12.1",
"tiktoken",
]

Expand Down
3 changes: 2 additions & 1 deletion ragna-docker.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ authentication = "ragna.deploy.RagnaDemoAuthentication"
document = "ragna.core.LocalDocument"
source_storages = [
"ragna.source_storages.Chroma",
"ragna.source_storages.LanceDB"
"ragna.source_storages.LanceDB",
"ragna.source_storages.QdrantDB",
]
assistants = [
"ragna.assistants.ClaudeHaiku",
Expand Down
1 change: 1 addition & 0 deletions ragna/source_storages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ._chroma import Chroma
from ._demo import RagnaDemoSourceStorage
from ._lancedb import LanceDB
from ._qdrant import Qdrant

# isort: split

Expand Down
257 changes: 257 additions & 0 deletions ragna/source_storages/_qdrant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
from __future__ import annotations

import os
import uuid
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Optional, cast

import ragna
from ragna.core import (
Document,
MetadataFilter,
MetadataOperator,
PackageRequirement,
Requirement,
Source,
)

from ._utils import raise_no_corpuses_available, raise_non_existing_corpus
from ._vector_database import VectorDatabaseSourceStorage

if TYPE_CHECKING:
from qdrant_client import models


class Qdrant(VectorDatabaseSourceStorage):
"""[Qdrant vector database](https://qdrant.tech/)

Use the `QDRANT_URL` and `QDRANT_API_KEY` env variables to configure connection to a qdrant server.

Eg:
> export QDRANT_URL="https://xyz-example.eu-central.aws.cloud.qdrant.io:6333"

> export QDRANT_API_KEY="<your-api-key-here>"

!!! info "Required packages"
Anush008 marked this conversation as resolved.
Show resolved Hide resolved

- `qdrant-client>=1.12.0`
"""

DOC_CONTENT_KEY = "__document"

@classmethod
def requirements(cls) -> list[Requirement]:
return [
*super().requirements(),
PackageRequirement("qdrant-client>=1.12.1"),
]

def __init__(self) -> None:
super().__init__()

from qdrant_client import QdrantClient

url = os.getenv("QDRANT_URL")
api_key = os.getenv("QDRANT_API_KEY")
path = ragna.local_root() / "qdrant"

# Cannot pass both url and path
self._client = (
QdrantClient(url=url, api_key=api_key) if url else QdrantClient(path=path)
)

def list_corpuses(self) -> list[str]:
return [c.name for c in self._client.get_collections().collections]

def _ensure_table(self, corpus_name: str, *, create: bool = False):
table_names = self.list_corpuses()
no_corpuses = not table_names
non_existing_corpus = corpus_name not in table_names

if non_existing_corpus and create:
from qdrant_client import models

self._client.create_collection(
collection_name=corpus_name,
vectors_config=models.VectorParams(
size=self._embedding_dimensions, distance=models.Distance.COSINE
),
)
elif no_corpuses:
raise_no_corpuses_available(self)
elif non_existing_corpus:
raise_non_existing_corpus(self, corpus_name)

def list_metadata(
self, corpus_name: Optional[str] = None
) -> dict[str, dict[str, tuple[str, list[Any]]]]:
if corpus_name is None:
corpus_names = self.list_corpuses()
else:
corpus_names = [corpus_name]

metadata = {}
for corpus_name in corpus_names:
points, _offset = self._client.scroll(
collection_name=corpus_name, with_payload=True
)

corpus_metadata = defaultdict(set)
for point in points:
for key, value in point.payload.items():
if any(
[
(key.startswith("__") and key.endswith("__")),
key == self.DOC_CONTENT_KEY,
not value,
]
):
continue

corpus_metadata[key].add(value)

metadata[corpus_name] = {
key: ({type(value).__name__ for value in values}.pop(), sorted(values))
for key, values in corpus_metadata.items()
}

return metadata

def store(
self,
corpus_name: str,
documents: list[Document],
*,
chunk_size: int = 500,
chunk_overlap: int = 250,
) -> None:
from qdrant_client import models

self._ensure_table(corpus_name, create=True)

points = []
for document in documents:
for chunk in self._chunk_pages(
document.extract_pages(),
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
):
points.append(
models.PointStruct(
id=str(uuid.uuid4()),
vector=self._embedding_function([chunk.text])[0],
payload={
"document_id": str(document.id),
"document_name": document.name,
**document.metadata,
"__page_numbers__": self._page_numbers_to_str(
chunk.page_numbers
),
"__num_tokens__": chunk.num_tokens,
self.DOC_CONTENT_KEY: chunk.text,
},
)
)

self._client.upsert(collection_name=corpus_name, points=points)

def _build_condition(self, operator, key, value):
from qdrant_client import models

# See https://qdrant.tech/documentation/concepts/filtering/#range
if operator == MetadataOperator.EQ:
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
return models.FieldCondition(key=key, match=models.MatchValue(value=value))
elif operator == MetadataOperator.LT:
return models.FieldCondition(key=key, range=models.Range(lt=value))
elif operator == MetadataOperator.LE:
return models.FieldCondition(key=key, range=models.Range(lte=value))
elif operator == MetadataOperator.GT:
return models.FieldCondition(key=key, range=models.Range(gt=value))
elif operator == MetadataOperator.GE:
return models.FieldCondition(key=key, range=models.Range(gte=value))
elif operator == MetadataOperator.IN:
return models.FieldCondition(key=key, match=models.MatchAny(any=value))
elif operator in {MetadataOperator.NE, MetadataOperator.NOT_IN}:
except_value = [value] if operator == MetadataOperator.NE else value
return models.FieldCondition(
key=key, match=models.MatchExcept(**{"except": except_value})
)
else:
raise ValueError(f"Unsupported operator: {operator}")
Anush008 marked this conversation as resolved.
Show resolved Hide resolved

def _translate_metadata_filter(
self, metadata_filter: MetadataFilter
) -> models.Filter:
from qdrant_client import models

if metadata_filter.operator is MetadataOperator.RAW:
return cast(models.Filter, metadata_filter.value)
Anush008 marked this conversation as resolved.
Show resolved Hide resolved
elif metadata_filter.operator == MetadataOperator.AND:
return models.Filter(
must=[
self._translate_metadata_filter(child)
for child in metadata_filter.value
]
)
elif metadata_filter.operator == MetadataOperator.OR:
return models.Filter(
should=[
self._translate_metadata_filter(child)
for child in metadata_filter.value
]
)

return self._build_condition(
metadata_filter.operator, metadata_filter.key, metadata_filter.value
)

def retrieve(
self,
corpus_name: str,
metadata_filter: Optional[MetadataFilter],
prompt: str,
*,
chunk_size: int = 500,
num_tokens: int = 1024,
) -> list[Source]:
from qdrant_client import models

self._ensure_table(corpus_name)

# We cannot retrieve source by a maximum number of tokens. Thus, we estimate how
# many sources we have to query. We overestimate by a factor of two to avoid
# retrieving too few sources and needing to query again.
limit = int(num_tokens * 2 / chunk_size)

query_vector = self._embedding_function([prompt])[0]

search_filter = (
self._translate_metadata_filter(metadata_filter)
if metadata_filter
else None
)
if isinstance(search_filter, models.FieldCondition):
search_filter = models.Filter(must=[search_filter])

points = self._client.query_points(
collection_name=corpus_name,
query=query_vector,
limit=limit,
query_filter=search_filter,
with_payload=True,
).points

return self._take_sources_up_to_max_tokens(
(
Source(
id=point.id,
document_id=point.payload["document_id"],
document_name=point.payload["document_name"],
location=point.payload["__page_numbers__"],
content=point.payload[self.DOC_CONTENT_KEY],
num_tokens=point.payload["__num_tokens__"],
)
for point in points
),
max_tokens=num_tokens,
)
4 changes: 2 additions & 2 deletions tests/source_storages/test_source_storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
PlainTextDocumentHandler,
RagnaException,
)
from ragna.source_storages import Chroma, LanceDB, RagnaDemoSourceStorage
from ragna.source_storages import Chroma, LanceDB, Qdrant, RagnaDemoSourceStorage

SOURCE_STORAGES = [Chroma, LanceDB, RagnaDemoSourceStorage]
SOURCE_STORAGES = [Chroma, LanceDB, Qdrant, RagnaDemoSourceStorage]

METADATAS = {
0: {"key": "value"},
Expand Down