diff --git a/README.md b/README.md index 292cae18..0d1dbea1 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ This is a basic implementation of a db-ally view for an example HR application, ```python from dbally import decorators, SqlAlchemyBaseView, create_collection +from dbally.llm_client.openai_client import OpenAIClient from sqlalchemy import create_engine class CandidateView(SqlAlchemyBaseView): @@ -52,7 +53,8 @@ class CandidateView(SqlAlchemyBaseView): return Candidate.country == country engine = create_engine('sqlite:///candidates.db') -my_collection = create_collection("collection_name") +llm = OpenAIClient(model_name="gpt-3.5-turbo") +my_collection = create_collection("collection_name", llm) my_collection.add(CandidateView, lambda: CandidateView(engine)) my_collection.ask("Find candidates from United States") diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index da20857d..5f939c2e 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -24,6 +24,7 @@ from dbally.collection import Collection from dbally.data_models.prompts.iql_prompt_template import default_iql_template from dbally.data_models.prompts.view_selector_prompt_template import default_view_selector_template +from dbally.llm_client.openai_client import OpenAIClient from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError @@ -81,13 +82,12 @@ async def evaluate(cfg: DictConfig) -> Any: engine = create_engine(benchmark_cfg.pg_connection_string + f"/{cfg.db_name}") - if "gpt" in cfg.model_name: - dbally.use_openai_llm( - model_name="gpt-4", - openai_api_key=benchmark_cfg.openai_api_key, - ) + llm_client = OpenAIClient( + model_name="gpt-4", + api_key=benchmark_cfg.openai_api_key, + ) - db = dbally.create_collection(cfg.db_name) + db = dbally.create_collection(cfg.db_name, llm_client) for view_name in cfg.view_names: view = VIEW_REGISTRY[ViewName(view_name)] diff --git a/docs/concepts/collections.md b/docs/concepts/collections.md index 158bbb55..7371899c 100644 --- a/docs/concepts/collections.md +++ b/docs/concepts/collections.md @@ -3,7 +3,7 @@ At its core, a collection groups together multiple [views](views.md). Once you've defined your views, the next step is to register them within a collection. Here's how you might do it: ```python -my_collection = dbally.create_collection("collection_name") +my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient()) my_collection.add(ExampleView) my_collection.add(RecipesView) ``` @@ -11,7 +11,7 @@ my_collection.add(RecipesView) Sometimes, view classes might need certain arguments when they're instantiated. In these instances, you'll want to register your view with a builder function that takes care of supplying these arguments. For instance, with views that rely on SQLAlchemy, you'll typically need to pass a database engine object like so: ```python -my_collection = dbally.create_collection("collection_name") +my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient()) engine = sqlalchemy.create_engine("sqlite://") my_collection.add(ExampleView, lambda: ExampleView(engine)) my_collection.add(RecipesView, lambda: RecipesView(engine)) diff --git a/docs/how-to/create_custom_event_handler.md b/docs/how-to/create_custom_event_handler.md index 217b5f71..46c2ab32 100644 --- a/docs/how-to/create_custom_event_handler.md +++ b/docs/how-to/create_custom_event_handler.md @@ -113,13 +113,18 @@ class FileEventHandler(EventHandler[TextIOWrapper, datetime]): ## Registering our event handler -To use our event handler, we need to register it with the db-ally `use_event_handler` function. +To use our event handler, we need to pass it to the collection when creating it: ```python import dbally +from dbally.llm_client.openai_client import OpenAIClient -dbally.use_event_handler(FileEventHandler()) +my_collection = bally.create_collection( + "collection_name", + llm_client=OpenAIClient(), + event_handlers=[FileEventHandler()], +) ``` -Now you can test your event handler by running a query and checking the logs directory for the log files. +Now you can test your event handler by running a query against the collection and checking the logs directory for the log files. diff --git a/docs/how-to/custom_views.md b/docs/how-to/custom_views.md index 29ad7ff5..d76c2199 100644 --- a/docs/how-to/custom_views.md +++ b/docs/how-to/custom_views.md @@ -151,7 +151,7 @@ import abc from typing import Callable, Any, Iterable from dbally.iql import IQLQuery -from dbally.data_models.execution_result import ExecutionResult +from dbally.data_models.execution_result import ViewExecutionResult @abc.abstractmethod def get_data(self) -> Iterable: @@ -159,10 +159,10 @@ def get_data(self) -> Iterable: Returns the full data to be filtered. """ -def execute(self, dry_run: bool = False) -> ExecutionResult: +def execute(self, dry_run: bool = False) -> ViewExecutionResult: filtered_data = list(filter(self._filter, self.get_data())) - return ExecutionResult(results=filtered_data, context={}) + return ViewExecutionResult(results=filtered_data, context={}) ``` The `execute` function gets the data (by calling the `get_data` method) and applies the combined filters to it. We're using the [`filter`](https://docs.python.org/3/library/functions.html#filter) function from Python's standard library to accomplish this. The filtered data is then returned as a list. @@ -216,10 +216,11 @@ Finally, we can use the `CandidatesView` just like any other view in db-ally. We ```python import asyncio import dbally -from dbally import CLIEventHandler +from dbally.llm_client.openai_client import OpenAIClient async def main(): - collection = dbally.create_collection("recruitment") + llm = OpenAIClient(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView) result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") diff --git a/docs/how-to/custom_views_code.py b/docs/how-to/custom_views_code.py index e6846dda..5ea4fb8e 100644 --- a/docs/how-to/custom_views_code.py +++ b/docs/how-to/custom_views_code.py @@ -9,12 +9,8 @@ from dbally import decorators, MethodsBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.iql import IQLQuery, syntax -from dbally.data_models.execution_result import ExecutionResult - -dbally.use_openai_llm( - openai_api_key=os.environ["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo", -) +from dbally.data_models.execution_result import ViewExecutionResult +from dbally.llm_client.openai_client import OpenAIClient @dataclass class Candidate: @@ -65,11 +61,11 @@ async def build_filter_node(self, node: syntax.Node) -> Callable[[Any], bool]: return lambda x: not child(x) raise ValueError(f"Unsupported grammar: {node}") - def execute(self, dry_run: bool = False) -> ExecutionResult: + def execute(self, dry_run: bool = False) -> ViewExecutionResult: print(self._filter) filtered_data = list(filter(self._filter, self.get_data())) - return ExecutionResult(results=filtered_data, context={}) + return ViewExecutionResult(results=filtered_data, context={}) class CandidateView(FilteredIterableBaseView): def get_data(self) -> Iterable: @@ -103,8 +99,9 @@ def from_country(self, country: str) -> Callable[[Candidate], bool]: return lambda x: x.country == country async def main(): - collection = dbally.create_collection("recruitment") - dbally.use_event_handler(CLIEventHandler()) + llm = OpenAIClient(model_name="gpt-3.5-turbo") + event_handlers = [CLIEventHandler()] + collection = dbally.create_collection("recruitment", llm, event_handlers=event_handlers) collection.add(CandidateView) result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") diff --git a/docs/how-to/log_runs_to_langsmith.md b/docs/how-to/log_runs_to_langsmith.md index 54237f01..12dea590 100644 --- a/docs/how-to/log_runs_to_langsmith.md +++ b/docs/how-to/log_runs_to_langsmith.md @@ -21,13 +21,17 @@ This guide aims to demonstrate the process of logging the executions of db-ally ## Logging runs to LangSmith -Enabling LangSmith integration can be done by registering a prepared [EventHandler](../reference/event_handlers/index.md) using the `dbally.use_event_handler` method. +Enabling LangSmith integration can be done by passing a prepared [EventHandler](../reference/event_handlers/index.md) when creating a db-ally collection: ```python import dbally from dbally.audit.event_handlers.langsmith_event_handler import LangSmithEventHandler -dbally.use_event_handler(LangSmithEventHandler(api_key="your_api_key")) +my_collection = dbally.create_collection( + "collection_name", + llm_client=OpenAIClient(), + event_handlers=[LangSmithEventHandler(api_key="your_api_key")], +) ``` -After this, all the runs of db-ally will be logged to LangSmith. +After this, all the queries against the collection will be logged to LangSmith. diff --git a/docs/how-to/pandas_views.md b/docs/how-to/pandas_views.md index 4e5ea606..7db4d192 100644 --- a/docs/how-to/pandas_views.md +++ b/docs/how-to/pandas_views.md @@ -74,8 +74,10 @@ To use the view, you need to create a [Collection](../concepts/collections.md) a ```python import dbally +from dbally.llm_client.openai_client import OpenAIClient -collection = dbally.create_collection("recruitment") +llm = OpenAIClient(model_name="gpt-3.5-turbo") +collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(CANDIDATE_DATA)) result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") diff --git a/docs/how-to/pandas_views_code.py b/docs/how-to/pandas_views_code.py index a562e9fa..fe17f232 100644 --- a/docs/how-to/pandas_views_code.py +++ b/docs/how-to/pandas_views_code.py @@ -8,11 +8,8 @@ from dbally import decorators, DataFrameBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.llm_client.openai_client import OpenAIClient -dbally.use_openai_llm( - openai_api_key=os.environ["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo", -) class CandidateView(DataFrameBaseView): """ @@ -49,8 +46,8 @@ def senior_data_scientist_position(self) -> pd.Series: ]) async def main(): - collection = dbally.create_collection("recruitment") - dbally.use_event_handler(CLIEventHandler()) + llm = OpenAIClient(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(CANDIDATE_DATA)) result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") diff --git a/docs/how-to/sql_views.md b/docs/how-to/sql_views.md index a8e66917..61c3343b 100644 --- a/docs/how-to/sql_views.md +++ b/docs/how-to/sql_views.md @@ -5,7 +5,7 @@ db-ally is a Python library that allows you to use natural language to query var ## Views The majority of the db-ally's codebase is independent of any particular kind of data source. The part that is specific to a data source is the view. A [view](../concepts/views.md) is a class that defines how to interact with a data source. It contains methods that define how to retrieve data from the data source and how to filter the data in response to natural language queries. -There are several methods for creating a view that connects to a SQL database, including creating a custom view from scratch. However, in most cases the easiest will be to use the `SqlAlchemyBaseView` class provided by db-ally. This class is designed to work with [SQLAlchemy](https://www.sqlalchemy.org/), a popular SQL toolkit and Object-Relational Mapping (ORM) library for Python. To define your view, you will need to produce a class that inherits from `SqlAlchemyBaseView`and implement the `get_select` method, which returns a [SQLAlchemy `Select`](https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.Select) object: +There are several methods for creating a view that connects to a SQL database, including [creating a custom view from scratch](./custom_views.md). However, in most cases the easiest will be to use the [`SqlAlchemyBaseView`][dbally.SqlAlchemyBaseView] class provided by db-ally. This class is designed to work with [SQLAlchemy](https://www.sqlalchemy.org/), a popular SQL toolkit and Object-Relational Mapping (ORM) library for Python. To define your view, you will need to produce a class that inherits from `SqlAlchemyBaseView`and implement the `get_select` method, which returns a [SQLAlchemy `Select`](https://docs.sqlalchemy.org/en/20/core/selectable.html#sqlalchemy.sql.expression.Select) object: ```python from dbally import SqlAlchemyBaseView @@ -84,7 +84,9 @@ engine = create_engine('sqlite:///candidates.db') Once you have defined your view and created an engine, you can register the view with db-ally. You do this by creating a collection and adding the view to it: ```python -my_collection = dbally.create_collection("collection_name") +from dbally.llm_client.openai_client import OpenAIClient + +my_collection = dbally.create_collection("collection_name", llm_client=OpenAIClient()) my_collection.add(CandidateView, lambda: CandidateView(engine)) ``` diff --git a/docs/how-to/update_similarity_indexes.md b/docs/how-to/update_similarity_indexes.md index f996b9fa..dc773e3b 100644 --- a/docs/how-to/update_similarity_indexes.md +++ b/docs/how-to/update_similarity_indexes.md @@ -61,8 +61,9 @@ If you have a [collection](../concepts/collections.md) and want to update Simila ```python from db_ally import create_collection +from db_ally.llm_client.openai_client import OpenAIClient -my_collection = create_collection("collection_name") +my_collection = create_collection("collection_name", llm_client=OpenAIClient()) # ... add views to the collection diff --git a/docs/how-to/use_custom_llm.md b/docs/how-to/use_custom_llm.md deleted file mode 100644 index e69de29b..00000000 diff --git a/docs/quickstart/index.md b/docs/quickstart/index.md index 5897902a..1096f02d 100644 --- a/docs/quickstart/index.md +++ b/docs/quickstart/index.md @@ -9,8 +9,8 @@ We will cover the following topics: - [Installation](#installation) - [Database Configuration](#configuring-the-database) -- [OpenAI Access Configuration](#configuring-openai-access) - [View Definition](#defining-the-views) +- [OpenAI Access Configuration](#configuring-openai-access) - [Collection Definition](#defining-the-collection) - [Query Execution](#running-the-query) @@ -50,19 +50,6 @@ Base.prepare(autoload_with=engine) Candidate = Base.classes.candidates ``` -## OpenAI Access Configuration - -To use OpenAI's GPT, configure db-ally and provide your OpenAI API key: - -```python -import dbally - -dbally.use_openai_llm( - openai_api_key="...", - model_name="gpt-3.5-turbo", -) -``` - ## View Definition To use db-ally, define the views you want to use. A [view](../concepts/views.md) is a class that specifies what to select from the database and includes methods that the AI model can use to filter rows. These methods are known as "filters". @@ -112,15 +99,27 @@ By setting up these filters, you enable the LLM to fetch candidates while option !!! note The `from_country` filter defined above supports only exact matches, which is not always ideal. Thankfully, db-ally comes with a solution for this problem - Similarity Indexes, which can be used to find the most similar value from the ones available. Refer to [Quickstart Part 2: Semantic Similarity](./quickstart2.md) for an example of using semantic similarity when filtering candidates by country. +## OpenAI Access Configuration + +To use OpenAI's GPT, configure db-ally and provide your OpenAI API key: + +```python +from dbally.llm_client.openai_client import OpenAIClient + +llm = OpenAIClient(model_name="gpt-3.5-turbo", api_key="...") +``` + +Replace `...` with your OpenAI API key. Alternatively, you can set the `OPENAI_API_KEY` environment variable with your API key and omit the `api_key` parameter altogether. + ## Collection Definition -Next, create a db-ally collection. A [collection](../concepts/collections.md) is an object where you register views and execute queries. +Next, create a db-ally collection. A [collection](../concepts/collections.md) is an object where you register views and execute queries. It also requires an AI model to use for generating [IQL queries](../concepts/iql.md) (in this case, the GPT model defined above). ```python import dbally async def main(): - collection = dbally.create_collection("recruitment") + collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) ``` diff --git a/docs/quickstart/quickstart2_code.py b/docs/quickstart/quickstart2_code.py index b83bf4c8..9fa8cd4a 100644 --- a/docs/quickstart/quickstart2_code.py +++ b/docs/quickstart/quickstart2_code.py @@ -12,6 +12,7 @@ from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.llm_client.openai_client import OpenAIClient engine = create_engine('sqlite:///candidates.db') @@ -20,11 +21,6 @@ Candidate = Base.classes.candidates -dbally.use_openai_llm( - openai_api_key=os.environ["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo", -) - country_similarity = SimilarityIndex( fetcher=SimpleSqlAlchemyFetcher( engine, @@ -77,8 +73,8 @@ def from_country(self, country: Annotated[str, country_similarity]) -> sqlalchem async def main(): await country_similarity.update() - collection = dbally.create_collection("recruitment") - dbally.use_event_handler(CLIEventHandler()) + llm = OpenAIClient(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) result = await collection.ask("Find someone from the United States with more than 2 years of experience.") diff --git a/docs/quickstart/quickstart3_code.py b/docs/quickstart/quickstart3_code.py index 1f408488..1d474e8b 100644 --- a/docs/quickstart/quickstart3_code.py +++ b/docs/quickstart/quickstart3_code.py @@ -10,9 +10,9 @@ import pandas as pd from dbally import decorators, SqlAlchemyBaseView, DataFrameBaseView, ExecutionResult -from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler from dbally.similarity import SimpleSqlAlchemyFetcher, FaissStore, SimilarityIndex from dbally.embedding_client.openai import OpenAiEmbeddingClient +from dbally.llm_client.openai_client import OpenAIClient engine = create_engine('sqlite:///candidates.db') @@ -21,10 +21,6 @@ Candidate = Base.classes.candidates -dbally.use_openai_llm( - openai_api_key=os.environ["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo", -) country_similarity = SimilarityIndex( fetcher=SimpleSqlAlchemyFetcher( @@ -126,8 +122,8 @@ def display_results(result: ExecutionResult): async def main(): await country_similarity.update() - collection = dbally.create_collection("recruitment") - # dbally.use_event_handler(CLIEventHandler()) + llm = OpenAIClient(model_name="gpt-3.5-turbo") + collection = dbally.create_collection("recruitment", llm) collection.add(CandidateView, lambda: CandidateView(engine)) collection.add(JobView, lambda: JobView(jobs_data)) diff --git a/docs/quickstart/quickstart_code.py b/docs/quickstart/quickstart_code.py index fa5795dc..77d9d87e 100644 --- a/docs/quickstart/quickstart_code.py +++ b/docs/quickstart/quickstart_code.py @@ -9,6 +9,7 @@ from dbally import decorators, SqlAlchemyBaseView from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler +from dbally.llm_client.openai_client import OpenAIClient engine = create_engine('sqlite:///candidates.db') @@ -18,11 +19,6 @@ Candidate = Base.classes.candidates -dbally.use_openai_llm( - openai_api_key=os.environ["OPENAI_API_KEY"], - model_name="gpt-3.5-turbo", -) - class CandidateView(SqlAlchemyBaseView): """ A view for retrieving candidates from the database. @@ -58,8 +54,9 @@ def from_country(self, country: str) -> sqlalchemy.ColumnElement: return Candidate.country == country async def main(): - collection = dbally.create_collection("recruitment") - dbally.use_event_handler(CLIEventHandler()) + llm = OpenAIClient(model_name="gpt-3.5-turbo") + + collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()]) collection.add(CandidateView, lambda: CandidateView(engine)) result = await collection.ask("Find me French candidates suitable for a senior data scientist position.") diff --git a/docs/reference/event_handlers/index.md b/docs/reference/event_handlers/index.md index efcf580c..ae69bc0d 100644 --- a/docs/reference/event_handlers/index.md +++ b/docs/reference/event_handlers/index.md @@ -7,7 +7,7 @@ db-ally provides an `EventHandler` abstract class that can be used to log the ru ## Lifecycle -Each run of [dbally.Collection.ask][dbally.Collection.ask] will trigger all instances of EventHandler registered using [`dbally.use_event_handler`][dbally.use_event_handler]. +Each run of [dbally.Collection.ask][dbally.Collection.ask] will trigger all instances of EventHandler that were passed to the Collection's constructor (or the [dbally.create_collection][dbally.create_collection] function). 1. `EventHandler.request_start` is called with [RequestStart][dbally.data_models.audit.RequestStart], it can return a context object that will be passed to next calls. diff --git a/docs/reference/index.md b/docs/reference/index.md index 464cceb9..504e5122 100644 --- a/docs/reference/index.md +++ b/docs/reference/index.md @@ -2,7 +2,3 @@ ::: dbally.create_collection - -::: dbally.use_openai_llm - -::: dbally.use_event_handler diff --git a/docs/reference/llm/index.md b/docs/reference/llm/index.md index e649b185..437fc577 100644 --- a/docs/reference/llm/index.md +++ b/docs/reference/llm/index.md @@ -5,7 +5,7 @@ Concrete implementations for specific LLMs, like OpenAILLMClient, can be found in this section of our documentation. -[`LLMClient` configuration options]((./llm_options.md)) include: template, format, event tracker, and optional generation parameters like +[`LLMClient` configuration options][dbally.data_models.llm_options.LLMOptions] include: template, format, event tracker, and optional generation parameters like frequency_penalty, max_tokens, and temperature. It constructs prompts using the [`PromptBuilder`](./prompt_builder.md) instance. diff --git a/examples/recruiting.py b/examples/recruiting.py index 2aa987fa..db51274f 100644 --- a/examples/recruiting.py +++ b/examples/recruiting.py @@ -99,10 +99,11 @@ async def recruiting_example(db_description: str, benchmark: Benchmark = example benchmark (Benchmark, optional): Benchmark containing set of questions. Defaults to example_benchmark. """ - dbally.use_openai_llm() - dbally.use_event_handler(CLIEventHandler()) - - recruitment_db = dbally.create_collection("recruitment") + recruitment_db = dbally.create_collection( + "recruitment", + llm_client=OpenAIClient(), + event_handlers=[CLIEventHandler()], + ) recruitment_db.add(RecruitmentView, lambda: RecruitmentView(ENGINE)) event_tracker = EventTracker() diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index d8c4a9b4..2305bf34 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -8,14 +8,12 @@ from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from .__version__ import __version__ -from ._main import create_collection, use_event_handler, use_openai_llm +from ._main import create_collection from .collection import Collection __all__ = [ "__version__", "create_collection", - "use_openai_llm", - "use_event_handler", "decorators", "MethodsBaseView", "SqlAlchemyBaseView", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index 2de140b1..08f430ab 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -4,46 +4,15 @@ from .collection import Collection from .iql_generator.iql_generator import IQLGenerator from .llm_client.base import LLMClient -from .llm_client.openai_client import OpenAIClient from .nl_responder.nl_responder import NLResponder from .view_selection.base import ViewSelector from .view_selection.llm_view_selector import LLMViewSelector -default_llm_client: Optional[LLMClient] = None -default_event_handlers: List[EventHandler] = [] - - -def use_openai_llm( - model_name: str = "gpt-3.5-turbo", - openai_api_key: Optional[str] = None, -) -> None: - """ - Set the default LLM client to the [OpenAIClient](llm/openai.md). - - Args: - model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used. - openai_api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used" - """ - global default_llm_client # pylint: disable=W0603 - default_llm_client = OpenAIClient(model_name=model_name, api_key=openai_api_key) - - -def use_event_handler(event_handler: EventHandler) -> None: - """ - Add the given [event handler](event_handlers/index.md) to the list of default event handlers that \ - are used by all collections. - - Args: - event_handler: [event handler](event_handlers/index.md) - """ - global default_event_handlers # pylint: disable=W0602 - default_event_handlers.append(event_handler) - def create_collection( name: str, + llm_client: LLMClient, event_handlers: Optional[List[EventHandler]] = None, - llm_client: Optional[LLMClient] = None, view_selector: Optional[ViewSelector] = None, iql_generator: Optional[IQLGenerator] = None, nl_responder: Optional[NLResponder] = None, @@ -60,19 +29,19 @@ def create_collection( ```python from dbally import create_collection - from dbally.audit.event_handlers.cli import CLIEventHandler + from dbally.llm_client.openai_client import OpenAIClient - collection = create_collection("my_collection", event_handlers=[CLIEventHandler()]) + collection = create_collection("my_collection", llm_client=OpenAIClient()) ``` Args: name: Name of the collection is available for [Event handlers](event_handlers/index.md) and is\ used to distinguish different db-ally runs. - event_handlers: Event handlers used by the collection during query executions. Can be used to\ - log events as [CLIEventHandler](event_handlers/cli.md) or to validate system performance as\ - [LangSmithEventHandler](event_handlers/langsmith.md). llm_client: LLM client used by the collection to generate views and respond to natural language\ - queries. If None, the default LLM client will be used. + queries. + event_handlers: Event handlers used by the collection during query executions. Can be used to\ + log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance as\ + [LangSmithEventHandler](event_handlers/langsmith_handler.md). view_selector: View selector used by the collection to select the best view for the given query.\ If None, a new instance of [LLMViewSelector][dbally.view_selection.llm_view_selector.LLMViewSelector]\ will be used. @@ -88,15 +57,10 @@ def create_collection( Raises: ValueError: if default LLM client is not configured """ - llm_client = llm_client or default_llm_client - - if not llm_client: - raise ValueError("LLM client is not configured. Pass the llm_client argument or set the default llm client") - view_selector = view_selector or LLMViewSelector(llm_client=llm_client) iql_generator = iql_generator or IQLGenerator(llm_client=llm_client) nl_responder = nl_responder or NLResponder(llm_client=llm_client) - event_handlers = event_handlers or default_event_handlers + event_handlers = event_handlers or [] return Collection( name, diff --git a/src/dbally/audit/event_handlers/cli_event_handler.py b/src/dbally/audit/event_handlers/cli_event_handler.py index e3ae6e97..fba1d3be 100644 --- a/src/dbally/audit/event_handlers/cli_event_handler.py +++ b/src/dbally/audit/event_handlers/cli_event_handler.py @@ -26,7 +26,7 @@ class CLIEventHandler(EventHandler): import dbally from dbally.audit.event_handlers.cli_event_handler import CLIEventHandler - dbally.use_event_handler(CLIEventHandler()) + my_collection = dbally.create_collection("my_collection", llm, event_handlers=[CLIEventHandler()]) ``` After using `CLIEventHandler`, during every `Collection.ask` execution you will see output similar to the one below: diff --git a/src/dbally/collection.py b/src/dbally/collection.py index 5d396b36..f3c05c7f 100644 --- a/src/dbally/collection.py +++ b/src/dbally/collection.py @@ -63,8 +63,8 @@ def __init__( iql_generator: Objects that translates natural language to the\ [Intermediate Query Language (IQL)](../concepts/iql.md) event_handlers: Event handlers used by the collection during query executions. Can be used\ - to log events as [CLIEventHandler](event_handlers/cli.md) or to validate system performance\ - as [LangSmithEventHandler](event_handlers/langsmith.md). + to log events as [CLIEventHandler](event_handlers/cli_handler.md) or to validate system performance\ + as [LangSmithEventHandler](event_handlers/langsmith_handler.md). nl_responder: Object that translates RAW response from db-ally into natural language. n_retries: IQL generator may produce invalid IQL. If this is the case this argument specifies\ how many times db-ally will try to regenerate it. Previous try with the error message is\ diff --git a/src/dbally/llm_client/base.py b/src/dbally/llm_client/base.py index b8d07411..3e4abf9e 100644 --- a/src/dbally/llm_client/base.py +++ b/src/dbally/llm_client/base.py @@ -16,7 +16,7 @@ class LLMClient(abc.ABC): It accepts parameters including the template, format, event tracker, and optional generation parameters like frequency_penalty, max_tokens, and temperature - (the full list of options is provided by the [`LLMOptions` class](llm_options.md)). + (the full list of options is provided by the [`LLMOptions` class][dbally.data_models.llm_options.LLMOptions]). It constructs a prompt using the `PromptBuilder` instance and generates text using the `self.call` method. """ diff --git a/src/dbally/llm_client/openai_client.py b/src/dbally/llm_client/openai_client.py index ade6a6d6..3b730d6d 100644 --- a/src/dbally/llm_client/openai_client.py +++ b/src/dbally/llm_client/openai_client.py @@ -11,9 +11,13 @@ class OpenAIClient(LLMClient): `OpenAIClient` is a class designed to interact with OpenAI's language model (LLM) endpoints, particularly for the GPT models. + Args: + model_name: Name of the [OpenAI's model](https://platform.openai.com/docs/models) to be used, + default is "gpt-3.5-turbo". + api_key: OpenAI's API key. If None OPENAI_API_KEY environment variable will be used """ - def __init__(self, model_name: str, api_key: Optional[str] = None) -> None: + def __init__(self, model_name: str = "gpt-3.5-turbo", api_key: Optional[str] = None) -> None: try: from openai import AsyncOpenAI # pylint: disable=import-outside-toplevel except ImportError as exc: diff --git a/src/dbally/similarity/faiss_store.py b/src/dbally/similarity/faiss_store.py index 24704106..7839861f 100644 --- a/src/dbally/similarity/faiss_store.py +++ b/src/dbally/similarity/faiss_store.py @@ -30,7 +30,7 @@ def __init__( max_distance: The maximum distance between two text embeddings to be considered similar. embedding_client: The client to use for creating text embeddings. index_type: The type of Faiss index to use. Defaults to faiss.IndexFlatL2. See - [https://github.com/facebookresearch/faiss/wiki/Faiss-indexes](Faiss wiki) for more information. + [Faiss wiki](https://github.com/facebookresearch/faiss/wiki/Faiss-indexes) for more information. """ super().__init__() self.index_dir = index_dir diff --git a/tests/unit/test_view_selector.py b/tests/unit/test_view_selector.py index 748d0bca..09a7f1d3 100644 --- a/tests/unit/test_view_selector.py +++ b/tests/unit/test_view_selector.py @@ -1,39 +1,39 @@ # mypy: disable-error-code="empty-body" +# pylint: disable=missing-return-doc +from typing import Dict from unittest.mock import AsyncMock, Mock import pytest import dbally from dbally.audit.event_tracker import EventTracker +from dbally.llm_client.base import LLMClient from dbally.view_selection.llm_view_selector import LLMViewSelector +from .mocks import MockLLMClient from .test_collection import MockView1, MockView2 @pytest.fixture -def llm_client(): - llm_client = Mock() - llm_client.text_generation = AsyncMock(return_value="MockView1") - return llm_client +def llm_client() -> LLMClient: + """Return a mock LLM client.""" + client = Mock() + client.text_generation = AsyncMock(return_value="MockView1") + return client @pytest.fixture -def event_tracker(): - return EventTracker() - - -@pytest.fixture -def views(): - dbally.use_openai_llm(openai_api_key="sk-fake") - mock_collection = dbally.create_collection("mock_collection") +def views() -> Dict[str, str]: + """Return a map of view names + view descriptions to be used in the test.""" + mock_collection = dbally.create_collection("mock_collection", llm_client=MockLLMClient()) mock_collection.add(MockView1) mock_collection.add(MockView2) return mock_collection.list() @pytest.mark.asyncio -async def test_view_selection(llm_client, event_tracker, views): +async def test_view_selection(llm_client: LLMClient, views: Dict[str, str]): view_selector = LLMViewSelector(llm_client) - view = await view_selector.select_view("Mock question?", views, event_tracker) + view = await view_selector.select_view("Mock question?", views, event_tracker=EventTracker()) assert view == "MockView1"