Skip to content

Commit

Permalink
[BUG] 1965 Split up embedding functions - Redux (chroma-core#2395)
Browse files Browse the repository at this point in the history
## Description of changes
The original attempt to split up the embedding functions failed because
of python 3.9 and 3.10 incompatibilities with `issubtype`.

Original PR here: chroma-core#2034

Failing tests here:
https://github.com/chroma-core/chroma/actions/runs/9605053108/job/26491923410

The fix is changing `issubtype` to `isinstance`, which has the same
functionality.

## Test plan

Along with CI, tested locally with python 3.9 and 3.10 and confirmed
passing.

## Documentation Changes
N/A
  • Loading branch information
atroyn authored Jun 21, 2024
1 parent 7da06e7 commit dd56ded
Show file tree
Hide file tree
Showing 19 changed files with 1,240 additions and 1,035 deletions.
13 changes: 8 additions & 5 deletions chromadb/api/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast
from numpy.typing import NDArray
import numpy as np
from typing_extensions import Literal, TypedDict, Protocol
from typing_extensions import Literal, TypedDict, Protocol, runtime_checkable
import chromadb.errors as errors
from chromadb.types import (
Metadata,
Expand Down Expand Up @@ -56,7 +56,7 @@ def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs:


def maybe_cast_one_to_many_embedding(
target: Union[OneOrMany[Embedding], OneOrMany[np.ndarray]]
target: Union[OneOrMany[Embedding], OneOrMany[np.ndarray]] # type: ignore[type-arg]
) -> Embeddings:
if isinstance(target, List):
# One Embedding
Expand Down Expand Up @@ -101,7 +101,7 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents:


# Images
ImageDType = Union[np.uint, np.int_, np.float_]
ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined]
Image = NDArray[ImageDType]
Images = List[Image]

Expand Down Expand Up @@ -184,6 +184,7 @@ class IndexMetadata(TypedDict):
time_created: float


@runtime_checkable
class EmbeddingFunction(Protocol[D]):
def __call__(self, input: D) -> Embeddings:
...
Expand All @@ -199,8 +200,10 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings:

setattr(cls, "__call__", __call__)

def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings:
return retry(**retry_kwargs)(self.__call__)(input)
def embed_with_retries(
self, input: D, **retry_kwargs: Dict[str, Any]
) -> Embeddings:
return cast(Embeddings, retry(**retry_kwargs)(self.__call__)(input))


def validate_embedding_function(
Expand Down
5 changes: 4 additions & 1 deletion chromadb/test/ef/test_default_ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
import pytest
from hypothesis import given, settings

from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import (
ONNXMiniLM_L6_V2,
_verify_sha256,
)


def unique_by(x: Hashable) -> Hashable:
Expand Down
53 changes: 53 additions & 0 deletions chromadb/test/ef/test_ef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from chromadb.utils import embedding_functions
from chromadb.api.types import EmbeddingFunction


def test_get_builtins_holds() -> None:
"""
Ensure that `get_builtins` is consistent after the ef migration.
This test is intended to be temporary until the ef migration is complete as
these expected builtins are likely to grow as long as users add new
embedding functions.
REMOVE ME ON THE NEXT EF ADDITION
"""
expected_builtins = {
"AmazonBedrockEmbeddingFunction",
"CohereEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GooglePalmEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OllamaEmbeddingFunction",
"OpenAIEmbeddingFunction",
"OpenCLIPEmbeddingFunction",
"RoboflowEmbeddingFunction",
"SentenceTransformerEmbeddingFunction",
"Text2VecEmbeddingFunction",
"ChromaLangchainEmbeddingFunction",
}

assert expected_builtins == embedding_functions.get_builtins()


def test_default_ef_exists() -> None:
assert hasattr(embedding_functions, "DefaultEmbeddingFunction")
default_ef = embedding_functions.DefaultEmbeddingFunction()

assert default_ef is not None
assert isinstance(default_ef, EmbeddingFunction)


def test_ef_imports() -> None:
for ef in embedding_functions.get_builtins():
# Langchain embedding function is a special snowflake
if ef == "ChromaLangchainEmbeddingFunction":
continue
assert hasattr(embedding_functions, ef)
assert isinstance(getattr(embedding_functions, ef), type)
assert issubclass(getattr(embedding_functions, ef), EmbeddingFunction)
Loading

0 comments on commit dd56ded

Please sign in to comment.