From 213d850ebbda4dc9c6757ed4610e8fd5e6545c51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Mon, 20 May 2024 17:47:51 +0200 Subject: [PATCH 1/4] rename embeddings module --- docs/quickstart/quickstart2_code.py | 2 +- docs/quickstart/quickstart3_code.py | 2 +- src/dbally/{embedding_client => embeddings}/__init__.py | 0 src/dbally/{embedding_client => embeddings}/base.py | 0 src/dbally/{embedding_client => embeddings}/openai.py | 2 +- src/dbally/similarity/chroma_store.py | 2 +- src/dbally/similarity/faiss_store.py | 2 +- tests/integration/test_index_with_chroma.py | 2 +- tests/unit/similarity/test_chroma.py | 2 +- 9 files changed, 7 insertions(+), 7 deletions(-) rename src/dbally/{embedding_client => embeddings}/__init__.py (100%) rename src/dbally/{embedding_client => embeddings}/base.py (100%) rename src/dbally/{embedding_client => embeddings}/openai.py (96%) diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index cd9669d0..9ba8a5ef 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -11,7 +11,7 @@ from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex -from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.embeddings.openai import OpenAiEmbeddingClient from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 3732c9da..cc4e3e74 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -11,7 +11,7 @@ from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex -from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.embeddings.openai import OpenAiEmbeddingClient from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') diff --git a/src/dbally/embedding_client/__init__.py b/src/dbally/embeddings/__init__.py similarity index 100% rename from src/dbally/embedding_client/__init__.py rename to src/dbally/embeddings/__init__.py diff --git a/src/dbally/embedding_client/base.py b/src/dbally/embeddings/base.py similarity index 100% rename from src/dbally/embedding_client/base.py rename to src/dbally/embeddings/base.py diff --git a/src/dbally/embedding_client/openai.py b/src/dbally/embeddings/openai.py similarity index 96% rename from src/dbally/embedding_client/openai.py rename to src/dbally/embeddings/openai.py index 5f601ead..f3f482e5 100644 --- a/src/dbally/embedding_client/openai.py +++ b/src/dbally/embeddings/openai.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient class OpenAiEmbeddingClient(EmbeddingClient): diff --git a/src/dbally/similarity/chroma_store.py b/src/dbally/similarity/chroma_store.py index 53ef657a..95dd88c9 100644 --- a/src/dbally/similarity/chroma_store.py +++ b/src/dbally/similarity/chroma_store.py @@ -3,7 +3,7 @@ import chromadb -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient from dbally.similarity.store import SimilarityStore diff --git a/src/dbally/similarity/faiss_store.py b/src/dbally/similarity/faiss_store.py index 7f43e34b..46f2c725 100644 --- a/src/dbally/similarity/faiss_store.py +++ b/src/dbally/similarity/faiss_store.py @@ -4,7 +4,7 @@ import faiss import numpy as np -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient from dbally.similarity.store import SimilarityStore diff --git a/tests/integration/test_index_with_chroma.py b/tests/integration/test_index_with_chroma.py index b28ec1db..a2698018 100644 --- a/tests/integration/test_index_with_chroma.py +++ b/tests/integration/test_index_with_chroma.py @@ -2,7 +2,7 @@ import pytest from chromadb import Documents, EmbeddingFunction, Embeddings -from dbally.embedding_client.base import EmbeddingClient +from dbally.embeddings.base import EmbeddingClient from dbally.similarity import ChromadbStore from dbally.similarity.fetcher import SimilarityFetcher from dbally.similarity.index import SimilarityIndex diff --git a/tests/unit/similarity/test_chroma.py b/tests/unit/similarity/test_chroma.py index f54301ed..4cad4497 100644 --- a/tests/unit/similarity/test_chroma.py +++ b/tests/unit/similarity/test_chroma.py @@ -4,7 +4,7 @@ import chromadb import pytest -from dbally.embedding_client import EmbeddingClient +from dbally.embeddings import EmbeddingClient from dbally.similarity import ChromadbStore DEFAULT_METADATA = {"hnsw:space": "l2"} From e916c1deae15d50d6748f3c8ca1fb36602514839 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 21 May 2024 10:36:42 +0200 Subject: [PATCH 2/4] add litellm embeddings --- docs/how-to/use_chromadb_store.md | 6 +- docs/how-to/use_custom_similarity_fetcher.md | 4 +- docs/quickstart/quickstart2.md | 6 +- docs/quickstart/quickstart2_code.py | 4 +- docs/quickstart/quickstart3_code.py | 4 +- src/dbally/__init__.py | 23 ++++++ src/dbally/embeddings/__init__.py | 4 +- src/dbally/embeddings/_exceptions.py | 39 +++++++++ src/dbally/embeddings/base.py | 13 +-- src/dbally/embeddings/litellm.py | 85 ++++++++++++++++++++ src/dbally/embeddings/openai.py | 52 ------------ src/dbally/llms/clients/litellm.py | 25 ++++-- src/dbally/llms/litellm.py | 25 ++++-- 13 files changed, 205 insertions(+), 85 deletions(-) create mode 100644 src/dbally/embeddings/_exceptions.py create mode 100644 src/dbally/embeddings/litellm.py delete mode 100644 src/dbally/embeddings/openai.py diff --git a/docs/how-to/use_chromadb_store.md b/docs/how-to/use_chromadb_store.md index b715d8e8..4184ee33 100644 --- a/docs/how-to/use_chromadb_store.md +++ b/docs/how-to/use_chromadb_store.md @@ -33,12 +33,12 @@ or [set up Chromadb in the client/server mode](https://docs.trychroma.com/usage- chroma_client = chromadb.HttpClient(host='localhost', port=8000) ``` -Next, you can either use [one of dbally embedding clients][dbally.embedding_client.EmbeddingClient], such as [OpenAiEmbeddingClient][dbally.embedding_client.OpenAiEmbeddingClient] +Next, you can either use [one of dbally embedding clients][dbally.embeddings.EmbeddingClient], such as [LiteLLMEmbeddingClient][dbally.embeddings.LiteLLMEmbeddingClient] ```python -from dbally.embedding_client import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient -embedding_client=OpenAiEmbeddingClient( +embedding_client=LiteLLMEmbeddingClient( api_key="your-api-key", ) diff --git a/docs/how-to/use_custom_similarity_fetcher.md b/docs/how-to/use_custom_similarity_fetcher.md index 06a01197..8bdc77e7 100644 --- a/docs/how-to/use_custom_similarity_fetcher.md +++ b/docs/how-to/use_custom_similarity_fetcher.md @@ -43,13 +43,13 @@ breeds_similarity = SimilarityIndex( store=FaissStore( index_dir="./similarity_indexes", index_name="breeds_similarity", - embedding_client=OpenAiEmbeddingClient( + embedding_client=LiteLLMEmbeddingClient( api_key="your-api-key", ) ) ``` -In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `OpenAiEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md). +In this example, we used the FaissStore, which utilizes the `faiss` library for rapid similarity search. We also employed the `LiteLLMEmbeddingClient` to get the semantic embeddings for the dog breeds. Depending on your needs, you can use a different built-in store or create [a custom one](../how-to/use_custom_similarity_store.md). ## Using the Similarity Index diff --git a/docs/quickstart/quickstart2.md b/docs/quickstart/quickstart2.md index 8f85f093..81165890 100644 --- a/docs/quickstart/quickstart2.md +++ b/docs/quickstart/quickstart2.md @@ -60,18 +60,18 @@ Next, let's define a store that will store the country names and can be used to ```python from dbally.similarity import FaissStore -from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient country_store = FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", - embedding_client=OpenAiEmbeddingClient( + embedding_client=LiteLLMEmbeddingClient( api_key="your-api-key", ) ) ``` -In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `OpenAiEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key. +In this example, we used the `FaissStore` store, which employs the `faiss` library for fast similarity search. We also used the `LiteLLMEmbeddingClient` to get the semantic embeddings for the country names. Replace `your-api-key` with your OpenAI API key. Finally, let's define the similarity index: diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 9ba8a5ef..4b012339 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -11,7 +11,7 @@ from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex -from dbally.embeddings.openai import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -30,7 +30,7 @@ store=FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", - embedding_client=OpenAiEmbeddingClient( + embedding_client=LiteLLMEmbeddingClient( api_key=os.environ["OPENAI_API_KEY"], ) ), diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index cc4e3e74..d03c2c9f 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -11,7 +11,7 @@ from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex -from dbally.embeddings.openai import OpenAiEmbeddingClient +from dbally.embeddings.litellm import LiteLLMEmbeddingClient from dbally.llms.litellm import LiteLLM engine = create_engine('sqlite:///candidates.db') @@ -31,7 +31,7 @@ store=FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", - embedding_client=OpenAiEmbeddingClient( + embedding_client=LiteLLMEmbeddingClient( api_key=os.environ["OPENAI_API_KEY"], ) ), diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index a947eb97..258c2f79 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -12,6 +12,12 @@ from ._main import create_collection from ._types import NOT_GIVEN, NotGiven from .collection import Collection +from .embeddings._exceptions import ( + EmbeddingConnectionError, + EmbeddingError, + EmbeddingResponseError, + EmbeddingStatusError, +) from .llms.clients._exceptions import LLMConnectionError, LLMError, LLMResponseError, LLMStatusError __all__ = [ @@ -25,6 +31,10 @@ "DataFrameBaseView", "ExecutionResult", "DbAllyError", + "EmbeddingError", + "EmbeddingConnectionError", + "EmbeddingResponseError", + "EmbeddingStatusError", "LLMError", "LLMConnectionError", "LLMResponseError", @@ -32,3 +42,16 @@ "NotGiven", "NOT_GIVEN", ] + +# Update the __module__ attribute for exported symbols so that +# error messages point to this module instead of the module +# it was originally defined in, e.g. +# dbally._exceptions.LLMError -> dbally.LLMError +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + try: + __locals[__name].__module__ = "dbally" + except (TypeError, AttributeError): + # Some of our exported symbols are builtins which we can't set attributes for. + pass diff --git a/src/dbally/embeddings/__init__.py b/src/dbally/embeddings/__init__.py index e2cfb278..67fe9b7f 100644 --- a/src/dbally/embeddings/__init__.py +++ b/src/dbally/embeddings/__init__.py @@ -1,4 +1,4 @@ from .base import EmbeddingClient -from .openai import OpenAiEmbeddingClient +from .litellm import LiteLLMEmbeddingClient -__all__ = ["EmbeddingClient", "OpenAiEmbeddingClient"] +__all__ = ["EmbeddingClient", "LiteLLMEmbeddingClient"] diff --git a/src/dbally/embeddings/_exceptions.py b/src/dbally/embeddings/_exceptions.py new file mode 100644 index 00000000..37c24f3c --- /dev/null +++ b/src/dbally/embeddings/_exceptions.py @@ -0,0 +1,39 @@ +from .._exceptions import DbAllyError + + +class EmbeddingError(DbAllyError): + """ + Base class for all exceptions raised by the EmbeddingClient. + """ + + def __init__(self, message: str) -> None: + super().__init__(message) + self.message = message + + +class EmbeddingConnectionError(EmbeddingError): + """ + Raised when there is an error connecting to the embedding API. + """ + + def __init__(self, message: str = "Connection error.") -> None: + super().__init__(message) + + +class EmbeddingStatusError(EmbeddingError): + """ + Raised when an API response has a status code of 4xx or 5xx. + """ + + def __init__(self, message: str, status_code: int) -> None: + super().__init__(message) + self.status_code = status_code + + +class EmbeddingResponseError(EmbeddingError): + """ + Raised when an API response has an invalid schema. + """ + + def __init__(self, message: str = "Data returned by API invalid for expected schema.") -> None: + super().__init__(message) diff --git a/src/dbally/embeddings/base.py b/src/dbally/embeddings/base.py index a582aa9e..4c757251 100644 --- a/src/dbally/embeddings/base.py +++ b/src/dbally/embeddings/base.py @@ -1,15 +1,16 @@ -# disable args docstring check as args are documented in OpenAI API docs -import abc +from abc import ABC, abstractmethod from typing import List -class EmbeddingClient(metaclass=abc.ABCMeta): - """Abstract client for creating text embeddings.""" +class EmbeddingClient(ABC): + """ + Abstract client for creating text embeddings. + """ - @abc.abstractmethod + @abstractmethod async def get_embeddings(self, data: List[str]) -> List[List[float]]: """ - For a given list of strings returns a list of embeddings. + Creates embeddings for the given strings. Args: data: List of strings to get embeddings for. diff --git a/src/dbally/embeddings/litellm.py b/src/dbally/embeddings/litellm.py new file mode 100644 index 00000000..6e312c56 --- /dev/null +++ b/src/dbally/embeddings/litellm.py @@ -0,0 +1,85 @@ +from typing import Dict, List, Optional + +try: + import litellm + + HAVE_LITELLM = True +except ImportError: + HAVE_LITELLM = False + +from dbally.embeddings.base import EmbeddingClient + +from ._exceptions import EmbeddingConnectionError, EmbeddingResponseError, EmbeddingStatusError + + +class LiteLLMEmbeddingClient(EmbeddingClient): + """ + Client for creating text embeddings using LiteLLM API. + """ + + def __init__( + self, + model: str = "text-embedding-3-small", + options: Optional[Dict] = None, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + api_version: Optional[str] = None, + ) -> None: + """ + Constructs the LiteLLMEmbeddingClient. + + Args: + model: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/embedding/supported_embedding)\ + to be used. Default is "text-embedding-3-small". + options: Additional options to pass to the LiteLLM API. + api_base: The API endpoint you want to call the model with. + api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, + for more information, follow the instructions for your specific vendor in the\ + [LiteLLM documentation](https://docs.litellm.ai/docs/embedding/supported_embedding). + api_version: The API version for the call. + + Raises: + ImportError: If the litellm package is not installed. + """ + if not HAVE_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + + super().__init__() + self.model = model + self.options = options or {} + self.api_base = api_base + self.api_key = api_key + self.api_version = api_version + + async def get_embeddings(self, data: List[str]) -> List[List[float]]: + """ + Creates embeddings for the given strings. + + Args: + data: List of strings to get embeddings for. + + Returns: + List of embeddings for the given strings. + + Raises: + EmbeddingConnectionError: If there is a connection error with the embedding API. + EmbeddingStatusError: If the embedding API returns an error status code. + EmbeddingResponseError: If the embedding API response is invalid. + """ + try: + response = await litellm.aembedding( + input=data, + model=self.model, + api_base=self.api_base, + api_key=self.api_key, + api_version=self.api_version, + **self.options, + ) + except litellm.openai.APIConnectionError as exc: + raise EmbeddingConnectionError() from exc + except litellm.openai.APIStatusError as exc: + raise EmbeddingStatusError(exc.message, exc.status_code) from exc + except litellm.openai.APIResponseValidationError as exc: + raise EmbeddingResponseError() from exc + + return [embedding["embedding"] for embedding in response.data] diff --git a/src/dbally/embeddings/openai.py b/src/dbally/embeddings/openai.py deleted file mode 100644 index f3f482e5..00000000 --- a/src/dbally/embeddings/openai.py +++ /dev/null @@ -1,52 +0,0 @@ -from typing import Any, Dict, List, Optional - -from dbally.embeddings.base import EmbeddingClient - - -class OpenAiEmbeddingClient(EmbeddingClient): - """ - Client for creating text embeddings using OpenAI API. - """ - - def __init__(self, api_key: str, model: str = "text-embedding-3-small", openai_options: Optional[Dict] = None): - """ - Initializes the OpenAiEmbeddingClient. - - Args: - api_key: The OpenAI API key. - model: The model to use for embeddings. - openai_options: Additional options to pass to the OpenAI API. - """ - super().__init__() - self.api_key = api_key - self.model = model - self.openai_options = openai_options - - try: - from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel - except ImportError as exc: - raise ImportError("You need to install openai package to use GPT models") from exc - - self._openai = AsyncOpenAI(api_key=self.api_key) - - async def get_embeddings(self, data: List[str]) -> List[List[float]]: - """ - For a given list of strings returns a list of embeddings. - - Args: - data: List of strings to get embeddings for. - - Returns: - List of embeddings for the given strings. - """ - kwargs: Dict[str, Any] = { - "model": self.model, - } - if self.openai_options: - kwargs.update(self.openai_options) - - response = await self._openai.embeddings.create( - input=data, - **kwargs, - ) - return [embedding.embedding for embedding in response.data] diff --git a/src/dbally/llms/clients/litellm.py b/src/dbally/llms/clients/litellm.py index 82752b1d..d24a58c2 100644 --- a/src/dbally/llms/clients/litellm.py +++ b/src/dbally/llms/clients/litellm.py @@ -1,8 +1,13 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Union -import litellm -from openai import APIConnectionError, APIResponseValidationError, APIStatusError +try: + import litellm + + HAVE_LITELLM = True +except ImportError: + HAVE_LITELLM = False + from dbally.data_models.audit import LLMEvent from dbally.llms.clients.base import LLMClient, LLMOptions @@ -53,7 +58,13 @@ def __init__( base_url: Base URL of the LLM API. api_key: API key used to authenticate with the LLM API. api_version: API version of the LLM API. + + Raises: + ImportError: If the litellm package is not installed. """ + if not HAVE_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + super().__init__(model_name) self.base_url = base_url self.api_key = api_key @@ -91,17 +102,17 @@ async def call( api_key=self.api_key, api_version=self.api_version, response_format=response_format, - **options.dict(), # type: ignore + **options.dict(), ) - except APIConnectionError as exc: + except litellm.openai.APIConnectionError as exc: raise LLMConnectionError() from exc - except APIStatusError as exc: + except litellm.openai.APIStatusError as exc: raise LLMStatusError(exc.message, exc.status_code) from exc - except APIResponseValidationError as exc: + except litellm.openai.APIResponseValidationError as exc: raise LLMResponseError() from exc event.completion_tokens = response.usage.completion_tokens event.prompt_tokens = response.usage.prompt_tokens event.total_tokens = response.usage.total_tokens - return response.choices[0].message.content # type: ignore + return response.choices[0].message.content diff --git a/src/dbally/llms/litellm.py b/src/dbally/llms/litellm.py index 9b295214..f6b6727d 100644 --- a/src/dbally/llms/litellm.py +++ b/src/dbally/llms/litellm.py @@ -1,7 +1,12 @@ from functools import cached_property from typing import Dict, Optional -from litellm import token_counter +try: + import litellm + + HAVE_LITELLM = True +except ImportError: + HAVE_LITELLM = False from dbally.llms.base import LLM from dbally.llms.clients.litellm import LiteLLMClient, LiteLLMOptions @@ -25,18 +30,24 @@ def __init__( api_version: Optional[str] = None, ) -> None: """ - Construct a new LiteLLM instance. + Constructs a new LiteLLM instance. Args: - model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used, - default is "gpt-3.5-turbo". + model_name: Name of the [LiteLLM supported model](https://docs.litellm.ai/docs/providers) to be used.\ + Default is "gpt-3.5-turbo". default_options: Default options to be used. base_url: Base URL of the LLM API. api_key: API key to be used. API key to be used. If not specified, an environment variable will be used, for more information, follow the instructions for your specific vendor in the\ [LiteLLM documentation](https://docs.litellm.ai/docs/providers). api_version: API version to be used. If not specified, the default version will be used. + + Raises: + ImportError: If the litellm package is not installed. """ + if not HAVE_LITELLM: + raise ImportError("You need to install litellm package to use LiteLLM models") + super().__init__(model_name, default_options) self.base_url = base_url self.api_key = api_key @@ -56,7 +67,7 @@ def _client(self) -> LiteLLMClient: def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: """ - Count tokens in the messages using a specified model. + Counts tokens in the messages using a specified model. Args: messages: Messages to count tokens for. @@ -65,4 +76,6 @@ def count_tokens(self, messages: ChatFormat, fmt: Dict[str, str]) -> int: Returns: Number of tokens in the messages. """ - return sum(token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages) + return sum( + litellm.token_counter(model=self.model_name, text=message["content"].format(**fmt)) for message in messages + ) From 0105ce4273a1ea09f74947d9fd9f18f6b49c7590 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 21 May 2024 12:28:56 +0200 Subject: [PATCH 3/4] add embedding model name in docs --- docs/how-to/use_chromadb_store.md | 19 +++++++++++-------- docs/how-to/use_custom_similarity_fetcher.md | 10 ++++++---- docs/quickstart/quickstart2.md | 5 +++-- docs/quickstart/quickstart2_code.py | 3 ++- docs/quickstart/quickstart3_code.py | 1 + 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/docs/how-to/use_chromadb_store.md b/docs/how-to/use_chromadb_store.md index 4184ee33..6e9d5a58 100644 --- a/docs/how-to/use_chromadb_store.md +++ b/docs/how-to/use_chromadb_store.md @@ -7,7 +7,7 @@ To use Chromadb with db-ally you need to install the chromadb extension -```python +```bash pip install dbally[chromadb] ``` @@ -39,22 +39,27 @@ Next, you can either use [one of dbally embedding clients][dbally.embeddings.Emb from dbally.embeddings.litellm import LiteLLMEmbeddingClient embedding_client=LiteLLMEmbeddingClient( - api_key="your-api-key", - ) - + model="text-embedding-3-small", # to use openai embedding model + api_key="your-api-key", +) ``` or [Chromadb embedding functions](https://docs.trychroma.com/embeddings) -``` +```python from chromadb.utils import embedding_functions + embedding_client = embedding_functions.DefaultEmbeddingFunction() ``` to define your [`ChromadbStore`][dbally.similarity.ChromadbStore]. ```python -store = ChromadbStore(index_name="myChromaIndex", chroma_client=chroma_client, embedding_function=embedding_client) +store = ChromadbStore( + index_name="myChromaIndex", + chroma_client=chroma_client, + embedding_function=embedding_client, +) ``` After this setup, you can initialize the SimilarityIndex @@ -63,8 +68,6 @@ After this setup, you can initialize the SimilarityIndex from typing import Annotated country_similarity = SimilarityIndex(store, DummyCountryFetcher()) - - ``` and [update it and find the closest matches in the same way as in built-in similarity indices](./use_custom_similarity_store.md/#using-the-similarity-index) . diff --git a/docs/how-to/use_custom_similarity_fetcher.md b/docs/how-to/use_custom_similarity_fetcher.md index 8bdc77e7..b79b6554 100644 --- a/docs/how-to/use_custom_similarity_fetcher.md +++ b/docs/how-to/use_custom_similarity_fetcher.md @@ -41,11 +41,13 @@ from dbally.similarity.store import FaissStore breeds_similarity = SimilarityIndex( fetcher=DogBreedsFetcher(), store=FaissStore( - index_dir="./similarity_indexes", - index_name="breeds_similarity", + index_dir="./similarity_indexes", + index_name="breeds_similarity", + ), embedding_client=LiteLLMEmbeddingClient( - api_key="your-api-key", - ) + model="text-embedding-3-small", # to use openai embedding model + api_key=os.environ["OPENAI_API_KEY"], + ), ) ``` diff --git a/docs/quickstart/quickstart2.md b/docs/quickstart/quickstart2.md index 81165890..cb839bff 100644 --- a/docs/quickstart/quickstart2.md +++ b/docs/quickstart/quickstart2.md @@ -66,8 +66,9 @@ country_store = FaissStore( index_dir="./similarity_indexes", index_name="country_similarity", embedding_client=LiteLLMEmbeddingClient( - api_key="your-api-key", - ) + model="text-embedding-3-small", # to use openai embedding model + api_key=os.environ["OPENAI_API_KEY"], + ), ) ``` diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index 4b012339..d330470a 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -31,8 +31,9 @@ index_dir="./similarity_indexes", index_name="country_similarity", embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embedding model api_key=os.environ["OPENAI_API_KEY"], - ) + ), ), ) diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index d03c2c9f..461cf723 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -32,6 +32,7 @@ index_dir="./similarity_indexes", index_name="country_similarity", embedding_client=LiteLLMEmbeddingClient( + model="text-embedding-3-small", # to use openai embeddings model api_key=os.environ["OPENAI_API_KEY"], ) ), From 6962c3f764d307902d648ebf9b6a72c7c7b348ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Pstr=C4=85g?= Date: Tue, 21 May 2024 12:33:33 +0200 Subject: [PATCH 4/4] fix typos --- docs/quickstart/quickstart3_code.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 461cf723..f0e385b1 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -32,9 +32,9 @@ index_dir="./similarity_indexes", index_name="country_similarity", embedding_client=LiteLLMEmbeddingClient( - model="text-embedding-3-small", # to use openai embeddings model + model="text-embedding-3-small", # to use openai embedding model api_key=os.environ["OPENAI_API_KEY"], - ) + ), ), )