Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure Documents to support bulk embedding #87

Merged
merged 14 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ To index your models:
```python
from django.db import models
from wagtail.models import Page
from wagtail_vector_index.storage.models import VectorIndexedMixin, EmbeddingField
from wagtail_vector_index.storage.django import VectorIndexedMixin, EmbeddingField


class MyPage(VectorIndexedMixin, Page):
Expand Down
8 changes: 4 additions & 4 deletions docs/vector-indexes.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ If you don't need to customise the way your index behaves, you can automatically
```python
from django.db import models
from wagtail.models import Page
from wagtail_vector_index.storage.models import VectorIndexedMixin, EmbeddingField
from wagtail_vector_index.storage.django import VectorIndexedMixin, EmbeddingField


class MyPage(VectorIndexedMixin, Page):
Expand All @@ -44,7 +44,7 @@ The `VectorIndexedMixin` class is made up of two other mixins:
If you want to customise your vector index, you can build your own `VectorIndex` class and configure your model to use it with the `vector_index_class` property:

```python
from wagtail_vector_index.storage.models import (
from wagtail_vector_index.storage.django import (
EmbeddableFieldsVectorIndexMixin,
DefaultStorageVectorIndex,
)
Expand All @@ -71,7 +71,7 @@ One of the things you might want to do with a custom index is query across multi
To do this, override `querysets` or `_get_querysets()` on your custom Vector Index class:

```python
from wagtail_vector_index.storage.models import (
from wagtail_vector_index.storage.django import (
EmbeddableFieldsVectorIndexMixin,
DefaultStorageVectorIndex,
)
Expand Down Expand Up @@ -121,7 +121,7 @@ You might want to customise this behavior. To do this you can create your own `D


```python
from wagtail_vector_index.storage.models import (
from wagtail_vector_index.storage.django import (
EmbeddableFieldsVectorIndexMixin,
EmbeddableFieldsDocumentConverter,
DefaultStorageVectorIndex,
Expand Down
2 changes: 1 addition & 1 deletion src/wagtail_vector_index/ai_utils/backends/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from django.core.exceptions import ImproperlyConfigured

from wagtail_vector_index.storage.base import Document
from wagtail_vector_index.storage.models import Document

from ..types import (
AIResponse,
Expand Down
2 changes: 1 addition & 1 deletion src/wagtail_vector_index/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ class WagtailVectorIndexAppConfig(AppConfig):
default_auto_field = "django.db.models.AutoField"

def ready(self):
from wagtail_vector_index.storage.models import register_indexed_models
from wagtail_vector_index.storage.django import register_indexed_models

register_indexed_models()
16 changes: 16 additions & 0 deletions src/wagtail_vector_index/migrations/0002_rename_embedding_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Generated by Django 5.0.1 on 2024-09-06 10:26

from django.db import migrations


class Migration(migrations.Migration):
dependencies = [
("wagtail_vector_index", "0001_initial"),
]

operations = [
migrations.RenameModel(
old_name="Embedding",
new_name="Document",
),
]
38 changes: 38 additions & 0 deletions src/wagtail_vector_index/migrations/0003_adjust_document_fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Generated by Django 5.0.1 on 2024-09-06 10:26

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("pgvector", "0003_alter_pgvectorembedding_embedding"),
(
"wagtail_vector_index",
"0002_rename_embedding_model",
),
]

operations = [
migrations.RemoveField(
model_name="document",
name="base_content_type",
),
migrations.RemoveField(
model_name="document",
name="content_type",
),
migrations.RemoveField(
model_name="document",
name="object_id",
),
migrations.AddField(
model_name="document",
name="object_keys",
field=models.JSONField(default=list),
),
migrations.AddField(
model_name="document",
name="metadata",
field=models.JSONField(default=dict),
),
]
118 changes: 85 additions & 33 deletions src/wagtail_vector_index/storage/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
from abc import ABC
from collections.abc import AsyncGenerator, Generator, Iterable, Mapping, Sequence
from dataclasses import dataclass
from typing import Any, ClassVar, Generic, Protocol, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, Type, TypeVar

from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
Expand All @@ -12,9 +13,15 @@
get_storage_provider,
)

if TYPE_CHECKING:
from wagtail_vector_index.storage.models import Document

StorageProviderClass = TypeVar("StorageProviderClass")
ConfigClass = TypeVar("ConfigClass")
IndexMixin = TypeVar("IndexMixin")
FromObjectType = TypeVar("FromObjectType", contravariant=True)
ChunkedObjectType = TypeVar("ChunkedObjectType", covariant=False)
ToObjectType = TypeVar("ToObjectType", covariant=True)


class DocumentRetrievalVectorIndexMixinProtocol(Protocol):
Expand Down Expand Up @@ -64,39 +71,86 @@ def __init_subclass__(cls, **kwargs: Any) -> None:
return super().__init_subclass__(**kwargs)


@dataclass(kw_only=True, frozen=True)
class Document:
"""Representation of some content that is passed to vector storage backends.
class FromDocumentOperator(Protocol[ToObjectType]):
"""Protocol for a class that can convert a Document to an object"""

A document is usually a part of a model, e.g. some content split out from
a VectorIndexedMixin model. One model instance may have multiple documents.
def from_document(self, document: "Document") -> ToObjectType: ...

def bulk_from_documents(
self, documents: Iterable["Document"]
) -> Generator[ToObjectType, None, None]: ...

async def abulk_from_documents(
self, documents: Iterable["Document"]
) -> AsyncGenerator[ToObjectType, None]: ...


class ObjectChunkerOperator(Protocol[ChunkedObjectType]):
"""Protocol for a class that can chunk an object into smaller chunks"""

def chunk_object(
self, object: ChunkedObjectType, chunk_size: int
) -> Iterable[ChunkedObjectType]: ...


class ToDocumentOperator(Protocol[FromObjectType]):
"""Protocol for a class that can convert an object to a Document"""

def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): ...

def to_documents(
self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend
) -> Generator["Document", None, None]: ...

def bulk_to_documents(
self,
objects: Iterable[FromObjectType],
*,
embedding_backend: BaseEmbeddingBackend,
) -> Generator["Document", None, None]: ...

The embedding_pk on a Document must be the PK of an Embedding model instance.
"""

vector: Sequence[float]
embedding_pk: int
metadata: Mapping
class DocumentConverter(ABC):
"""Base class for a DocumentConverter that can convert objects to Documents and vice versa"""

to_document_operator_class: Type[ToDocumentOperator]
from_document_operator_class: Type[FromDocumentOperator]
object_chunker_operator_class: Type[ObjectChunkerOperator]

@property
def to_document_operator(self) -> ToDocumentOperator:
return self.to_document_operator_class(self.object_chunker_operator_class)

@property
def from_document_operator(self) -> FromDocumentOperator:
return self.from_document_operator_class()

class DocumentConverter(Protocol):
def to_documents(
self, object: object, *, embedding_backend: BaseEmbeddingBackend
) -> Generator[Document, None, None]: ...
) -> Generator["Document", None, None]:
return self.to_document_operator.to_documents(
object, embedding_backend=embedding_backend
)

def from_document(self, document: Document) -> object: ...
def from_document(self, document: "Document") -> object:
return self.from_document_operator.from_document(document)

def bulk_to_documents(
self, objects: Iterable[object], *, embedding_backend: BaseEmbeddingBackend
) -> Generator[Document, None, None]: ...
) -> Generator["Document", None, None]:
return self.to_document_operator.bulk_to_documents(
objects, embedding_backend=embedding_backend
)

def bulk_from_documents(
self, documents: Iterable[Document]
) -> Generator[object, None, None]: ...
self, documents: Sequence["Document"]
) -> Generator[object, None, None]:
return self.from_document_operator.bulk_from_documents(documents)

async def abulk_from_documents(
self, documents: Iterable[Document]
) -> AsyncGenerator[object, None]: ...
def abulk_from_documents(
self, documents: Sequence["Document"]
) -> AsyncGenerator[object, None]:
return self.from_document_operator.abulk_from_documents(documents)


@dataclass
Expand Down Expand Up @@ -129,7 +183,7 @@ class VectorIndex(Generic[ConfigClass]):
def get_embedding_backend(self) -> BaseEmbeddingBackend:
return get_embedding_backend(self.embedding_backend_alias)

def get_documents(self) -> Iterable[Document]:
def get_documents(self) -> Iterable["Document"]:
raise NotImplementedError

def get_converter(self) -> DocumentConverter:
Expand Down Expand Up @@ -159,7 +213,7 @@ def query(

sources = list(self.get_converter().bulk_from_documents(similar_documents))

merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents)
merged_context = "\n".join(doc.content for doc in similar_documents)
prompt = (
getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None)
or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer."
Expand Down Expand Up @@ -196,14 +250,10 @@ async def aquery(
)
]

sources = [
source
async for source in self.get_converter().abulk_from_documents(
similar_documents
)
]
similar_objects = self.get_converter().abulk_from_documents(similar_documents)
sources = [source async for source in similar_objects]

merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents)
merged_context = "\n".join([doc.content for doc in similar_documents])
prompt = (
getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None)
or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer."
Expand Down Expand Up @@ -258,8 +308,10 @@ def search(
query_embedding = next(self.get_embedding_backend().embed([query]))
except StopIteration as e:
raise ValueError("No embeddings were generated for the given query.") from e
similar_documents = self.get_similar_documents(
query_embedding, limit=limit, similarity_threshold=similarity_threshold
similar_documents = list(
self.get_similar_documents(
query_embedding, limit=limit, similarity_threshold=similarity_threshold
)
)
return list(self.get_converter().bulk_from_documents(similar_documents))

Expand Down Expand Up @@ -293,10 +345,10 @@ def get_similar_documents(
*,
limit: int = 5,
similarity_threshold: float = 0.0,
) -> Generator[Document, None, None]:
) -> Generator["Document", None, None]:
raise NotImplementedError

def aget_similar_documents(
self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0
) -> AsyncGenerator[Document, None]:
) -> AsyncGenerator["Document", None]:
raise NotImplementedError
Loading
Loading