Skip to content

feat: Qdrant Vector Database #81

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ The following environment variables are required to run the application:
- `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1`
- `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings
- `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings
- `QDRANT_API_KEY`: (Optional) api key, not needed for docker container

Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables.

Expand Down Expand Up @@ -118,6 +119,19 @@ The `ATLAS_MONGO_DB_URI` could be the same or different from what is used by Lib

Follow one of the [four documented methods](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure) to create the vector index.

### Use Qdrant as Vector Database

Additionally, we can use [Qdrant](https://qdrant.tech/documentation/) as the vector database. To do so, set the following environment variables

```env
VECTOR_DB_TYPE=qdrant
DB_HOST=<host address>
DB_PORT=<port number>
COLLECTION_NAME=<vector collection>
QDRANT_API_KEY=<api key>
```

To download qdrant docker image: `docker pull qdrant/qdrant` and start container on port 6333: `docker run -p 6333:6333 qdrant/qdrant`. Set `DB_HOST=localhost` and `DB_PORT=6333`. Will automatically create a collection of name `COLLECTION_NAME` if not already exists. `QDRANT_API_KEY` not neccesary to use.

### Cloud Installation Settings:

Expand Down
19 changes: 17 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
class VectorDBType(Enum):
PGVECTOR = "pgvector"
ATLAS_MONGO = "atlas-mongo"
QDRANT="qdrant"


class EmbeddingsProvider(Enum):
Expand Down Expand Up @@ -60,13 +61,16 @@ def get_env_variable(
MONGO_VECTOR_COLLECTION = get_env_variable(
"MONGO_VECTOR_COLLECTION", None
) # Deprecated, backwards compatability
COLLECTION_NAME = get_env_variable(
"COLLECTION_NAME", "vector_collection"
)
QDRANT_API_KEY = get_env_variable("QDRANT_API_KEY")
CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100"))

env_value = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower()
PDF_EXTRACT_IMAGES = True if env_value == "true" else False

CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"
DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"

## Logging
Expand Down Expand Up @@ -255,11 +259,12 @@ def init_embeddings(provider, model):

# Vector store
if VECTOR_DB_TYPE == VectorDBType.PGVECTOR:
CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}"
vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="async",
mode="async-PGVector",
)
elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO:
# Backward compatability check
Expand All @@ -276,6 +281,16 @@ def init_embeddings(provider, model):
mode="atlas-mongo",
search_index=ATLAS_SEARCH_INDEX,
)
elif VECTOR_DB_TYPE == VectorDBType.QDRANT:
CONNECTION_STRING = f"{DB_HOST}:{DB_PORT}"
vector_store = get_vector_store(
connection_string=CONNECTION_STRING,
embeddings=embeddings,
collection_name=COLLECTION_NAME,
mode="qdrant",
api_key=QDRANT_API_KEY

)
else:
raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}")

Expand Down
26 changes: 19 additions & 7 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@
from middleware import security_middleware
from mongo import mongo_health_check
from constants import ERROR_MESSAGES
from store import AsyncPgVector
from store import AsyncPgVector, AsyncQdrant
from store_factory import async_DB

load_dotenv(find_dotenv())

Expand Down Expand Up @@ -105,7 +106,7 @@ async def lifespan(app: FastAPI):
@app.get("/ids")
async def get_all_ids():
try:
if isinstance(vector_store, AsyncPgVector):
if isinstance(vector_store, async_DB):
ids = await vector_store.get_all_ids()
else:
ids = vector_store.get_all_ids()
Expand All @@ -118,7 +119,7 @@ async def get_all_ids():
def isHealthOK():
if VECTOR_DB_TYPE == VectorDBType.PGVECTOR:
return pg_health_check()
if VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO:
elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO:
return mongo_health_check()
else:
return True
Expand All @@ -135,7 +136,7 @@ async def health_check():
@app.get("/documents", response_model=list[DocumentResponse])
async def get_documents_by_ids(ids: list[str] = Query(...)):
try:
if isinstance(vector_store, AsyncPgVector):
if isinstance(vector_store, async_DB):
existing_ids = await vector_store.get_all_ids()
documents = await vector_store.get_documents_by_ids(ids)
else:
Expand All @@ -162,7 +163,7 @@ async def get_documents_by_ids(ids: list[str] = Query(...)):
@app.delete("/documents")
async def delete_documents(document_ids: List[str] = Body(...)):
try:
if isinstance(vector_store, AsyncPgVector):
if isinstance(vector_store, async_DB):
existing_ids = await vector_store.get_all_ids()
await vector_store.delete(ids=document_ids)
else:
Expand Down Expand Up @@ -198,6 +199,12 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
k=body.k,
filter={"file_id": body.file_id},
)
elif isinstance(vector_store, AsyncQdrant):
documents = await vector_store.similarity_search_many(
query=body.query,
k=body.k,
ids=[body.file_id],
)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding, k=body.k, filter={"file_id": body.file_id}
Expand Down Expand Up @@ -260,7 +267,7 @@ async def store_data_in_vector_db(
]

try:
if isinstance(vector_store, AsyncPgVector):
if isinstance(vector_store, async_DB):
ids = await vector_store.aadd_documents(
docs, ids=[file_id] * len(documents)
)
Expand Down Expand Up @@ -443,7 +450,7 @@ async def embed_file(
async def load_document_context(id: str):
ids = [id]
try:
if isinstance(vector_store, AsyncPgVector):
if isinstance(vector_store, async_DB):
existing_ids = await vector_store.get_all_ids()
documents = await vector_store.get_documents_by_ids(ids)
else:
Expand Down Expand Up @@ -536,6 +543,11 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
k=body.k,
filter={"file_id": {"$in": body.file_ids}},
)
elif isinstance(vector_store, AsyncQdrant):
documents = await vector_store.similarity_search_many(
query=body.query,
k=body.k,
ids=body.file_ids)
else:
documents = vector_store.similarity_search_with_score_by_vector(
embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}}
Expand Down
1 change: 1 addition & 0 deletions requirements.lite.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,4 @@ python-pptx==0.6.23
xlrd==2.0.1
langchain-aws==0.2.1
boto3==1.34.144
qdrant-client==1.11.2
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,5 @@ langchain-huggingface==0.1.0
cryptography==42.0.7
python-magic==0.4.27
python-pptx==0.6.23
xlrd==2.0.1
xlrd==2.0.1
qdrant-client==1.11.2
109 changes: 108 additions & 1 deletion store.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
from typing import Any, Optional
from sqlalchemy import delete
from sqlalchemy import delete, func
from langchain_community.vectorstores.pgvector import PGVector
from langchain_core.documents import Document
from langchain_core.runnables.config import run_in_executor
from sqlalchemy.orm import Session
import qdrant_client as client
from langchain_qdrant import QdrantVectorStore
from qdrant_client.http import models
from uuid import uuid1

from langchain_mongodb import MongoDBAtlasVectorSearch
from langchain_core.embeddings import Embeddings
Expand All @@ -15,6 +19,7 @@
import copy



class ExtendedPgVector(PGVector):

def get_all_ids(self) -> list[str]:
Expand Down Expand Up @@ -75,6 +80,107 @@ async def delete(
) -> None:
await run_in_executor(None, self._delete_multiple, ids, collection_only)


class ExtendedQdrant(QdrantVectorStore):
@property
def embedding_function(self) -> Embeddings:
return self.embeddings

def delete_vectors_by_source_document(self, source_document_ids: list[str]) -> None:
points_selector = models.Filter(
must=[
models.FieldCondition(
key="metadata.file_id",
match=models.MatchAny(any=source_document_ids),
),
],
)
response = self.client.delete(collection_name=self.collection_name, points_selector=points_selector)
status = response.status.name
return status


def get_all_ids(self) -> list[str]:
collection_info = self.client.get_collection(self.collection_name)
total_points = collection_info.points_count

# Scroll through all points in the collection
unique_values = set()
pointsRec = 0
limit = 500 #How much to load each time
next_offset = None
while pointsRec < total_points:
points, next_offset = self.client.scroll(
collection_name=self.collection_name,
limit=limit,
with_payload=True,
offset=next_offset,
)
for point in points:
unique_values.add(point.payload['metadata']['file_id'])
pointsRec += limit
return list(unique_values)

def get_documents_by_ids(self, ids: list[str]):
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.file_id",
match=models.MatchAny(any=ids)
)
]
)
limit = 500
next_offset = None
docList = []
while True:
results = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=filter,
limit=limit,
offset=next_offset
)
points, next_offset = results
if points:
docList.extend([Document(page_content=point.payload['page_content'],
metadata=point.payload['metadata']
)
for point in points
])
if not next_offset:
break
return docList


class AsyncQdrant(ExtendedQdrant):

async def get_all_ids(self) -> list[str]:
return await run_in_executor(None, super().get_all_ids)

async def get_documents_by_ids(self, ids: list[str]) -> list[Document]:
return await run_in_executor(None, super().get_documents_by_ids, ids)

async def delete(
self,
ids: list[str]
) -> None:
# Garantir que o argumento correto está sendo passado
await run_in_executor(None, self.delete_vectors_by_source_document, ids)

async def similarity_search_many(self, query:str, k:int, ids:list[str])-> List[Tuple[Document, float]]:
filter = models.Filter(
must=[
models.FieldCondition(
key="metadata.file_id",
match=models.MatchAny(any=ids)
)
]
)
results = await self.asimilarity_search_with_score(query=query, k=k, filter=filter)
return results

async def aadd_documents(self, docs: list[Document], ids: list[str]):
return await super().aadd_documents(docs)

class AtlasMongoVector(MongoDBAtlasVectorSearch):
@property
Expand Down Expand Up @@ -150,3 +256,4 @@ def delete(self, ids: Optional[list[str]] = None) -> None:
# implement the deletion of documents by file_id in self._collection
if ids is not None:
self._collection.delete_many({"file_id": {"$in": ids}})

45 changes: 40 additions & 5 deletions store_factory.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
from typing import Optional
from typing import Optional, TypedDict
from langchain_core.embeddings import Embeddings
from store import AsyncPgVector, ExtendedPgVector
from store import AtlasMongoVector
from store import AsyncQdrant
import qdrant_client
from pymongo import MongoClient


async_DB = (AsyncPgVector, AsyncQdrant) #Add if async database implementation


def get_vector_store(
connection_string: str,
embeddings: Embeddings,
collection_name: str,
mode: str = "sync",
search_index: Optional[str] = None
mode: str = "sync-PGVector",
search_index: Optional[str] = None,
api_key: Optional[str] = None
):
if mode == "sync":
if mode == "sync-PGVector":
return ExtendedPgVector(
connection_string=connection_string,
embedding_function=embeddings,
collection_name=collection_name,
)
elif mode == "async":
elif mode == "async-PGVector":
return AsyncPgVector(
connection_string=connection_string,
embedding_function=embeddings,
Expand All @@ -30,6 +36,35 @@ def get_vector_store(
return AtlasMongoVector(
collection=mong_collection, embedding=embeddings, index_name=search_index
)
elif mode == "qdrant":
embeddings_dimension = len(embeddings.embed_query("Dimension"))
client = qdrant_client.QdrantClient(
url=connection_string,
api_key=api_key
)
collection_config = qdrant_client.http.models.VectorParams(
size=embeddings_dimension,
distance=qdrant_client.http.models.Distance.COSINE
)
if not client.collection_exists(collection_name):
collection_config = qdrant_client.http.models.VectorParams(
size=embeddings_dimension,
distance=qdrant_client.http.models.Distance.COSINE
)
client.create_collection(
collection_name=collection_name,
vectors_config=collection_config
)
client.create_payload_index(
collection_name=collection_name,
field_name="metadata.file_id",
field_schema="keyword",
)
return AsyncQdrant(
client=client,
collection_name=collection_name,
embedding=embeddings
)

else:
raise ValueError("Invalid mode specified. Choose 'sync' or 'async'.")
Expand Down