Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add pgvector retriever support #6

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Override PGVector for correct async metadata deserialization, fixes #124
shamspias committed Oct 8, 2024
commit 6488c1a8205b71e1d8f9c724dc6335b3b364f540
73 changes: 57 additions & 16 deletions src/retrieval_graph/retrieval.py
Original file line number Diff line number Diff line change
@@ -8,15 +8,14 @@

import os
from contextlib import contextmanager
from typing import Generator, Any, Dict
from typing import Generator

from langchain_core.embeddings import Embeddings
from langchain_core.runnables import RunnableConfig
from langchain_core.vectorstores import VectorStoreRetriever

from retrieval_graph.configuration import Configuration, IndexConfiguration


## Encoder constructors


@@ -110,7 +109,60 @@ def make_pgvector_retriever(
configuration: IndexConfiguration, embedding_model: Embeddings
) -> Generator[VectorStoreRetriever, None, None]:
"""Configure this agent to connect to a pgvector index."""
from langchain_postgres.vectorstores import PGVector
import json
from typing import Any, List, Tuple
from langchain_postgres.vectorstores import PGVector as OverPGVector
from langchain_core.documents import Document

class PGVector(OverPGVector):
"""
A custom override of the PGVector class to handle metadata deserialization issues
when operating in async_mode. This class addresses a known issue where metadata,
stored as byte data, is not properly converted back into a dictionary format
during asynchronous operations.
The override specifically ensures that all metadata, whether stored as bytes,
strings, or other unrecognized formats, is correctly processed into a dictionary
format suitable for use within the application. This is crucial for maintaining
consistency and usability of metadata across asynchronous database interactions.
Issue Reference:
"Metadata field not properly deserialized when using async_mode=True with PGVector #124"
Methods:
_results_to_docs_and_scores: Converts query results from PGVector into a list
of tuples, each containing a Document and its corresponding
score, while ensuring metadata is correctly deserialized.
"""
def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
"""Return docs and scores from results."""
docs = []
for result in results:
# Access the metadata
metadata = result.EmbeddingStore.cmetadata

# Process the metadata to ensure it's a dict
if not isinstance(metadata, dict):
if hasattr(metadata, 'buf'):
# For Fragment types (e.g., asyncpg.Record)
metadata_bytes = metadata.buf
metadata_str = metadata_bytes.decode('utf-8')
metadata = json.loads(metadata_str)
elif isinstance(metadata, str):
# If it's a JSON string
metadata = json.loads(metadata)
else:
# Handle other types if necessary
metadata = {}

doc = Document(
id=str(result.EmbeddingStore.id),
page_content=result.EmbeddingStore.document,
metadata=metadata,
)
score = result.distance if self.embeddings is not None else None
docs.append((doc, score))
return docs

connection_string = os.environ.get("PGVECTOR_CONNECTION_STRING")
if not connection_string:
@@ -129,21 +181,10 @@ def make_pgvector_retriever(
)

search_kwargs = configuration.search_kwargs

# Ensure search_kwargs is a dictionary
if not isinstance(search_kwargs, dict):
search_kwargs = {}

search_kwargs: Dict[str, Any] = search_kwargs # Explicit type annotation
# Add a filter to ensure we only retrieve documents for the given user_id
user_id = configuration.user_id
metadata_filter = search_kwargs.setdefault("filter", {})
metadata_filter["user_id"] = {"$eq": user_id}

# Create a retriever from the vector store
retriever = vstore.as_retriever(search_kwargs=search_kwargs)

yield retriever
yield vstore.as_retriever(search_kwargs=search_kwargs)


@contextmanager
@@ -178,4 +219,4 @@ def make_retriever(
"Unrecognized retriever_provider in configuration. "
f"Expected one of: {', '.join(Configuration.__annotations__['retriever_provider'].__args__)}\n"
f"Got: {configuration.retriever_provider}"
)
)