diff --git a/docs/quick-start.md b/docs/quick-start.md index ea3578a..013c0df 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -25,7 +25,7 @@ To index your models: ```python from django.db import models from wagtail.models import Page -from wagtail_vector_index.storage.models import VectorIndexedMixin, EmbeddingField +from wagtail_vector_index.storage.django import VectorIndexedMixin, EmbeddingField class MyPage(VectorIndexedMixin, Page): diff --git a/docs/vector-indexes.md b/docs/vector-indexes.md index 5a13d93..180c2fa 100644 --- a/docs/vector-indexes.md +++ b/docs/vector-indexes.md @@ -19,7 +19,7 @@ If you don't need to customise the way your index behaves, you can automatically ```python from django.db import models from wagtail.models import Page -from wagtail_vector_index.storage.models import VectorIndexedMixin, EmbeddingField +from wagtail_vector_index.storage.django import VectorIndexedMixin, EmbeddingField class MyPage(VectorIndexedMixin, Page): @@ -44,7 +44,7 @@ The `VectorIndexedMixin` class is made up of two other mixins: If you want to customise your vector index, you can build your own `VectorIndex` class and configure your model to use it with the `vector_index_class` property: ```python -from wagtail_vector_index.storage.models import ( +from wagtail_vector_index.storage.django import ( EmbeddableFieldsVectorIndexMixin, DefaultStorageVectorIndex, ) @@ -71,7 +71,7 @@ One of the things you might want to do with a custom index is query across multi To do this, override `querysets` or `_get_querysets()` on your custom Vector Index class: ```python -from wagtail_vector_index.storage.models import ( +from wagtail_vector_index.storage.django import ( EmbeddableFieldsVectorIndexMixin, DefaultStorageVectorIndex, ) @@ -121,7 +121,7 @@ You might want to customise this behavior. To do this you can create your own `D ```python -from wagtail_vector_index.storage.models import ( +from wagtail_vector_index.storage.django import ( EmbeddableFieldsVectorIndexMixin, EmbeddableFieldsDocumentConverter, DefaultStorageVectorIndex, diff --git a/src/wagtail_vector_index/ai_utils/backends/echo.py b/src/wagtail_vector_index/ai_utils/backends/echo.py index f7c284c..4efc3b9 100644 --- a/src/wagtail_vector_index/ai_utils/backends/echo.py +++ b/src/wagtail_vector_index/ai_utils/backends/echo.py @@ -10,7 +10,7 @@ from django.core.exceptions import ImproperlyConfigured -from wagtail_vector_index.storage.base import Document +from wagtail_vector_index.storage.models import Document from ..types import ( AIResponse, diff --git a/src/wagtail_vector_index/apps.py b/src/wagtail_vector_index/apps.py index 3c65c9c..232d880 100644 --- a/src/wagtail_vector_index/apps.py +++ b/src/wagtail_vector_index/apps.py @@ -8,6 +8,6 @@ class WagtailVectorIndexAppConfig(AppConfig): default_auto_field = "django.db.models.AutoField" def ready(self): - from wagtail_vector_index.storage.models import register_indexed_models + from wagtail_vector_index.storage.django import register_indexed_models register_indexed_models() diff --git a/src/wagtail_vector_index/migrations/0002_rename_embedding_model.py b/src/wagtail_vector_index/migrations/0002_rename_embedding_model.py new file mode 100644 index 0000000..b9d0772 --- /dev/null +++ b/src/wagtail_vector_index/migrations/0002_rename_embedding_model.py @@ -0,0 +1,16 @@ +# Generated by Django 5.0.1 on 2024-09-06 10:26 + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("wagtail_vector_index", "0001_initial"), + ] + + operations = [ + migrations.RenameModel( + old_name="Embedding", + new_name="Document", + ), + ] diff --git a/src/wagtail_vector_index/migrations/0003_adjust_document_fields.py b/src/wagtail_vector_index/migrations/0003_adjust_document_fields.py new file mode 100644 index 0000000..9f2092a --- /dev/null +++ b/src/wagtail_vector_index/migrations/0003_adjust_document_fields.py @@ -0,0 +1,38 @@ +# Generated by Django 5.0.1 on 2024-09-06 10:26 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("pgvector", "0003_alter_pgvectorembedding_embedding"), + ( + "wagtail_vector_index", + "0002_rename_embedding_model", + ), + ] + + operations = [ + migrations.RemoveField( + model_name="document", + name="base_content_type", + ), + migrations.RemoveField( + model_name="document", + name="content_type", + ), + migrations.RemoveField( + model_name="document", + name="object_id", + ), + migrations.AddField( + model_name="document", + name="object_keys", + field=models.JSONField(default=list), + ), + migrations.AddField( + model_name="document", + name="metadata", + field=models.JSONField(default=dict), + ), + ] diff --git a/src/wagtail_vector_index/storage/base.py b/src/wagtail_vector_index/storage/base.py index 35c084c..0cb81c4 100644 --- a/src/wagtail_vector_index/storage/base.py +++ b/src/wagtail_vector_index/storage/base.py @@ -1,7 +1,8 @@ import copy +from abc import ABC from collections.abc import AsyncGenerator, Generator, Iterable, Mapping, Sequence from dataclasses import dataclass -from typing import Any, ClassVar, Generic, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Protocol, Type, TypeVar from django.conf import settings from django.core.exceptions import ImproperlyConfigured @@ -12,9 +13,15 @@ get_storage_provider, ) +if TYPE_CHECKING: + from wagtail_vector_index.storage.models import Document + StorageProviderClass = TypeVar("StorageProviderClass") ConfigClass = TypeVar("ConfigClass") IndexMixin = TypeVar("IndexMixin") +FromObjectType = TypeVar("FromObjectType", contravariant=True) +ChunkedObjectType = TypeVar("ChunkedObjectType", covariant=False) +ToObjectType = TypeVar("ToObjectType", covariant=True) class DocumentRetrievalVectorIndexMixinProtocol(Protocol): @@ -64,39 +71,86 @@ def __init_subclass__(cls, **kwargs: Any) -> None: return super().__init_subclass__(**kwargs) -@dataclass(kw_only=True, frozen=True) -class Document: - """Representation of some content that is passed to vector storage backends. +class FromDocumentOperator(Protocol[ToObjectType]): + """Protocol for a class that can convert a Document to an object""" - A document is usually a part of a model, e.g. some content split out from - a VectorIndexedMixin model. One model instance may have multiple documents. + def from_document(self, document: "Document") -> ToObjectType: ... + + def bulk_from_documents( + self, documents: Iterable["Document"] + ) -> Generator[ToObjectType, None, None]: ... + + async def abulk_from_documents( + self, documents: Iterable["Document"] + ) -> AsyncGenerator[ToObjectType, None]: ... + + +class ObjectChunkerOperator(Protocol[ChunkedObjectType]): + """Protocol for a class that can chunk an object into smaller chunks""" + + def chunk_object( + self, object: ChunkedObjectType, chunk_size: int + ) -> Iterable[ChunkedObjectType]: ... + + +class ToDocumentOperator(Protocol[FromObjectType]): + """Protocol for a class that can convert an object to a Document""" + + def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): ... + + def to_documents( + self, object: FromObjectType, *, embedding_backend: BaseEmbeddingBackend + ) -> Generator["Document", None, None]: ... + + def bulk_to_documents( + self, + objects: Iterable[FromObjectType], + *, + embedding_backend: BaseEmbeddingBackend, + ) -> Generator["Document", None, None]: ... - The embedding_pk on a Document must be the PK of an Embedding model instance. - """ - vector: Sequence[float] - embedding_pk: int - metadata: Mapping +class DocumentConverter(ABC): + """Base class for a DocumentConverter that can convert objects to Documents and vice versa""" + to_document_operator_class: Type[ToDocumentOperator] + from_document_operator_class: Type[FromDocumentOperator] + object_chunker_operator_class: Type[ObjectChunkerOperator] + + @property + def to_document_operator(self) -> ToDocumentOperator: + return self.to_document_operator_class(self.object_chunker_operator_class) + + @property + def from_document_operator(self) -> FromDocumentOperator: + return self.from_document_operator_class() -class DocumentConverter(Protocol): def to_documents( self, object: object, *, embedding_backend: BaseEmbeddingBackend - ) -> Generator[Document, None, None]: ... + ) -> Generator["Document", None, None]: + return self.to_document_operator.to_documents( + object, embedding_backend=embedding_backend + ) - def from_document(self, document: Document) -> object: ... + def from_document(self, document: "Document") -> object: + return self.from_document_operator.from_document(document) def bulk_to_documents( self, objects: Iterable[object], *, embedding_backend: BaseEmbeddingBackend - ) -> Generator[Document, None, None]: ... + ) -> Generator["Document", None, None]: + return self.to_document_operator.bulk_to_documents( + objects, embedding_backend=embedding_backend + ) def bulk_from_documents( - self, documents: Iterable[Document] - ) -> Generator[object, None, None]: ... + self, documents: Sequence["Document"] + ) -> Generator[object, None, None]: + return self.from_document_operator.bulk_from_documents(documents) - async def abulk_from_documents( - self, documents: Iterable[Document] - ) -> AsyncGenerator[object, None]: ... + def abulk_from_documents( + self, documents: Sequence["Document"] + ) -> AsyncGenerator[object, None]: + return self.from_document_operator.abulk_from_documents(documents) @dataclass @@ -129,7 +183,7 @@ class VectorIndex(Generic[ConfigClass]): def get_embedding_backend(self) -> BaseEmbeddingBackend: return get_embedding_backend(self.embedding_backend_alias) - def get_documents(self) -> Iterable[Document]: + def get_documents(self) -> Iterable["Document"]: raise NotImplementedError def get_converter(self) -> DocumentConverter: @@ -159,7 +213,7 @@ def query( sources = list(self.get_converter().bulk_from_documents(similar_documents)) - merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents) + merged_context = "\n".join(doc.content for doc in similar_documents) prompt = ( getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None) or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer." @@ -196,14 +250,10 @@ async def aquery( ) ] - sources = [ - source - async for source in self.get_converter().abulk_from_documents( - similar_documents - ) - ] + similar_objects = self.get_converter().abulk_from_documents(similar_documents) + sources = [source async for source in similar_objects] - merged_context = "\n".join(doc.metadata["content"] for doc in similar_documents) + merged_context = "\n".join([doc.content for doc in similar_documents]) prompt = ( getattr(settings, "WAGTAIL_VECTOR_INDEX_QUERY_PROMPT", None) or "You are a helpful assistant. Use the following context to answer the question. Don't mention the context in your answer." @@ -258,8 +308,10 @@ def search( query_embedding = next(self.get_embedding_backend().embed([query])) except StopIteration as e: raise ValueError("No embeddings were generated for the given query.") from e - similar_documents = self.get_similar_documents( - query_embedding, limit=limit, similarity_threshold=similarity_threshold + similar_documents = list( + self.get_similar_documents( + query_embedding, limit=limit, similarity_threshold=similarity_threshold + ) ) return list(self.get_converter().bulk_from_documents(similar_documents)) @@ -293,10 +345,10 @@ def get_similar_documents( *, limit: int = 5, similarity_threshold: float = 0.0, - ) -> Generator[Document, None, None]: + ) -> Generator["Document", None, None]: raise NotImplementedError def aget_similar_documents( self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0 - ) -> AsyncGenerator[Document, None]: + ) -> AsyncGenerator["Document", None]: raise NotImplementedError diff --git a/src/wagtail_vector_index/storage/django.py b/src/wagtail_vector_index/storage/django.py new file mode 100644 index 0000000..e26fac6 --- /dev/null +++ b/src/wagtail_vector_index/storage/django.py @@ -0,0 +1,610 @@ +import logging +from collections import defaultdict +from collections.abc import ( + AsyncGenerator, + Generator, + Iterable, + MutableSequence, + Sequence, +) +from itertools import chain, islice +from typing import ( + TYPE_CHECKING, + ClassVar, + Optional, + Type, + TypeAlias, + cast, +) + +from django.apps import apps +from django.core import checks +from django.core.exceptions import FieldDoesNotExist +from django.db import models, transaction +from django.utils.functional import classproperty # type: ignore +from wagtail.models import Page +from wagtail.query import PageQuerySet +from wagtail.search.index import BaseField + +from wagtail_vector_index.ai_utils.backends.base import BaseEmbeddingBackend +from wagtail_vector_index.ai_utils.text_splitting.langchain import ( + LangchainRecursiveCharacterTextSplitter, +) +from wagtail_vector_index.ai_utils.text_splitting.naive import ( + NaiveTextSplitterCalculator, +) +from wagtail_vector_index.ai_utils.types import TextSplitterProtocol +from wagtail_vector_index.storage import get_storage_provider, registry +from wagtail_vector_index.storage.base import ( + DocumentConverter, + DocumentRetrievalVectorIndexMixinProtocol, + FromDocumentOperator, + ObjectChunkerOperator, + ToDocumentOperator, + VectorIndex, +) +from wagtail_vector_index.storage.exceptions import IndexedTypeFromDocumentError +from wagtail_vector_index.storage.models import Document + +logger = logging.getLogger(__name__) + +""" Everything related to indexing Django models is in this file. + +This includes: + +- The Embedding Django model, which is used to store embeddings for model instances in the database +- The EmbeddableFieldsMixin, which is a mixin for Django models that lets user define which fields should be used to generate embeddings +- The EmbeddableFieldsVectorIndexMixin, which is a VectorIndex mixin that expects EmbeddableFieldsMixin models +- The EmbeddableFieldsDocumentConverter, which is a DocumentConverter that knows how to convert a model instance using the EmbeddableFieldsMixin protocol to and from a Document +""" + +ModelLabel: TypeAlias = str +ObjectId: TypeAlias = str + + +# If `batched` is not available (Python < 3.12), provide a fallback implementation +try: + from itertools import batched # type: ignore +except ImportError: + + def batched(iterable, n): + if n < 1: + raise ValueError("n must be at least one") + iterator = iter(iterable) + while batch := tuple(islice(iterator, n)): + yield batch + + +class ModelKey(str): + """A unique identifier for a model instance. + + The string is of the form ":". This can be used as the object_key + for Documents. + """ + + @classmethod + def from_instance(cls, instance: models.Model) -> "ModelKey": + return cls(f"{instance._meta.label}:{instance.pk}") + + @property + def model_label(self) -> ModelLabel: + return self.split(":")[0] + + @property + def object_id(self) -> ObjectId: + return self.split(":")[1] + + +# ########### +# Classes that allow users to automatically generate documents from their models based on fields specified +# ########### + + +class EmbeddingField(BaseField): + """A field that can be used to specify which fields of a model should be used to generate embeddings""" + + def __init__(self, *args, important=False, **kwargs): + self.important = important + super().__init__(*args, **kwargs) + + +class EmbeddableFieldsMixin(models.Model): + """Mixin for Django models that allows the user to specify which fields should be used to generate embeddings.""" + + embedding_fields = [] + + class Meta: + abstract = True + + @classmethod + def _get_embedding_fields(cls) -> list["EmbeddingField"]: + embedding_fields = { + (type(field), field.field_name): field for field in cls.embedding_fields + } + return list(embedding_fields.values()) + + @classmethod + def check(cls, **kwargs): + """Extend model checks to include validation of embedding_fields in the + same way that Wagtail's Indexed class does it.""" + errors = super().check(**kwargs) + errors.extend(cls._check_embedding_fields(**kwargs)) + return errors + + @classmethod + def _has_field(cls, name): + try: + cls._meta.get_field(name) + except FieldDoesNotExist: + return hasattr(cls, name) + else: + return True + + @classmethod + def _check_embedding_fields(cls, **kwargs): + errors = [] + for field in cls._get_embedding_fields(): + message = "{model}.embedding_fields contains non-existent field '{name}'" + if not cls._has_field(field.field_name): + errors.append( + checks.Warning( + message.format(model=cls.__name__, name=field.field_name), + obj=cls, + id="wagtailai.WA001", + ) + ) + return errors + + +class ModelFromDocumentOperator(FromDocumentOperator[models.Model]): + """A class that can convert Documents into model instances""" + + def from_document(self, document: Document) -> models.Model: + # Use the first key in the list, which is the most specific model class + key = ModelKey(document.object_keys[0]) + model_class = self._model_class_from_label(key.model_label) + try: + return model_class.objects.filter(pk=key.object_id).get() + except model_class.DoesNotExist as e: + raise IndexedTypeFromDocumentError("No object found for document") from e + + def bulk_from_documents( + self, documents: Sequence[Document] + ) -> Generator[models.Model, None, None]: + keys_by_model_label = self._get_keys_by_model_label(documents) + objects_by_key = self._get_models_by_key(keys_by_model_label) + + yield from self._get_deduplicated_objects_generator(documents, objects_by_key) + + async def abulk_from_documents( + self, documents: Sequence[Document] + ) -> AsyncGenerator[models.Model, None]: + """A copy of `bulk_from_documents`, but async""" + keys_by_model_label = self._get_keys_by_model_label(documents) + objects_by_key = await self._aget_models_by_key(keys_by_model_label) + + # N.B. `yield from` cannot be used in async functions, so we have to use a loop + for object_from_document in self._get_deduplicated_objects_generator( + documents, objects_by_key + ): + yield object_from_document + + @staticmethod + def _model_class_from_label(label: ModelLabel) -> type[models.Model]: + model_class = apps.get_model(label) + + if model_class is None: + raise ValueError(f"Failed to find model class for {label!r}") + + return model_class + + @staticmethod + def _get_keys_by_model_label( + documents: Sequence[Document], + ) -> dict[ModelLabel, list[ModelKey]]: + keys_by_model_label = defaultdict(list) + for doc in documents: + key = ModelKey(doc.object_keys[0]) + keys_by_model_label[key.model_label].append(key) + return keys_by_model_label + + @staticmethod + def _get_deduplicated_objects_generator( + documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] + ) -> Generator[models.Model, None, None]: + seen_keys = set() # de-dupe as we go + for doc in documents: + key = ModelKey(doc.object_keys[0]) + if key in seen_keys: + continue + seen_keys.add(key) + yield objects_by_key[key] + + @staticmethod + def _get_models_by_key(keys_by_model_label: dict) -> dict[ModelKey, models.Model]: + """ + ModelKey keys are required to reliably map data + from multiple models. This function loads the models from the database + and groups them by such a key. + """ + objects_by_key: dict[ModelKey, models.Model] = {} + for model_label, keys in keys_by_model_label.items(): + model_class = ModelFromDocumentOperator._model_class_from_label(model_label) + model_objects = model_class.objects.filter( + pk__in=[key.object_id for key in keys] + ) + objects_by_key.update( + {ModelKey.from_instance(obj): obj for obj in model_objects} + ) + return objects_by_key + + @staticmethod + async def _aget_models_by_key( + keys_by_model_label: dict, + ) -> dict[ModelKey, models.Model]: + """ + Same as `_get_models_by_key`, but async. + """ + objects_by_key: dict[ModelKey, models.Model] = {} + for model_label, keys in keys_by_model_label.items(): + model_class = ModelFromDocumentOperator._model_class_from_label(model_label) + model_objects = model_class.objects.filter( + pk__in=[key.object_id for key in keys] + ) + objects_by_key.update( + {ModelKey.from_instance(obj): obj async for obj in model_objects} + ) + return objects_by_key + + +class ModelToDocumentOperator(ToDocumentOperator[models.Model]): + """A class that can generate Documents from model instances""" + + def __init__(self, object_chunker_operator_class: Type[ObjectChunkerOperator]): + self.object_chunker_operator = object_chunker_operator_class() + + @staticmethod + def _existing_documents_match( + documents: Iterable[Document], splits: list[str] + ) -> bool: + """Determine whether the documents passed in match the text content passed in""" + if not documents: + return False + + document_content = {document.content for document in documents} + + return set(splits) == document_content + + @staticmethod + def _keys_for_instance(instance: models.Model) -> list[ModelKey]: + """Get keys for all the parent classes and the object itself in MRO order""" + parent_classes = instance._meta.get_parent_list() + keys = [ModelKey(f"{cls._meta.label}:{instance.pk}") for cls in parent_classes] + keys = [ModelKey.from_instance(instance), *keys] + return keys + + @transaction.atomic + def generate_documents( + self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend + ) -> list[Document]: + """Use the AI backend to generate and store Documents for this object""" + chunks = list( + self.object_chunker_operator.chunk_object( + object, chunk_size=embedding_backend.config.token_limit + ) + ) + documents = Document.objects.for_key(ModelKey(object)) + + # If the existing embeddings all match on content, we return them + # without generating new ones + if self._existing_documents_match(documents, chunks): + return list(documents) + + # Otherwise we delete all the existing Documents and get new ones + documents.delete() + + embedding_vectors = embedding_backend.embed(chunks) + generated_documents: MutableSequence[Document] = [] + for idx, returned_embedding in enumerate(embedding_vectors): + chunk = chunks[idx] + document = Document.objects.create( + object_keys=[str(key) for key in self._keys_for_instance(object)], + vector=returned_embedding, + content=chunk, + ) + generated_documents.append(document) + + return generated_documents + + @transaction.atomic + def bulk_generate_documents(self, objects, *, embedding_backend): + objects_by_key = {ModelKey.from_instance(obj): obj for obj in objects} + documents = Document.objects.for_keys(list(objects_by_key.keys())) + + documents_by_object_key = defaultdict(list) + for document in documents: + documents_by_object_key[document.object_keys[0]].append(document) + + objects_to_rebuild = {} + + # Maintain a list of object keys in the order they appear in the chunks + # so we can map the embeddings from the backend to the correct object + chunk_mapping = [] + + # Determine which objects need to be rebuilt + for key, object in objects_by_key.items(): + documents_for_object = documents_by_object_key[key] + chunks = list( + self.object_chunker_operator.chunk_object( + object, chunk_size=embedding_backend.config.token_limit + ) + ) + + if not self._existing_documents_match(documents_for_object, chunks): + objects_to_rebuild[key] = {"object": object, "chunks": chunks} + chunk_mapping += [key] * len(chunks) + + if not objects_to_rebuild: + return documents + + all_chunks = list( + chain(*[obj["chunks"] for obj in objects_to_rebuild.values()]) + ) + + embedding_vectors = list(embedding_backend.embed(all_chunks)) + documents_by_object = defaultdict(list) + + for idx, embedding in enumerate(embedding_vectors): + object_key = chunk_mapping[idx] + documents_by_object[object_key].append((idx, embedding)) + + existing_documents = Document.objects.for_keys(list(documents_by_object.keys())) + existing_documents.delete() + + for object_key, documents in documents_by_object.items(): + for idx, returned_embedding in documents: + all_keys = self._keys_for_instance(objects_by_key[object_key]) + chunk = all_chunks[idx] + Document.objects.create( + object_keys=all_keys, + vector=returned_embedding, + content=chunk, + ) + + # Return every document object, regardless of whether it was rebuilt, retaining + # the order they appeared in the original list + documents = list(Document.objects.for_keys(list(objects_by_key.keys()))) + return sorted( + documents, + key=lambda doc: list(objects_by_key.keys()).index( + ModelKey(doc.object_keys[0]) + ), + ) + + def to_documents( + self, object: models.Model, *, embedding_backend: BaseEmbeddingBackend + ) -> Generator[Document, None, None]: + yield from self.generate_documents(object, embedding_backend=embedding_backend) + + def bulk_to_documents( + self, + objects: Iterable[models.Model], + *, + batch_size: int = 100, + embedding_backend: BaseEmbeddingBackend, + ) -> Generator[Document, None, None]: + batches = list(batched(objects, batch_size)) + for idx, batch in enumerate(batches): + logger.info(f"Generating documents for batch {idx + 1} of {len(batches)}") + yield from self.bulk_generate_documents( + batch, embedding_backend=embedding_backend + ) + + +class EmbeddableFieldsObjectChunkerOperator( + ObjectChunkerOperator[EmbeddableFieldsMixin] +): + def chunk_object( + self, object: EmbeddableFieldsMixin, *, chunk_size: int + ) -> list[str]: + """Split the contents of a model instance's `embedding_fields` in to smaller chunks""" + splittable_content = [] + important_content = [] + embedding_fields = object._meta.model._get_embedding_fields() + + for field in embedding_fields: + value = field.get_value(object) + if value is None: + continue + if isinstance(value, str): + final_value = value + else: + final_value: str = "\n".join((str(v) for v in value)) + if field.important: + important_content.append(final_value) + else: + splittable_content.append(final_value) + + text = "\n".join(splittable_content) + important_text = "\n".join(important_content) + splitter = self._get_text_splitter_class(chunk_size=chunk_size) + return [f"{important_text}\n{text}" for text in splitter.split_text(text)] + + @staticmethod + def _get_text_splitter_class(chunk_size: int) -> TextSplitterProtocol: + length_calculator = NaiveTextSplitterCalculator() + return LangchainRecursiveCharacterTextSplitter( + chunk_size=chunk_size, + chunk_overlap=100, + length_function=length_calculator.get_splitter_length, + ) + + +class EmbeddableFieldsDocumentConverter(DocumentConverter): + """Implementation of DocumentConverter that knows how to convert a model instance using the + EmbeddableFieldsMixin to and from a Document. + """ + + to_document_operator_class = ModelToDocumentOperator + from_document_operator_class = ModelFromDocumentOperator + object_chunker_operator_class = EmbeddableFieldsObjectChunkerOperator + + +# ########### +# VectorIndex mixins which add model-specific behaviour +# ########### + +if TYPE_CHECKING: + MixinBase = DocumentRetrievalVectorIndexMixinProtocol +else: + MixinBase = object + + +class EmbeddableFieldsVectorIndexMixin(MixinBase): + """A Mixin for VectorIndex which indexes the results of querysets of EmbeddableFieldsMixin models""" + + querysets: ClassVar[Sequence[models.QuerySet]] + + def _get_querysets(self) -> Sequence[models.QuerySet]: + return self.querysets + + def get_converter_class(self) -> type[EmbeddableFieldsDocumentConverter]: + return EmbeddableFieldsDocumentConverter + + def get_converter(self) -> EmbeddableFieldsDocumentConverter: + return self.get_converter_class()() + + def get_documents(self) -> Iterable[Document]: + querysets = self._get_querysets() + all_documents = [] + + for queryset in querysets: + # We need to consume the generator here to ensure that the + # Embedding models are created, even if it is not consumed + # by the caller + all_documents += list( + self.get_converter().bulk_to_documents( + queryset, embedding_backend=self.get_embedding_backend() + ) + ) + return all_documents + + +class PageEmbeddableFieldsVectorIndexMixin(EmbeddableFieldsVectorIndexMixin): + """A mixin for VectorIndex for use with Wagtail pages that automatically + restricts indexed models to live pages.""" + + querysets: Sequence[PageQuerySet] + + def _get_querysets(self) -> list[PageQuerySet]: + qs_list = super()._get_querysets() + + # Technically a manager instance, not a queryset, but we want to use the custom + # methods. + return [cast(PageQuerySet, qs).live() for qs in qs_list] + + +# ########### +# Classes related to automatic generation of indexes for models +# ########### + + +def camel_case(snake_str: str): + """Convert a snake_case string to CamelCase""" + parts = snake_str.split("_") + return "".join(*map(str.title, parts)) + + +def build_vector_index_base_for_storage_provider( + storage_provider_alias: str = "default", +): + """Build a VectorIndex base class for a given storage provider alias. + + e.g. If WAGATAIL_VECTOR_INDEX_STORAGE_PROVIDERS includes a provider with alias "default" referencing the PgvectorStorageProvider, + this function will return a class that is a subclass of PgvectorIndexMixin and VectorIndex. + """ + + storage_provider = get_storage_provider(storage_provider_alias) + alias_camel = camel_case(storage_provider_alias) + return type( + f"{alias_camel}VectorIndex", (storage_provider.index_mixin, VectorIndex), {} + ) + + +# A VectorIndex built from whatever mixin belongs to the storage provider with the "default" alias +DefaultStorageVectorIndex = build_vector_index_base_for_storage_provider("default") + + +class GeneratedIndexMixin(models.Model): + """Mixin for Django models that automatically generates and registers a VectorIndex for the model. + + The model can still have custom VectorIndex classes registered if needed.""" + + vector_index_class: ClassVar[Optional[type[VectorIndex]]] = None + + class Meta: + abstract = True + + @classmethod + def generated_index_class_name(cls): + """Return the class name to be used for the index generated by this mixin""" + return f"{cls.__name__}Index" + + @classmethod + def build_vector_index(cls) -> VectorIndex: + """Build a VectorIndex instance for this model""" + + class_list = () + # If the user has specified a custom `vector_index_class`, use that + if cls.vector_index_class: + class_list = (cls.vector_index_class,) + else: + storage_provider = get_storage_provider("default") + base_cls = VectorIndex + storage_mixin_cls = storage_provider.index_mixin + # If the model is a Wagtail Page, use a special PageEmbeddableFieldsVectorIndexMixin + if issubclass(cls, Page): + mixin_cls = PageEmbeddableFieldsVectorIndexMixin + # Otherwise use the standard EmbeddableFieldsVectorIndexMixin + else: + mixin_cls = EmbeddableFieldsVectorIndexMixin + class_list = ( + mixin_cls, + storage_mixin_cls, + base_cls, + ) + + return cast( + type[VectorIndex], + type( + cls.generated_index_class_name(), + class_list, + { + "querysets": [cls.objects.all()], + }, + ), + )() + + @classproperty + def vector_index(cls): + """Get a vector index instance for this model""" + + return registry[cls.generated_index_class_name()] + + +class VectorIndexedMixin(EmbeddableFieldsMixin, GeneratedIndexMixin, models.Model): + """Model mixin which adds both the embeddable fields behaviour and the automatic index behaviour to a model.""" + + class Meta: + abstract = True + + +def register_indexed_models(): + """Discover and register all models that are a subclass of GeneratedIndexMixin.""" + indexed_models = [ + model + for model in apps.get_models() + if issubclass(model, GeneratedIndexMixin) and not model._meta.abstract + ] + for model in indexed_models: + registry.register_index(model.build_vector_index()) diff --git a/src/wagtail_vector_index/storage/models.py b/src/wagtail_vector_index/storage/models.py index e467541..a9dbd37 100644 --- a/src/wagtail_vector_index/storage/models.py +++ b/src/wagtail_vector_index/storage/models.py @@ -1,564 +1,53 @@ -from collections import defaultdict -from collections.abc import ( - AsyncGenerator, - Generator, - Iterable, - MutableSequence, - Sequence, -) -from typing import TYPE_CHECKING, ClassVar, Optional, TypeAlias, TypeVar, cast +import operator +from functools import reduce +from typing import cast -from django.apps import apps -from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation -from django.contrib.contenttypes.models import ContentType -from django.core import checks -from django.core.exceptions import FieldDoesNotExist -from django.db import models, transaction -from django.utils.functional import classproperty # type: ignore -from wagtail.models import Page -from wagtail.query import PageQuerySet -from wagtail.search.index import BaseField +from django.db import connection, models +from django.db.models import Q -from wagtail_vector_index.ai_utils.backends.base import BaseEmbeddingBackend -from wagtail_vector_index.ai_utils.text_splitting.langchain import ( - LangchainRecursiveCharacterTextSplitter, -) -from wagtail_vector_index.ai_utils.text_splitting.naive import ( - NaiveTextSplitterCalculator, -) -from wagtail_vector_index.ai_utils.types import TextSplitterProtocol -from wagtail_vector_index.storage import get_storage_provider, registry -from wagtail_vector_index.storage.base import ( - Document, - DocumentRetrievalVectorIndexMixinProtocol, - VectorIndex, -) -from wagtail_vector_index.storage.exceptions import IndexedTypeFromDocumentError -""" Everything related to indexing Django models is in this file. - -This includes: - -- The Embedding Django model, which is used to store embeddings for model instances in the database -- The EmbeddableFieldsMixin, which is a mixin for Django models that lets user define which fields should be used to generate embeddings -- The EmbeddableFieldsVectorIndexMixin, which is a VectorIndex mixin that expects EmbeddableFieldsMixin models -- The EmbeddableFieldsDocumentConverter, which is a DocumentConverter that knows how to convert a model instance using the EmbeddableFieldsMixin protocol to and from a Document -""" - -ContentTypeId: TypeAlias = str -ObjectId: TypeAlias = str -ModelKey: TypeAlias = tuple[ContentTypeId, ObjectId] - - -class Embedding(models.Model): - """Stores an embedding for a model instance""" - - content_type = models.ForeignKey( - ContentType, on_delete=models.CASCADE, related_name="+" - ) - content_type_id: int - base_content_type = models.ForeignKey( - ContentType, on_delete=models.CASCADE, related_name="+" - ) - object_id = models.CharField( - max_length=255, - ) - content_object = GenericForeignKey( - "content_type", "object_id", for_concrete_model=False - ) - vector = models.JSONField() - content = models.TextField() - - def __str__(self): - return f"Embedding for {self.object_id}" - - @classmethod - def _get_base_content_type(cls, model_or_object): - if parents := model_or_object._meta.get_parent_list(): - return ContentType.objects.get_for_model( - parents[-1], for_concrete_model=False - ) +class DocumentQuerySet(models.QuerySet): + def for_key(self, object_key: str): + if connection.vendor != "sqlite": + return self.filter(object_keys__contains=[object_key]) else: - return ContentType.objects.get_for_model( - model_or_object, for_concrete_model=False - ) - - @classmethod - def from_instance(cls, instance: models.Model) -> "Embedding": - """Create an Embedding instance for a model instance""" - content_type = ContentType.objects.get_for_model(instance) - return Embedding( - content_type=content_type, - base_content_type=cls._get_base_content_type(instance), - object_id=instance.pk, - ) - - @classmethod - def get_for_instance(cls, instance: models.Model): - """Get all Embedding instances that are related to a model instance""" - content_type = ContentType.objects.get_for_model(instance) - return Embedding.objects.filter( - content_type=content_type, object_id=instance.pk - ) - - def to_document(self) -> Document: - return Document( - vector=self.vector, - embedding_pk=self.pk, - metadata={ - "object_id": str(self.object_id), - "content_type_id": str(self.content_type_id), - "content": self.content, - }, - ) - - -# ########### -# Classes that allow users to automatically generate documents from their models based on fields specified -# ########### - - -class EmbeddingField(BaseField): - """A field that can be used to specify which fields of a model should be used to generate embeddings""" - - def __init__(self, *args, important=False, **kwargs): - self.important = important - super().__init__(*args, **kwargs) - - -class EmbeddableFieldsMixin(models.Model): - """Mixin for Django models that allows the user to specify which fields should be used to generate embeddings.""" + # SQLite doesn't support the __contains lookup for JSON fields + # so we use icontains which just does a string search + return self.filter(object_keys__icontains=object_key) - embedding_fields = [] - embeddings = GenericRelation( - Embedding, content_type_field="content_type", for_concrete_model=False - ) - - class Meta: - abstract = True + def for_keys(self, object_keys: list[str]): + q_objs = [Q(object_keys__icontains=object_key) for object_key in object_keys] + return self.filter(reduce(operator.or_, q_objs)) @classmethod - def _get_embedding_fields(cls) -> list["EmbeddingField"]: - embedding_fields = { - (type(field), field.field_name): field for field in cls.embedding_fields - } - return list(embedding_fields.values()) + def as_manager(cls) -> "DocumentManager": + return cast(DocumentManager, super().as_manager()) - @classmethod - def check(cls, **kwargs): - """Extend model checks to include validation of embedding_fields in the - same way that Wagtail's Indexed class does it.""" - errors = super().check(**kwargs) - errors.extend(cls._check_embedding_fields(**kwargs)) - return errors - @classmethod - def _has_field(cls, name): - try: - cls._meta.get_field(name) - except FieldDoesNotExist: - return hasattr(cls, name) - else: - return True +class DocumentManager(models.Manager["Document"]): + # Workaround for typing issues + def for_key(self, object_key: str) -> DocumentQuerySet: ... - @classmethod - def _check_embedding_fields(cls, **kwargs): - errors = [] - for field in cls._get_embedding_fields(): - message = "{model}.embedding_fields contains non-existent field '{name}'" - if not cls._has_field(field.field_name): - errors.append( - checks.Warning( - message.format(model=cls.__name__, name=field.field_name), - obj=cls, - id="wagtailai.WA001", - ) - ) - return errors + def for_keys(self, object_keys: list[str]) -> DocumentQuerySet: ... -IndexedType = TypeVar("IndexedType") +class Document(models.Model): + """Stores an embedding for an arbitrary chunk""" + object_keys = models.JSONField(default=list) + vector = models.JSONField() + content = models.TextField() + metadata = models.JSONField(default=dict) -class DocumentToModelMixin: - """A mixin for DocumentConverter classes that need to efficiently convert Documents - into model instances of the relevant type. - """ + objects: DocumentManager = DocumentQuerySet.as_manager() - @staticmethod - def _model_class_from_ctid(id: str) -> type[models.Model]: - ct = ContentType.objects.get_for_id(int(id)) - model_class = ct.model_class() - if model_class is None: - raise ValueError(f"Failed to find model class for {ct!r}") - return model_class + def __str__(self): + keys = ", ".join(self.object_keys) + return f"Document for {keys}" @classmethod - async def _amodel_class_from_ctid(cls, id: str) -> type[models.Model]: - ct = await cls._aget_content_type_for_id(int(id)) - model_class = ct.model_class() - if model_class is None: - raise ValueError(f"Failed to find model class for {ct!r}") - return model_class - - def from_document(self, document: Document) -> models.Model: - model_class = self._model_class_from_ctid(document.metadata["content_type_id"]) - try: - return model_class.objects.filter(pk=document.metadata["object_id"]).get() - except model_class.DoesNotExist as e: - raise IndexedTypeFromDocumentError("No object found for document") from e - - def bulk_from_documents( - self, documents: Iterable[Document] - ) -> Generator[models.Model, None, None]: - documents = tuple(documents) - - ids_by_content_type = self._get_ids_by_content_type(documents) - objects_by_key = self._get_models_by_key(ids_by_content_type) - - yield from self._get_deduplicated_objects_generator(documents, objects_by_key) - - async def abulk_from_documents( - self, documents: Iterable[Document] - ) -> AsyncGenerator[models.Model, None]: - """A copy of `bulk_from_documents`, but async""" - # Force evaluate generators to allow value to be reused - documents = tuple(documents) - - ids_by_content_type = self._get_ids_by_content_type(documents) - objects_by_key = await self._aget_models_by_key(ids_by_content_type) - - # N.B. `yield from` cannot be used in async functions, so we have to use a loop - for object_from_document in self._get_deduplicated_objects_generator( - documents, objects_by_key - ): - yield object_from_document - - @staticmethod - def _get_ids_by_content_type( - documents: Sequence[Document], - ) -> dict[ContentTypeId, list[ObjectId]]: - ids_by_content_type = defaultdict(list) - for doc in documents: - ids_by_content_type[doc.metadata["content_type_id"]].append( - doc.metadata["object_id"] - ) - return ids_by_content_type - - @staticmethod - def _get_deduplicated_objects_generator( - documents: Sequence[Document], objects_by_key: dict[ModelKey, models.Model] - ) -> Generator[models.Model, None, None]: - seen_keys = set() # de-dupe as we go - for doc in documents: - key = (doc.metadata["content_type_id"], doc.metadata["object_id"]) - if key in seen_keys: - continue - seen_keys.add(key) - yield objects_by_key[key] - - def _get_models_by_key( - self, ids_by_content_type: dict - ) -> dict[ModelKey, models.Model]: - """ - (content_type_id, object_id) combo keys are required to reliably map data - from multiple models. This function loads the models from the database - and groups them by such a key. - """ - objects_by_key: dict[ModelKey, models.Model] = {} - for content_type_id, ids in ids_by_content_type.items(): - model_class = self._model_class_from_ctid(content_type_id) - model_objects = model_class.objects.filter(pk__in=ids) - objects_by_key.update( - {(content_type_id, str(obj.pk)): obj for obj in model_objects} - ) - return objects_by_key - - async def _aget_models_by_key( - self, ids_by_content_type: dict - ) -> dict[ModelKey, models.Model]: - """ - Same as `_get_models_by_key`, but async. - """ - objects_by_key: dict[ModelKey, models.Model] = {} - for content_type_id, ids in ids_by_content_type.items(): - model_class = await self._amodel_class_from_ctid(content_type_id) - model_objects = model_class.objects.filter(pk__in=ids) - objects_by_key.update( - {(content_type_id, str(obj.pk)): obj async for obj in model_objects} - ) - return objects_by_key - - @staticmethod - async def _aget_content_type_for_id(id: int) -> ContentType: - """ - Same as `ContentTypeManager.get_for_id`, but async. - """ - manager = ContentType.objects - try: - ct = manager._cache[manager.db][id] # type: ignore[reportAttributeAccessIssue] - except KeyError: - ct = await manager.aget(pk=id) - manager._add_to_cache(manager.db, ct) # type: ignore[reportAttributeAccessIssue] - return ct - - -class EmbeddableFieldsDocumentConverter(DocumentToModelMixin): - """Implementation of DocumentConverter that knows how to convert a model instance using the - EmbeddableFieldsMixin to and from a Document. - - Stores and retrieves embeddings from an Embedding model.""" - - def _get_split_content( - self, object: EmbeddableFieldsMixin, *, chunk_size: int - ) -> list[str]: - """Split the contents of a model instance's `embedding_fields` in to smaller chunks""" - splittable_content = [] - important_content = [] - embedding_fields = object._meta.model._get_embedding_fields() - - for field in embedding_fields: - value = field.get_value(object) - if value is None: - continue - if isinstance(value, str): - final_value = value - else: - final_value: str = "\n".join((str(v) for v in value)) - if field.important: - important_content.append(final_value) - else: - splittable_content.append(final_value) - - text = "\n".join(splittable_content) - important_text = "\n".join(important_content) - splitter = self._get_text_splitter_class(chunk_size=chunk_size) - return [f"{important_text}\n{text}" for text in splitter.split_text(text)] - - def _get_text_splitter_class(self, chunk_size: int) -> TextSplitterProtocol: - length_calculator = NaiveTextSplitterCalculator() - return LangchainRecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=100, - length_function=length_calculator.get_splitter_length, - ) - - def _existing_embeddings_match( - self, embeddings: Iterable[Embedding], splits: list[str] - ) -> bool: - """Determine whether the embeddings passed in match the text content passed in""" - if not embeddings: - return False - - embedding_content = {embedding.content for embedding in embeddings} - - return set(splits) == embedding_content - - @transaction.atomic - def generate_embeddings( - self, object: EmbeddableFieldsMixin, *, embedding_backend: BaseEmbeddingBackend - ) -> list[Embedding]: - """Use the AI backend to generate and store embeddings for this object""" - splits = self._get_split_content( - object, chunk_size=embedding_backend.config.token_limit + def from_keys(cls, object_keys: list[str]) -> "Document": + """Create a Document instance for a list of object keys""" + return Document( + object_keys=object_keys, ) - embeddings = Embedding.get_for_instance(object) - - # If the existing embeddings all match on content, we return them - # without generating new ones - if self._existing_embeddings_match(embeddings, splits): - return list(embeddings) - - # Otherwise we delete all the existing embeddings and get new ones - embeddings.delete() - - embedding_vectors = embedding_backend.embed(splits) - generated_embeddings: MutableSequence[Embedding] = [] - for idx, returned_embedding in enumerate(embedding_vectors): - split = splits[idx] - embedding = Embedding.from_instance(object) - embedding.vector = returned_embedding - embedding.content = split - embedding.save() - generated_embeddings.append(embedding) - - return generated_embeddings - - def to_documents( - self, object: EmbeddableFieldsMixin, *, embedding_backend: BaseEmbeddingBackend - ) -> Generator[Document, None, None]: - for embedding in self.generate_embeddings( - object, embedding_backend=embedding_backend - ): - yield embedding.to_document() - - def bulk_to_documents( - self, - objects: Iterable[EmbeddableFieldsMixin], - *, - embedding_backend: BaseEmbeddingBackend, - ) -> Generator[Document, None, None]: - # TODO: Implement a more efficient bulk embedding approach - for object in objects: - yield from self.to_documents(object, embedding_backend=embedding_backend) - - -# ########### -# VectorIndex mixins which add model-specific behaviour -# ########### - -if TYPE_CHECKING: - MixinBase = DocumentRetrievalVectorIndexMixinProtocol -else: - MixinBase = object - - -class EmbeddableFieldsVectorIndexMixin(MixinBase): - """A Mixin for VectorIndex which indexes the results of querysets of EmbeddableFieldsMixin models""" - - querysets: ClassVar[Sequence[models.QuerySet]] - - def _get_querysets(self) -> Sequence[models.QuerySet]: - return self.querysets - - def get_converter_class(self) -> type[EmbeddableFieldsDocumentConverter]: - return EmbeddableFieldsDocumentConverter - - def get_converter(self) -> EmbeddableFieldsDocumentConverter: - return self.get_converter_class()() - - def get_documents(self) -> Iterable[Document]: - querysets = self._get_querysets() - all_documents = [] - - for queryset in querysets: - instances = queryset.prefetch_related("embeddings") - # We need to consume the generator here to ensure that the - # Embedding models are created, even if it is not consumed - # by the caller - all_documents += list( - self.get_converter().bulk_to_documents( - instances, embedding_backend=self.get_embedding_backend() - ) - ) - return all_documents - - -class PageEmbeddableFieldsVectorIndexMixin(EmbeddableFieldsVectorIndexMixin): - """A mixin for VectorIndex for use with Wagtail pages that automatically - restricts indexed models to live pages.""" - - querysets: Sequence[PageQuerySet] - - def _get_querysets(self) -> list[PageQuerySet]: - qs_list = super()._get_querysets() - - # Technically a manager instance, not a queryset, but we want to use the custom - # methods. - return [cast(PageQuerySet, qs).live() for qs in qs_list] - - -# ########### -# Classes related to automatic generation of indexes for models -# ########### - - -def camel_case(snake_str: str): - """Convert a snake_case string to CamelCase""" - parts = snake_str.split("_") - return "".join(*map(str.title, parts)) - - -def build_vector_index_base_for_storage_provider( - storage_provider_alias: str = "default", -): - """Build a VectorIndex base class for a given storage provider alias. - - e.g. If WAGATAIL_VECTOR_INDEX_STORAGE_PROVIDERS includes a provider with alias "default" referencing the PgvectorStorageProvider, - this function will return a class that is a subclass of PgvectorIndexMixin and VectorIndex. - """ - - storage_provider = get_storage_provider(storage_provider_alias) - alias_camel = camel_case(storage_provider_alias) - return type( - f"{alias_camel}VectorIndex", (storage_provider.index_mixin, VectorIndex), {} - ) - - -# A VectorIndex built from whatever mixin belongs to the storage provider with the "default" alias -DefaultStorageVectorIndex = build_vector_index_base_for_storage_provider("default") - - -class GeneratedIndexMixin(models.Model): - """Mixin for Django models that automatically generates and registers a VectorIndex for the model. - - The model can still have custom VectorIndex classes registered if needed.""" - - vector_index_class: ClassVar[Optional[type[VectorIndex]]] = None - - class Meta: - abstract = True - - @classmethod - def generated_index_class_name(cls): - """Return the class name to be used for the index generated by this mixin""" - return f"{cls.__name__}Index" - - @classmethod - def build_vector_index(cls) -> VectorIndex: - """Build a VectorIndex instance for this model""" - - class_list = () - # If the user has specified a custom `vector_index_class`, use that - if cls.vector_index_class: - class_list = (cls.vector_index_class,) - else: - storage_provider = get_storage_provider("default") - base_cls = VectorIndex - storage_mixin_cls = storage_provider.index_mixin - # If the model is a Wagtail Page, use a special PageEmbeddableFieldsVectorIndexMixin - if issubclass(cls, Page): - mixin_cls = PageEmbeddableFieldsVectorIndexMixin - # Otherwise use the standard EmbeddableFieldsVectorIndexMixin - else: - mixin_cls = EmbeddableFieldsVectorIndexMixin - class_list = ( - mixin_cls, - storage_mixin_cls, - base_cls, - ) - - return cast( - type[VectorIndex], - type( - cls.generated_index_class_name(), - class_list, - { - "querysets": [cls.objects.all()], - }, - ), - )() - - @classproperty - def vector_index(cls): - """Get a vector index instance for this model""" - - return registry[cls.generated_index_class_name()] - - -class VectorIndexedMixin(EmbeddableFieldsMixin, GeneratedIndexMixin, models.Model): - """Model mixin which adds both the embeddable fields behaviour and the automatic index behaviour to a model.""" - - class Meta: - abstract = True - - -def register_indexed_models(): - """Discover and register all models that are a subclass of GeneratedIndexMixin.""" - indexed_models = [ - model - for model in apps.get_models() - if issubclass(model, GeneratedIndexMixin) and not model._meta.abstract - ] - for model in indexed_models: - registry.register_index(model.build_vector_index()) diff --git a/src/wagtail_vector_index/storage/numpy/provider.py b/src/wagtail_vector_index/storage/numpy/provider.py index 455c917..02de3f5 100644 --- a/src/wagtail_vector_index/storage/numpy/provider.py +++ b/src/wagtail_vector_index/storage/numpy/provider.py @@ -6,7 +6,6 @@ import numpy as np from wagtail_vector_index.storage.base import ( - Document, StorageProvider, StorageVectorIndexMixinProtocol, ) @@ -19,6 +18,8 @@ class ProviderConfig: ... if TYPE_CHECKING: + from wagtail_vector_index.storage.models import Document + MixinBase = StorageVectorIndexMixinProtocol["NumpyStorageProvider"] else: MixinBase = object @@ -28,7 +29,7 @@ class NumpyIndexMixin(MixinBase): def rebuild_index(self) -> None: self.get_documents() - def upsert(self, *, documents: Iterable[Document]) -> None: + def upsert(self, *, documents: Iterable["Document"]) -> None: pass def delete(self, *, document_ids: Sequence[str]) -> None: @@ -40,7 +41,7 @@ def get_similar_documents( *, limit: int = 5, similarity_threshold: float = 0.0, - ) -> Generator[Document, None, None]: + ) -> Generator["Document", None, None]: similarities = [] for document in self.get_documents(): cosine_similarity = ( diff --git a/src/wagtail_vector_index/storage/pgvector/migrations/0002_initial.py b/src/wagtail_vector_index/storage/pgvector/migrations/0002_initial.py index dd4810a..c720140 100644 --- a/src/wagtail_vector_index/storage/pgvector/migrations/0002_initial.py +++ b/src/wagtail_vector_index/storage/pgvector/migrations/0002_initial.py @@ -37,7 +37,7 @@ class Migration(migrations.Migration): models.ForeignKey( on_delete=django.db.models.deletion.CASCADE, related_name="+", - to="wagtail_vector_index.embedding", + to="wagtail_vector_index.document", ), ), ], diff --git a/src/wagtail_vector_index/storage/pgvector/migrations/0003_alter_pgvectorembedding_embedding.py b/src/wagtail_vector_index/storage/pgvector/migrations/0003_alter_pgvectorembedding_embedding.py new file mode 100644 index 0000000..34e4510 --- /dev/null +++ b/src/wagtail_vector_index/storage/pgvector/migrations/0003_alter_pgvectorembedding_embedding.py @@ -0,0 +1,26 @@ +# Generated by Django 5.0.1 on 2024-09-06 10:26 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("pgvector", "0002_initial"), + ( + "wagtail_vector_index", + "0002_rename_embedding_model", + ), + ] + + operations = [ + migrations.AlterField( + model_name="pgvectorembedding", + name="embedding", + field=models.ForeignKey( + on_delete=django.db.models.deletion.CASCADE, + related_name="+", + to="wagtail_vector_index.document", + ), + ), + ] diff --git a/src/wagtail_vector_index/storage/pgvector/migrations/0004_rename_pgvector_embedding_col.py b/src/wagtail_vector_index/storage/pgvector/migrations/0004_rename_pgvector_embedding_col.py new file mode 100644 index 0000000..c14d0af --- /dev/null +++ b/src/wagtail_vector_index/storage/pgvector/migrations/0004_rename_pgvector_embedding_col.py @@ -0,0 +1,29 @@ +# Generated by Django 5.0.1 on 2024-09-13 12:32 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("pgvector", "0003_alter_pgvectorembedding_embedding"), + ("wagtail_vector_index", "0003_adjust_document_fields"), + ] + + operations = [ + migrations.RemoveConstraint( + model_name="pgvectorembedding", + name="unique_pgvector_embedding_per_index_and_dimensions", + ), + migrations.RenameField( + model_name="pgvectorembedding", + old_name="embedding", + new_name="document", + ), + migrations.AddConstraint( + model_name="pgvectorembedding", + constraint=models.UniqueConstraint( + fields=("document", "index_name", "embedding_output_dimensions"), + name="unique_pgvector_embedding_per_index_and_dimensions", + ), + ), + ] diff --git a/src/wagtail_vector_index/storage/pgvector/models.py b/src/wagtail_vector_index/storage/pgvector/models.py index 0ebb92b..eaec952 100644 --- a/src/wagtail_vector_index/storage/pgvector/models.py +++ b/src/wagtail_vector_index/storage/pgvector/models.py @@ -63,8 +63,8 @@ class PgvectorEmbeddingManager(models.Manager.from_queryset(PgvectorEmbeddingQue class PgvectorEmbedding(models.Model): - embedding = models.ForeignKey( - "wagtail_vector_index.Embedding", on_delete=models.CASCADE, related_name="+" + document = models.ForeignKey( + "wagtail_vector_index.Document", on_delete=models.CASCADE, related_name="+" ) vector = VectorField() embedding_output_dimensions = models.PositiveIntegerField(db_index=True) @@ -75,7 +75,7 @@ class PgvectorEmbedding(models.Model): class Meta: constraints = [ models.UniqueConstraint( - fields=["embedding", "index_name", "embedding_output_dimensions"], + fields=["document", "index_name", "embedding_output_dimensions"], name="unique_pgvector_embedding_per_index_and_dimensions", ) ] @@ -87,4 +87,4 @@ class Meta: # https://github.com/pgvector/pgvector-python/tree/master#django def __str__(self) -> str: - return "pgvector embedding for {}".format(self.embedding) + return "pgvector embedding for {}".format(self.document) diff --git a/src/wagtail_vector_index/storage/pgvector/provider.py b/src/wagtail_vector_index/storage/pgvector/provider.py index e5b0b13..cd95149 100644 --- a/src/wagtail_vector_index/storage/pgvector/provider.py +++ b/src/wagtail_vector_index/storage/pgvector/provider.py @@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, cast from wagtail_vector_index.storage.base import ( - Document, StorageProvider, StorageVectorIndexMixinProtocol, ) @@ -18,6 +17,8 @@ from .types import DistanceMethod if TYPE_CHECKING: + from wagtail_vector_index.storage.models import Document + from .models import PgvectorEmbedding, PgvectorEmbeddingQuerySet MixinBase = StorageVectorIndexMixinProtocol["PgvectorStorageProvider"] @@ -61,7 +62,7 @@ def rebuild_index(self) -> None: self.clear() self.upsert(documents=self.get_documents()) - def upsert(self, *, documents: Iterable[Document]) -> None: + def upsert(self, *, documents: Iterable["Document"]) -> None: counter = 0 objs_to_create: MutableSequence["PgvectorEmbedding"] = [] for document in documents: @@ -81,21 +82,19 @@ def clear(self): def get_similar_documents( self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0 - ) -> Generator[Document, None, None]: + ) -> Generator["Document", None, None]: for pgvector_embedding in self._get_similar_documents_queryset( query_vector, limit=limit, similarity_threshold=similarity_threshold ).iterator(): - embedding = pgvector_embedding.embedding - yield embedding.to_document() + yield pgvector_embedding.document async def aget_similar_documents( self, query_vector, *, limit: int = 5, similarity_threshold: float = 0.0 - ) -> AsyncGenerator[Document, None]: + ) -> AsyncGenerator["Document", None]: async for pgvector_embedding in self._get_similar_documents_queryset( query_vector, limit=limit, similarity_threshold=similarity_threshold ): - embedding = pgvector_embedding.embedding - yield embedding.to_document() + yield pgvector_embedding.document def _get_queryset(self) -> "PgvectorEmbeddingQuerySet": # objects is technically a Manager instance but we want to use the custom @@ -109,7 +108,7 @@ def _get_similar_documents_queryset( ) -> "PgvectorEmbeddingQuerySet": queryset = ( self._get_queryset() - .select_related("embedding") + .select_related("document") .filter(embedding_output_dimensions=len(query_vector)) .order_by_distance( query_vector, @@ -130,9 +129,9 @@ def _bulk_create(self, embeddings: Sequence["PgvectorEmbedding"]) -> None: ignore_conflicts=self.bulk_create_ignore_conflicts, ) - def _document_to_embedding(self, document: Document) -> "PgvectorEmbedding": + def _document_to_embedding(self, document: "Document") -> "PgvectorEmbedding": return _embedding_model()( - embedding_id=document.embedding_pk, + document_id=document.pk, embedding_output_dimensions=len(document.vector), vector=document.vector, index_name=type(self).__name__, diff --git a/src/wagtail_vector_index/storage/qdrant/provider.py b/src/wagtail_vector_index/storage/qdrant/provider.py index 504658b..6ac479b 100644 --- a/src/wagtail_vector_index/storage/qdrant/provider.py +++ b/src/wagtail_vector_index/storage/qdrant/provider.py @@ -44,7 +44,7 @@ def rebuild_index(self) -> None: def upsert(self, *, documents: Iterable[Document]) -> None: points = [ qdrant_models.PointStruct( - id=document.embedding_pk, + id=document.pk, vector=document.vector, payload=document.metadata, ) diff --git a/src/wagtail_vector_index/storage/weaviate/provider.py b/src/wagtail_vector_index/storage/weaviate/provider.py index 91fbc37..12659ac 100644 --- a/src/wagtail_vector_index/storage/weaviate/provider.py +++ b/src/wagtail_vector_index/storage/weaviate/provider.py @@ -52,7 +52,7 @@ def upsert(self, *, documents: Iterable[Document]) -> None: batch.add_data_object( { "metadata": json.dumps(document.metadata), - "embedding_pk": document.embedding_pk, + "embedding_pk": document.pk, }, self.index_name, vector=document.vector, diff --git a/tests/factories.py b/tests/factories.py index dc193ea..fdbe608 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -2,7 +2,7 @@ import wagtail_factories from faker import Faker from testapp.models import DifferentPage, ExampleModel, ExamplePage -from wagtail_vector_index.storage.models import Embedding +from wagtail_vector_index.storage.models import Document fake = Faker() @@ -30,9 +30,10 @@ class Meta: body = factory.LazyFunction(lambda: "\n".join(fake.paragraphs())) -class EmbeddingFactory(factory.django.DjangoModelFactory): +class DocumentFactory(factory.django.DjangoModelFactory): class Meta: - model = Embedding + model = Document vector = factory.LazyFunction(lambda: [fake.pyfloat() for _ in range(300)]) content = factory.LazyFunction(lambda: "\n".join(fake.paragraphs())) + object_keys = factory.Iterator([fake.uuid4() for _ in range(3)]) diff --git a/tests/test_django_converter.py b/tests/test_django_converter.py new file mode 100644 index 0000000..768cd91 --- /dev/null +++ b/tests/test_django_converter.py @@ -0,0 +1,313 @@ +import factory +import pytest +from factories import ( + DifferentPageFactory, + DocumentFactory, + ExampleModelFactory, + ExamplePageFactory, +) +from faker import Faker +from testapp.models import DifferentPage, ExamplePage +from wagtail_vector_index.ai import get_embedding_backend +from wagtail_vector_index.storage.django import ( + EmbeddableFieldsDocumentConverter, + EmbeddableFieldsObjectChunkerOperator, + EmbeddingField, + ModelFromDocumentOperator, + ModelLabel, + ModelToDocumentOperator, +) + +fake = Faker() + + +class TestChunking: + def test_get_chunks_splits_content_into_multiple_chunks( + self, patch_embedding_fields + ): + with patch_embedding_fields(ExamplePage, [EmbeddingField("body")]): + body = fake.text(max_nb_chars=1000) + instance = ExamplePageFactory.build(title="Important Title", body=body) + chunker = EmbeddableFieldsObjectChunkerOperator() + chunks = chunker.chunk_object(instance, chunk_size=100) + assert len(chunks) > 1 + + def test_get_chunks_adds_important_field_to_each_chunk( + self, patch_embedding_fields + ): + with patch_embedding_fields( + ExamplePage, + [EmbeddingField("title", important=True), EmbeddingField("body")], + ): + body = fake.text(max_nb_chars=200) + instance = ExamplePageFactory.build(title="Important Title", body=body) + chunker = EmbeddableFieldsObjectChunkerOperator() + chunks = chunker.chunk_object(instance, chunk_size=100) + assert all(chunk.startswith(instance.title) for chunk in chunks) + + +class TestFromDocument: + def test_extract_model_class_from_label(self): + label = ModelLabel("testapp.ExamplePage") + model_class = ModelFromDocumentOperator._model_class_from_label(label) + assert model_class == ExamplePage + + @pytest.mark.django_db + def test_get_keys_by_model_label(self): + example_pages = ExamplePageFactory.create_batch(3) + different_pages = DifferentPageFactory.create_batch(3) + documents = DocumentFactory.create_batch( + 6, + object_keys=factory.Iterator( + [[f"testapp.ExamplePage:{page.pk}"] for page in example_pages] + + [[f"testapp.DifferentPage:{page.pk}"] for page in different_pages] + ), + ) + keys_by_model_label = ModelFromDocumentOperator._get_keys_by_model_label( + documents + ) + + assert len(keys_by_model_label) == 2 + assert "testapp.ExamplePage" in keys_by_model_label + assert "testapp.DifferentPage" in keys_by_model_label + assert len(keys_by_model_label["testapp.ExamplePage"]) == 3 + assert len(keys_by_model_label["testapp.DifferentPage"]) == 3 + + @pytest.mark.django_db + def test_get_models_by_key(self): + example_pages = ExamplePageFactory.create_batch(3) + different_pages = DifferentPageFactory.create_batch(3) + documents = DocumentFactory.create_batch( + 6, + object_keys=factory.Iterator( + [[f"testapp.ExamplePage:{page.pk}"] for page in example_pages] + + [[f"testapp.DifferentPage:{page.pk}"] for page in different_pages] + ), + ) + keys_by_model_label = ModelFromDocumentOperator._get_keys_by_model_label( + documents + ) + models_by_key = ModelFromDocumentOperator._get_models_by_key( + keys_by_model_label + ) + assert len(models_by_key) == 6 + assert all( + isinstance(model, (ExamplePage, DifferentPage)) + for model in models_by_key.values() + ) + + @pytest.mark.django_db + def test_from_document_returns_model_object(self): + instance = ExamplePageFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + document = DocumentFactory.create( + object_keys=[f"testapp.ExamplePage:{instance.pk}"], + ) + operator = ModelFromDocumentOperator() + recovered_instance = operator.from_document(document) + assert isinstance(recovered_instance, ExamplePage) + assert recovered_instance.pk == instance.pk + + @pytest.mark.django_db + def test_bulk_from_documents_returns_model_objects(self): + instances = ExamplePageFactory.create_batch(3) + documents = DocumentFactory.create_batch( + 3, + object_keys=factory.Iterator( + [[f"testapp.ExamplePage:{page.pk}"] for page in instances] + ), + ) + operator = ModelFromDocumentOperator() + recovered_instances = list(operator.bulk_from_documents(documents)) + assert len(recovered_instances) == 3 + assert all( + isinstance(instance, ExamplePage) for instance in recovered_instances + ) + assert all( + instance.pk in [page.pk for page in instances] + for instance in recovered_instances + ) + + @pytest.mark.django_db + def test_bulk_from_documents_returns_model_objects_in_order(self): + instances = ExamplePageFactory.create_batch(3) + documents = DocumentFactory.create_batch( + 3, + object_keys=factory.Iterator( + [[f"testapp.ExamplePage:{page.pk}"] for page in instances] + ), + ) + operator = ModelFromDocumentOperator() + recovered_instances = list(operator.bulk_from_documents(documents)) + assert recovered_instances == instances + + @pytest.mark.django_db + def test_bulk_from_documents_returns_model_objects_for_multiple_models(self): + example_pages = ExamplePageFactory.create_batch(3) + different_pages = DifferentPageFactory.create_batch(3) + documents = DocumentFactory.create_batch( + 6, + object_keys=factory.Iterator( + [[f"testapp.ExamplePage:{page.pk}"] for page in example_pages] + + [[f"testapp.DifferentPage:{page.pk}"] for page in different_pages] + ), + ) + operator = ModelFromDocumentOperator() + recovered_instances = list(operator.bulk_from_documents(documents)) + assert len(recovered_instances) == 6 + assert all( + isinstance(instance, (ExamplePage, DifferentPage)) + for instance in recovered_instances + ) + assert all( + instance.pk in [page.pk for page in example_pages + different_pages] + for instance in recovered_instances + ) + + @pytest.mark.django_db + def test_bulk_from_documents_returns_deduplicated_model_objects(self): + instance = ExamplePageFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + documents = DocumentFactory.create_batch( + 3, + object_keys=[f"testapp.ExamplePage:{instance.pk}"], + ) + operator = ModelFromDocumentOperator() + recovered_instances = list(operator.bulk_from_documents(documents)) + assert len(recovered_instances) == 1 + assert recovered_instances[0].pk == instance.pk + + +class TestToDocument: + def test_existing_documents_match(self): + text_contents = ["This is a test", "Another test", "More testing content"] + documents = [ + DocumentFactory.build(content=content) for content in text_contents + ] + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + assert operator._existing_documents_match(documents, text_contents) + + @pytest.mark.django_db + def test_keys_for_instance(self): + instance = ExamplePageFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + keys = operator._keys_for_instance(instance) + assert len(keys) == 2 + assert keys[0] == f"testapp.ExamplePage:{instance.pk}" + assert keys[1] == f"wagtailcore.Page:{instance.pk}" + + @pytest.mark.django_db + def test_generate_documents_returns_documents(self): + instance = ExamplePageFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + documents = list( + operator.to_documents( + instance, embedding_backend=get_embedding_backend("default") + ) + ) + assert len(documents) == 1 + assert documents[0].content == f"{instance.title}\n{instance.body}" + + @pytest.mark.django_db + def test_bulk_generate_documents_returns_documents(self): + instances = ExamplePageFactory.create_batch(3) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + documents = list( + operator.bulk_to_documents( + instances, embedding_backend=get_embedding_backend("default") + ) + ) + assert len(documents) == 3 + assert all( + document.content == f"{instance.title}\n{instance.body}" + for document, instance in zip(documents, instances, strict=False) + ) + + @pytest.mark.django_db + def test_bulk_generate_documents_returns_documents_for_multiple_models(self): + example_pages = ExamplePageFactory.create_batch(3) + different_pages = DifferentPageFactory.create_batch(3) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + documents = list( + operator.bulk_to_documents( + example_pages + different_pages, + embedding_backend=get_embedding_backend("default"), + ) + ) + assert len(documents) == 6 + assert all( + document.content == f"{instance.title}\n{instance.body}" + for document, instance in zip( + documents, example_pages + different_pages, strict=False + ) + ) + + @pytest.mark.django_db + def test_bulk_to_documents_batches_objects(self, mocker): + instances = ExamplePageFactory.create_batch(10) + operator = ModelToDocumentOperator(EmbeddableFieldsObjectChunkerOperator) + bulk_generate_mock = mocker.patch.object(operator, "bulk_generate_documents") + list( + operator.bulk_to_documents( + instances, + embedding_backend=get_embedding_backend("default"), + batch_size=2, + ) + ) + assert bulk_generate_mock.call_count == 5 + + +class TestConverter: + @pytest.mark.django_db + def test_returns_original_object(self, patch_embedding_fields): + with patch_embedding_fields(ExamplePage, [EmbeddingField("body")]): + instance = ExamplePageFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + converter = EmbeddableFieldsDocumentConverter() + document = next( + converter.to_documents( + instance, embedding_backend=get_embedding_backend("default") + ) + ) + recovered_instance = converter.from_document(document) + assert isinstance(recovered_instance, ExamplePage) + assert recovered_instance.pk == instance.pk + + +@pytest.mark.django_db +def test_convert_single_document_to_object(): + converter = EmbeddableFieldsDocumentConverter() + instance = ExamplePageFactory.create( + title="Important Title", body=fake.text(max_nb_chars=200) + ) + documents = list( + converter.to_documents( + instance, embedding_backend=get_embedding_backend("default") + ) + ) + recovered_instance = converter.from_document(documents[0]) + assert isinstance(recovered_instance, ExamplePage) + assert recovered_instance.pk == instance.pk + + +@pytest.mark.django_db +def test_convert_multiple_documents_to_objects(): + converter = EmbeddableFieldsDocumentConverter() + example_objects = ExampleModelFactory.create_batch(5) + example_pages = ExamplePageFactory.create_batch(5) + different_pages = DifferentPageFactory.create_batch(5) + all_objects = list(example_objects + example_pages + different_pages) + documents = list( + converter.bulk_to_documents( + all_objects, embedding_backend=get_embedding_backend("default") + ) + ) + recovered_objects = list(converter.bulk_from_documents(documents)) + assert recovered_objects == all_objects diff --git a/tests/test_index.py b/tests/test_index.py index be18501..a9cc296 100644 --- a/tests/test_index.py +++ b/tests/test_index.py @@ -8,8 +8,9 @@ from wagtail_vector_index.storage import ( registry, ) -from wagtail_vector_index.storage.base import Document, VectorIndex -from wagtail_vector_index.storage.models import EmbeddingField +from wagtail_vector_index.storage.base import VectorIndex +from wagtail_vector_index.storage.django import EmbeddingField, ModelKey +from wagtail_vector_index.storage.models import Document fake = Faker() @@ -53,7 +54,7 @@ def gen_documents(cls, *args, **kwargs): for page in test_pages: vector = get_vector_for_text(page.title) yield Document( - embedding_pk=page.pk, + object_keys=[ModelKey.from_instance(page)], metadata={ "title": page.title, "object_id": str(page.pk), @@ -75,7 +76,7 @@ def mock_vector_index(mocker, mock_embedding_backend, document_generator): ) mocker.patch( - "wagtail_vector_index.storage.models.EmbeddableFieldsDocumentConverter.bulk_to_documents", + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_to_documents", side_effect=document_generator, ) @@ -132,7 +133,9 @@ def test_index_get_documents_returns_at_least_one_document_per_page(): index = registry["ExamplePageIndex"] index.rebuild_index() documents = index.get_documents() - found_pages = {document.metadata.get("object_id") for document in documents} + found_pages = { + ModelKey(document.object_keys[0]).object_id for document in documents + } assert found_pages == {str(page.pk) for page in pages} @@ -147,7 +150,8 @@ def test_index_with_multiple_models(): example_pages_ids = {str(page.pk) for page in example_pages} different_page_ids = {str(page.pk) for page in different_pages} found_page_ids = { - document.metadata["object_id"] for document in index.get_documents() + ModelKey(document.object_keys[0]).object_id + for document in index.get_documents() } assert found_page_ids == example_pages_ids.union(different_page_ids) @@ -172,7 +176,7 @@ def gen_pages(cls, *args, **kwargs): yield from pages mocker.patch( - "wagtail_vector_index.storage.models.EmbeddableFieldsDocumentConverter.bulk_from_documents", + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", side_effect=gen_pages, ) @@ -197,7 +201,7 @@ def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.0): yield from documents query_mock = mocker.patch("conftest.ChatMockBackend.chat") - expected_content = "\n".join([doc.metadata["content"] for doc in documents]) + expected_content = "\n".join([doc.content for doc in documents]) similar_documents_mock = mocker.patch.object(index, "get_similar_documents") similar_documents_mock.side_effect = get_similar_documents index.query("") @@ -215,7 +219,7 @@ def get_similar_documents(query_embedding, limit=0, similarity_threshold=0.5): yield from documents query_mock = mocker.patch("conftest.ChatMockBackend.chat") - expected_content = "\n".join([doc.metadata["content"] for doc in documents]) + expected_content = "\n".join([doc.content for doc in documents]) similar_documents_mock = mocker.patch.object(index, "get_similar_documents") similar_documents_mock.side_effect = get_similar_documents index.query("", similarity_threshold=0.5) @@ -232,7 +236,7 @@ def gen_pages(cls, *args, **kwargs): yield from pages mocker.patch( - "wagtail_vector_index.storage.models.EmbeddableFieldsDocumentConverter.bulk_from_documents", + "wagtail_vector_index.storage.django.EmbeddableFieldsDocumentConverter.bulk_from_documents", side_effect=gen_pages, ) diff --git a/tests/test_model_converter.py b/tests/test_model_converter.py deleted file mode 100644 index b677965..0000000 --- a/tests/test_model_converter.py +++ /dev/null @@ -1,53 +0,0 @@ -import pytest -from factories import DifferentPageFactory, ExampleModelFactory, ExamplePageFactory -from faker import Faker -from testapp.models import ExamplePage -from wagtail_vector_index.ai import get_embedding_backend -from wagtail_vector_index.storage.models import ( - EmbeddableFieldsDocumentConverter, - EmbeddingField, -) - -fake = Faker() - - -@pytest.mark.django_db -def test_get_split_content_adds_important_field_to_each_split(patch_embedding_fields): - with patch_embedding_fields( - ExamplePage, [EmbeddingField("title", important=True), EmbeddingField("body")] - ): - body = fake.text(max_nb_chars=200) - instance = ExamplePageFactory.create(title="Important Title", body=body) - converter = EmbeddableFieldsDocumentConverter() - splits = converter._get_split_content(instance, chunk_size=100) - assert all(split.startswith(instance.title) for split in splits) - - -@pytest.mark.django_db -def test_convert_single_document_to_object(): - converter = EmbeddableFieldsDocumentConverter() - instance = ExamplePageFactory.create( - title="Important Title", body=fake.text(max_nb_chars=200) - ) - documents = list( - converter.to_documents( - instance, embedding_backend=get_embedding_backend("default") - ) - ) - recovered_instance = converter.from_document(documents[0]) - assert isinstance(recovered_instance, ExamplePage) - assert recovered_instance.pk == instance.pk - - -@pytest.mark.django_db -def test_convert_multiple_documents_to_objects(): - converter = EmbeddableFieldsDocumentConverter() - example_objects = ExampleModelFactory.create_batch(5) - example_pages = ExamplePageFactory.create_batch(5) - different_pages = DifferentPageFactory.create_batch(5) - all_objects = list(example_objects + example_pages + different_pages) - documents = converter.bulk_to_documents( - all_objects, embedding_backend=get_embedding_backend("default") - ) - recovered_objects = list(converter.bulk_from_documents(documents)) - assert recovered_objects == all_objects diff --git a/tests/test_model_index.py b/tests/test_model_index.py index ebcfc94..1755211 100644 --- a/tests/test_model_index.py +++ b/tests/test_model_index.py @@ -3,7 +3,7 @@ from faker import Faker from testapp.models import ExamplePage from wagtail_vector_index.storage import registry -from wagtail_vector_index.storage.models import Embedding +from wagtail_vector_index.storage.models import Document fake = Faker() @@ -46,4 +46,4 @@ def test_rebuilding_model_index_creates_embeddings(): ExamplePageFactory.create_batch(10) index = ExamplePage.vector_index index.rebuild_index() - assert Embedding.objects.count() == 10 + assert Document.objects.count() == 10 diff --git a/tests/test_storage/test_models.py b/tests/test_storage/test_models.py index d41abcb..b26163b 100644 --- a/tests/test_storage/test_models.py +++ b/tests/test_storage/test_models.py @@ -1,4 +1,4 @@ -from wagtail_vector_index.storage.models import ( +from wagtail_vector_index.storage.django import ( DefaultStorageVectorIndex, build_vector_index_base_for_storage_provider, ) @@ -19,7 +19,7 @@ def test_build_vector_index_base_for_default_storage_provider(settings): def test_build_vector_index_base_for_alias_storage_provider(settings): - from wagtail_vector_index.storage.pgvector import PgvectorIndexMixin + from wagtail_vector_index.storage.pgvector.provider import PgvectorIndexMixin settings.WAGTAIL_VECTOR_INDEX_STORAGE_PROVIDERS = { "default": { diff --git a/tests/testapp/models.py b/tests/testapp/models.py index e92dda1..9cba8f6 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -2,7 +2,7 @@ from wagtail.admin.panels import FieldPanel from wagtail.fields import RichTextField from wagtail.models import Page -from wagtail_vector_index.storage.models import ( +from wagtail_vector_index.storage.django import ( DefaultStorageVectorIndex, EmbeddingField, PageEmbeddableFieldsVectorIndexMixin, diff --git a/tests/testapp/settings.py b/tests/testapp/settings.py index 59b65e9..7ec17d7 100644 --- a/tests/testapp/settings.py +++ b/tests/testapp/settings.py @@ -252,17 +252,17 @@ "loggers": { "wagtail_vector_index": { "handlers": ["console"], - "level": "DEBUG", + "level": "INFO", "propagate": False, }, "llm": { "handlers": ["console"], - "level": "DEBUG", + "level": "INFO", "propagate": False, }, "testapp": { "handlers": ["console"], - "level": "DEBUG", + "level": "INFO", "propagate": False, }, },