Skip to content

Commit

Permalink
feat: allow collection fallbacks (#59)
Browse files Browse the repository at this point in the history
  • Loading branch information
karllu3 authored Jul 22, 2024
1 parent ea687f8 commit a2ef774
Show file tree
Hide file tree
Showing 12 changed files with 574 additions and 57 deletions.
27 changes: 27 additions & 0 deletions docs/concepts/collections.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@ my_collection.ask("Find me Italian recipes for soups")

In this scenario, the LLM first determines the most suitable view to address the query, and then that view is used to pull the relevant data.

Sometimes, the selected view does not match question (LLM select wrong view) and will raise an error. In such situations, the fallback collections can be used.
This will cause a next view selection, but from the fallback collection.

```python
llm = LiteLLM(model_name="gpt-3.5-turbo")
user_collection = dbally.create_collection("candidates", llm)
user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine))
user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine()))
user_collection.add(CandidateView, lambda: (candidate_view_with_similarity_store.engine))

fallback_collection = dbally.create_collection("freeform candidates", llm)
fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine))
user_collection.set_fallback(fallback_collection)
```
The fallback collection process the same question with declared set of views. The fallback collection could be chained.

```python
second_fallback_collection = dbally.create_collection("recruitment", llm)
second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine))

fallback_collection.set_fallback(second_fallback_collection)

```




!!! info
The result of a query is an [`ExecutionResult`][dbally.collection.results.ExecutionResult] object, which contains the data fetched by the view. It contains a `results` attribute that holds the actual data, structured as a list of dictionaries. The exact structure of these dictionaries depends on the view that was used to fetch the data, which can be obtained by looking at the `view_name` attribute of the `ExecutionResult` object.

Expand Down
3 changes: 2 additions & 1 deletion examples/recruiting/candidate_view_with_similarity_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from sqlalchemy.ext.automap import automap_base
from typing_extensions import Annotated

from dbally import SqlAlchemyBaseView, decorators
from dbally.embeddings.litellm import LiteLLMEmbeddingClient
from dbally.similarity import FaissStore, SimilarityIndex, SimpleSqlAlchemyFetcher
from dbally.views import decorators
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

engine = create_engine("sqlite:///examples/recruiting/data/candidates.db")

Expand Down
42 changes: 42 additions & 0 deletions examples/recruiting/candidates_freeform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# pylint: disable=missing-return-doc, missing-param-doc, missing-function-docstring
from typing import List

from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base

from dbally.views.freeform.text2sql import BaseText2SQLView, ColumnConfig, TableConfig

engine = create_engine("sqlite:///examples/recruiting/data/candidates.db")

_Base = automap_base()
_Base.prepare(autoload_with=engine)
_Candidate = _Base.classes.candidates


class CandidateFreeformView(BaseText2SQLView):
"""
A view for retrieving candidates from the database.
"""

def get_tables(self) -> List[TableConfig]:
"""
Get the tables used by the view.
Returns:
A list of tables.
"""
return [
TableConfig(
name="candidates",
columns=[
ColumnConfig("name", "TEXT"),
ColumnConfig("country", "TEXT"),
ColumnConfig("years_of_experience", "INTEGER"),
ColumnConfig("position", "TEXT"),
ColumnConfig("university", "TEXT"),
ColumnConfig("skills", "TEXT"),
ColumnConfig("tags", "TEXT"),
ColumnConfig("id", "INTEGER PRIMARY KEY"),
],
),
]
2 changes: 1 addition & 1 deletion examples/recruiting/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def is_available_within_months( # pylint: disable=W0602, C0116, W9011
end = start + relativedelta(months=months)
return Candidate.available_from.between(start, end)

def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011
def list_few_shots(self) -> List[FewShotExample]: # pylint: disable=W9011, C0116
return [
FewShotExample(
"Which candidates studied at University of Toronto?",
Expand Down
36 changes: 36 additions & 0 deletions examples/visualize_fallback_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# pylint: disable=missing-function-docstring
import asyncio

from recruiting import candidate_view_with_similarity_store, candidates_freeform
from recruiting.candidate_view_with_similarity_store import CandidateView
from recruiting.candidates_freeform import CandidateFreeformView
from recruiting.cypher_text2sql_view import SampleText2SQLViewCyphers, create_freeform_memory_engine
from recruiting.db import ENGINE as recruiting_engine
from recruiting.views import RecruitmentView

import dbally
from dbally.audit import CLIEventHandler, OtelEventHandler
from dbally.gradio import create_gradio_interface
from dbally.llms.litellm import LiteLLM


async def main():
llm = LiteLLM(model_name="gpt-3.5-turbo")
user_collection = dbally.create_collection("candidates", llm)
user_collection.add(CandidateView, lambda: CandidateView(candidate_view_with_similarity_store.engine))
user_collection.add(SampleText2SQLViewCyphers, lambda: SampleText2SQLViewCyphers(create_freeform_memory_engine()))

fallback_collection = dbally.create_collection("freeform candidates", llm, event_handlers=[OtelEventHandler()])
fallback_collection.add(CandidateFreeformView, lambda: CandidateFreeformView(candidates_freeform.engine))

second_fallback_collection = dbally.create_collection("recruitment", llm, event_handlers=[CLIEventHandler()])
second_fallback_collection.add(RecruitmentView, lambda: RecruitmentView(recruiting_engine))

user_collection.set_fallback(fallback_collection).set_fallback(second_fallback_collection)

gradio_interface = await create_gradio_interface(user_collection=user_collection)
gradio_interface.launch()


if __name__ == "__main__":
asyncio.run(main())
3 changes: 1 addition & 2 deletions src/dbally/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import TYPE_CHECKING, List

from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.exceptions import NoViewFoundError
from dbally.collection.results import ExecutionResult
from dbally.views import decorators
from dbally.views.methods_base import MethodsBaseView
Expand Down Expand Up @@ -40,7 +40,6 @@
"EmbeddingConnectionError",
"EmbeddingResponseError",
"EmbeddingStatusError",
"IndexUpdateError",
"LLMError",
"LLMConnectionError",
"LLMResponseError",
Expand Down
25 changes: 20 additions & 5 deletions src/dbally/audit/event_handlers/cli_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
pprint = print # type: ignore

from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent
from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent

_RICH_FORMATING_KEYWORD_SET = {"green", "orange", "grey", "bold", "cyan"}
_RICH_FORMATING_PATTERN = rf"\[.*({'|'.join(_RICH_FORMATING_KEYWORD_SET)}).*\]"
Expand Down Expand Up @@ -94,6 +94,18 @@ async def event_start(self, event: Event, request_context: None) -> None:
f"[cyan bold]STORE: [grey53]{event.store}\n"
f"[cyan bold]FETCHER: [grey53]{event.fetcher}\n"
)
elif isinstance(event, FallbackEvent):
self._print_syntax(
f"[grey53]\n=======================================\n"
"[grey53]=======================================\n"
f"[orange bold]Fallback event starts \n"
f"[orange bold]Triggering collection: [grey53]{event.triggering_collection_name}\n"
f"[orange bold]Triggering view name: [grey53]{event.triggering_view_name}\n"
f"[orange bold]Error description: [grey53]{event.error_description}\n"
f"[orange bold]Fallback collection name: [grey53]{event.fallback_collection_name}\n"
"[grey53]=======================================\n"
"[grey53]=======================================\n"
)

# pylint: disable=unused-argument
async def event_end(self, event: Optional[Event], request_context: None, event_context: None) -> None:
Expand Down Expand Up @@ -123,8 +135,11 @@ async def request_end(self, output: RequestEnd, request_context: Optional[dict]
output: The output of the request.
request_context: Optional context passed from request_start method
"""
self._print_syntax("[green bold]REQUEST OUTPUT:")
self._print_syntax(f"Number of rows: {len(output.result.results)}")
if output.result:
self._print_syntax("[green bold]REQUEST OUTPUT:")
self._print_syntax(f"Number of rows: {len(output.result.results)}")

if "sql" in output.result.context:
self._print_syntax(f"{output.result.context['sql']}", "psql")
if "sql" in output.result.context:
self._print_syntax(f"{output.result.context['sql']}", "psql")
else:
self._print_syntax("[red bold]No results found")
7 changes: 5 additions & 2 deletions src/dbally/audit/event_handlers/otel_event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from opentelemetry.util.types import AttributeValue

from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.events import Event, LLMEvent, RequestEnd, RequestStart, SimilarityEvent
from dbally.audit.events import Event, FallbackEvent, LLMEvent, RequestEnd, RequestStart, SimilarityEvent

TRACER_NAME = "db-ally.events"
FORBIDDEN_CONTEXT_KEYS = {"filter_mask"}
Expand Down Expand Up @@ -172,8 +172,11 @@ async def event_start(self, event: Event, request_context: SpanHandler) -> SpanH
.set("db-ally.similarity.fetcher", event.fetcher)
.set_input("db-ally.similarity.input", event.input_value)
)
if isinstance(event, FallbackEvent):
with self._new_child_span(request_context, "fallback") as span:
return self._handle_span(span).set("db-ally.error_description", event.error_description)

raise ValueError(f"Unsuported event: {type(event)}")
raise ValueError(f"Unsupported event: {type(event)}")

async def event_end(self, event: Optional[Event], request_context: SpanHandler, event_context: SpanHandler) -> None:
"""
Expand Down
12 changes: 12 additions & 0 deletions src/dbally/audit/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ class SimilarityEvent(Event):
output_value: Optional[str] = None


@dataclass
class FallbackEvent(Event):
"""
FallbackEvent is fired when a processed view/collection raise an exception.
"""

triggering_collection_name: str
triggering_view_name: str
fallback_collection_name: str
error_description: str


@dataclass
class RequestStart:
"""
Expand Down
Loading

0 comments on commit a2ef774

Please sign in to comment.