diff --git a/README.md b/README.md index 37bdcdf..380e1dd 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ The following environment variables are required to run the application: - `AWS_DEFAULT_REGION`: (Optional) defaults to `us-east-1` - `AWS_ACCESS_KEY_ID`: (Optional) needed for bedrock embeddings - `AWS_SECRET_ACCESS_KEY`: (Optional) needed for bedrock embeddings +- `QDRANT_API_KEY`: (Optional) api key, not needed for docker container Make sure to set these environment variables before running the application. You can set them in a `.env` file or as system environment variables. @@ -118,6 +119,19 @@ The `ATLAS_MONGO_DB_URI` could be the same or different from what is used by Lib Follow one of the [four documented methods](https://www.mongodb.com/docs/atlas/atlas-vector-search/create-index/#procedure) to create the vector index. +### Use Qdrant as Vector Database + +Additionally, we can use [Qdrant](https://qdrant.tech/documentation/) as the vector database. To do so, set the following environment variables + +```env +VECTOR_DB_TYPE=qdrant +DB_HOST= +DB_PORT= +COLLECTION_NAME= +QDRANT_API_KEY= +``` + +To download qdrant docker image: `docker pull qdrant/qdrant` and start container on port 6333: `docker run -p 6333:6333 qdrant/qdrant`. Set `DB_HOST=localhost` and `DB_PORT=6333`. Will automatically create a collection of name `COLLECTION_NAME` if not already exists. `QDRANT_API_KEY` not neccesary to use. ### Cloud Installation Settings: diff --git a/config.py b/config.py index 7b94a97..f4b39eb 100644 --- a/config.py +++ b/config.py @@ -15,6 +15,7 @@ class VectorDBType(Enum): PGVECTOR = "pgvector" ATLAS_MONGO = "atlas-mongo" + QDRANT="qdrant" class EmbeddingsProvider(Enum): @@ -60,13 +61,16 @@ def get_env_variable( MONGO_VECTOR_COLLECTION = get_env_variable( "MONGO_VECTOR_COLLECTION", None ) # Deprecated, backwards compatability +COLLECTION_NAME = get_env_variable( + "COLLECTION_NAME", "vector_collection" +) +QDRANT_API_KEY = get_env_variable("QDRANT_API_KEY") CHUNK_SIZE = int(get_env_variable("CHUNK_SIZE", "1500")) CHUNK_OVERLAP = int(get_env_variable("CHUNK_OVERLAP", "100")) env_value = get_env_variable("PDF_EXTRACT_IMAGES", "False").lower() PDF_EXTRACT_IMAGES = True if env_value == "true" else False -CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" DSN = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" ## Logging @@ -255,11 +259,12 @@ def init_embeddings(provider, model): # Vector store if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: + CONNECTION_STRING = f"postgresql+psycopg2://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{DB_HOST}:{DB_PORT}/{POSTGRES_DB}" vector_store = get_vector_store( connection_string=CONNECTION_STRING, embeddings=embeddings, collection_name=COLLECTION_NAME, - mode="async", + mode="async-PGVector", ) elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: # Backward compatability check @@ -276,6 +281,16 @@ def init_embeddings(provider, model): mode="atlas-mongo", search_index=ATLAS_SEARCH_INDEX, ) +elif VECTOR_DB_TYPE == VectorDBType.QDRANT: + CONNECTION_STRING = f"{DB_HOST}:{DB_PORT}" + vector_store = get_vector_store( + connection_string=CONNECTION_STRING, + embeddings=embeddings, + collection_name=COLLECTION_NAME, + mode="qdrant", + api_key=QDRANT_API_KEY + +) else: raise ValueError(f"Unsupported vector store type: {VECTOR_DB_TYPE}") diff --git a/main.py b/main.py index c82b550..1301989 100644 --- a/main.py +++ b/main.py @@ -49,7 +49,8 @@ from middleware import security_middleware from mongo import mongo_health_check from constants import ERROR_MESSAGES -from store import AsyncPgVector +from store import AsyncPgVector, AsyncQdrant +from store_factory import async_DB load_dotenv(find_dotenv()) @@ -105,7 +106,7 @@ async def lifespan(app: FastAPI): @app.get("/ids") async def get_all_ids(): try: - if isinstance(vector_store, AsyncPgVector): + if isinstance(vector_store, async_DB): ids = await vector_store.get_all_ids() else: ids = vector_store.get_all_ids() @@ -118,7 +119,7 @@ async def get_all_ids(): def isHealthOK(): if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: return pg_health_check() - if VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: + elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: return mongo_health_check() else: return True @@ -135,7 +136,7 @@ async def health_check(): @app.get("/documents", response_model=list[DocumentResponse]) async def get_documents_by_ids(ids: list[str] = Query(...)): try: - if isinstance(vector_store, AsyncPgVector): + if isinstance(vector_store, async_DB): existing_ids = await vector_store.get_all_ids() documents = await vector_store.get_documents_by_ids(ids) else: @@ -162,7 +163,7 @@ async def get_documents_by_ids(ids: list[str] = Query(...)): @app.delete("/documents") async def delete_documents(document_ids: List[str] = Body(...)): try: - if isinstance(vector_store, AsyncPgVector): + if isinstance(vector_store, async_DB): existing_ids = await vector_store.get_all_ids() await vector_store.delete(ids=document_ids) else: @@ -198,6 +199,12 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): k=body.k, filter={"file_id": body.file_id}, ) + elif isinstance(vector_store, AsyncQdrant): + documents = await vector_store.similarity_search_many( + query=body.query, + k=body.k, + ids=[body.file_id], + ) else: documents = vector_store.similarity_search_with_score_by_vector( embedding, k=body.k, filter={"file_id": body.file_id} @@ -260,7 +267,7 @@ async def store_data_in_vector_db( ] try: - if isinstance(vector_store, AsyncPgVector): + if isinstance(vector_store, async_DB): ids = await vector_store.aadd_documents( docs, ids=[file_id] * len(documents) ) @@ -443,7 +450,7 @@ async def embed_file( async def load_document_context(id: str): ids = [id] try: - if isinstance(vector_store, AsyncPgVector): + if isinstance(vector_store, async_DB): existing_ids = await vector_store.get_all_ids() documents = await vector_store.get_documents_by_ids(ids) else: @@ -536,6 +543,11 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody): k=body.k, filter={"file_id": {"$in": body.file_ids}}, ) + elif isinstance(vector_store, AsyncQdrant): + documents = await vector_store.similarity_search_many( + query=body.query, + k=body.k, + ids=body.file_ids) else: documents = vector_store.similarity_search_with_score_by_vector( embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}} diff --git a/requirements.lite.txt b/requirements.lite.txt index 3f554bb..abcfdd3 100644 --- a/requirements.lite.txt +++ b/requirements.lite.txt @@ -30,3 +30,4 @@ python-pptx==0.6.23 xlrd==2.0.1 langchain-aws==0.2.1 boto3==1.34.144 +qdrant-client==1.11.2 diff --git a/requirements.txt b/requirements.txt index 28f18a0..4b71f74 100644 --- a/requirements.txt +++ b/requirements.txt @@ -33,4 +33,5 @@ langchain-huggingface==0.1.0 cryptography==42.0.7 python-magic==0.4.27 python-pptx==0.6.23 -xlrd==2.0.1 \ No newline at end of file +xlrd==2.0.1 +qdrant-client==1.11.2 \ No newline at end of file diff --git a/store.py b/store.py index 92694ef..3a0050e 100644 --- a/store.py +++ b/store.py @@ -1,9 +1,13 @@ from typing import Any, Optional -from sqlalchemy import delete +from sqlalchemy import delete, func from langchain_community.vectorstores.pgvector import PGVector from langchain_core.documents import Document from langchain_core.runnables.config import run_in_executor from sqlalchemy.orm import Session +import qdrant_client as client +from langchain_qdrant import QdrantVectorStore +from qdrant_client.http import models +from uuid import uuid1 from langchain_mongodb import MongoDBAtlasVectorSearch from langchain_core.embeddings import Embeddings @@ -15,6 +19,7 @@ import copy + class ExtendedPgVector(PGVector): def get_all_ids(self) -> list[str]: @@ -75,6 +80,107 @@ async def delete( ) -> None: await run_in_executor(None, self._delete_multiple, ids, collection_only) + +class ExtendedQdrant(QdrantVectorStore): + @property + def embedding_function(self) -> Embeddings: + return self.embeddings + + def delete_vectors_by_source_document(self, source_document_ids: list[str]) -> None: + points_selector = models.Filter( + must=[ + models.FieldCondition( + key="metadata.file_id", + match=models.MatchAny(any=source_document_ids), + ), + ], + ) + response = self.client.delete(collection_name=self.collection_name, points_selector=points_selector) + status = response.status.name + return status + + + def get_all_ids(self) -> list[str]: + collection_info = self.client.get_collection(self.collection_name) + total_points = collection_info.points_count + + # Scroll through all points in the collection + unique_values = set() + pointsRec = 0 + limit = 500 #How much to load each time + next_offset = None + while pointsRec < total_points: + points, next_offset = self.client.scroll( + collection_name=self.collection_name, + limit=limit, + with_payload=True, + offset=next_offset, + ) + for point in points: + unique_values.add(point.payload['metadata']['file_id']) + pointsRec += limit + return list(unique_values) + + def get_documents_by_ids(self, ids: list[str]): + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.file_id", + match=models.MatchAny(any=ids) + ) + ] + ) + limit = 500 + next_offset = None + docList = [] + while True: + results = self.client.scroll( + collection_name=self.collection_name, + scroll_filter=filter, + limit=limit, + offset=next_offset + ) + points, next_offset = results + if points: + docList.extend([Document(page_content=point.payload['page_content'], + metadata=point.payload['metadata'] + ) + for point in points + ]) + if not next_offset: + break + return docList + + +class AsyncQdrant(ExtendedQdrant): + + async def get_all_ids(self) -> list[str]: + return await run_in_executor(None, super().get_all_ids) + + async def get_documents_by_ids(self, ids: list[str]) -> list[Document]: + return await run_in_executor(None, super().get_documents_by_ids, ids) + + async def delete( + self, + ids: list[str] + ) -> None: + # Garantir que o argumento correto está sendo passado + await run_in_executor(None, self.delete_vectors_by_source_document, ids) + + async def similarity_search_many(self, query:str, k:int, ids:list[str])-> List[Tuple[Document, float]]: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.file_id", + match=models.MatchAny(any=ids) + ) + ] + ) + results = await self.asimilarity_search_with_score(query=query, k=k, filter=filter) + return results + + async def aadd_documents(self, docs: list[Document], ids: list[str]): + return await super().aadd_documents(docs) class AtlasMongoVector(MongoDBAtlasVectorSearch): @property @@ -150,3 +256,4 @@ def delete(self, ids: Optional[list[str]] = None) -> None: # implement the deletion of documents by file_id in self._collection if ids is not None: self._collection.delete_many({"file_id": {"$in": ids}}) + diff --git a/store_factory.py b/store_factory.py index 6549e29..97e4a80 100644 --- a/store_factory.py +++ b/store_factory.py @@ -1,24 +1,30 @@ -from typing import Optional +from typing import Optional, TypedDict from langchain_core.embeddings import Embeddings from store import AsyncPgVector, ExtendedPgVector from store import AtlasMongoVector +from store import AsyncQdrant +import qdrant_client from pymongo import MongoClient +async_DB = (AsyncPgVector, AsyncQdrant) #Add if async database implementation + + def get_vector_store( connection_string: str, embeddings: Embeddings, collection_name: str, - mode: str = "sync", - search_index: Optional[str] = None + mode: str = "sync-PGVector", + search_index: Optional[str] = None, + api_key: Optional[str] = None ): - if mode == "sync": + if mode == "sync-PGVector": return ExtendedPgVector( connection_string=connection_string, embedding_function=embeddings, collection_name=collection_name, ) - elif mode == "async": + elif mode == "async-PGVector": return AsyncPgVector( connection_string=connection_string, embedding_function=embeddings, @@ -30,6 +36,35 @@ def get_vector_store( return AtlasMongoVector( collection=mong_collection, embedding=embeddings, index_name=search_index ) + elif mode == "qdrant": + embeddings_dimension = len(embeddings.embed_query("Dimension")) + client = qdrant_client.QdrantClient( + url=connection_string, + api_key=api_key + ) + collection_config = qdrant_client.http.models.VectorParams( + size=embeddings_dimension, + distance=qdrant_client.http.models.Distance.COSINE + ) + if not client.collection_exists(collection_name): + collection_config = qdrant_client.http.models.VectorParams( + size=embeddings_dimension, + distance=qdrant_client.http.models.Distance.COSINE + ) + client.create_collection( + collection_name=collection_name, + vectors_config=collection_config + ) + client.create_payload_index( + collection_name=collection_name, + field_name="metadata.file_id", + field_schema="keyword", + ) + return AsyncQdrant( + client=client, + collection_name=collection_name, + embedding=embeddings + ) else: raise ValueError("Invalid mode specified. Choose 'sync' or 'async'.")