Skip to content

Commit

Permalink
some renames and stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
micpst committed Sep 23, 2024
1 parent d7026d4 commit e8271ac
Show file tree
Hide file tree
Showing 18 changed files with 59 additions and 163 deletions.
21 changes: 12 additions & 9 deletions src/dbally/collection/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@
import textwrap
import time
from collections import defaultdict
from typing import Callable, Dict, Iterable, List, Optional, Type, TypeVar
from typing import Callable, Dict, List, Optional, Type, TypeVar

import dbally
from dbally.audit.event_handlers.base import EventHandler
from dbally.audit.event_tracker import EventTracker
from dbally.audit.events import FallbackEvent, RequestEnd, RequestStart
from dbally.collection.exceptions import IndexUpdateError, NoViewFoundError
from dbally.collection.results import ExecutionResult, ViewExecutionResult
from dbally.context.context import BaseCallerContext
from dbally.context import Context
from dbally.iql_generator.prompt import UnsupportedQueryError
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
Expand Down Expand Up @@ -228,7 +228,7 @@ async def _ask_view(
event_tracker: EventTracker,
llm_options: Optional[LLMOptions],
dry_run: bool,
contexts: Iterable[BaseCallerContext],
contexts: List[Context],
) -> ViewExecutionResult:
"""
Ask the selected view to provide an answer to the question.
Expand All @@ -247,11 +247,11 @@ async def _ask_view(
view_result = await selected_view.ask(
query=question,
llm=self._llm,
contexts=contexts,
event_tracker=event_tracker,
n_retries=self.n_retries,
dry_run=dry_run,
llm_options=llm_options,
contexts=contexts,
)
return view_result

Expand Down Expand Up @@ -298,9 +298,11 @@ def get_all_event_handlers(self) -> List[EventHandler]:
return self._event_handlers
return list(set(self._event_handlers).union(self._fallback_collection.get_all_event_handlers()))

# pylint: disable=too-many-arguments
async def _handle_fallback(
self,
question: str,
contexts: Optional[List[Context]],
dry_run: bool,
return_natural_response: bool,
llm_options: Optional[LLMOptions],
Expand All @@ -322,7 +324,6 @@ async def _handle_fallback(
Returns:
The result from the fallback collection.
"""
if not self._fallback_collection:
raise caught_exception
Expand All @@ -337,6 +338,7 @@ async def _handle_fallback(
async with event_tracker.track_event(fallback_event) as span:
result = await self._fallback_collection.ask(
question=question,
contexts=contexts,
dry_run=dry_run,
return_natural_response=return_natural_response,
llm_options=llm_options,
Expand All @@ -348,10 +350,10 @@ async def _handle_fallback(
async def ask(
self,
question: str,
contexts: Optional[List[Context]] = None,
dry_run: bool = False,
return_natural_response: bool = False,
llm_options: Optional[LLMOptions] = None,
contexts: Optional[Iterable[BaseCallerContext]] = None,
event_tracker: Optional[EventTracker] = None,
) -> ExecutionResult:
"""
Expand All @@ -366,14 +368,14 @@ async def ask(
Args:
question: question posed using natural language representation e.g\
"What job offers for Data Scientists do we have?"
"What job offers for Data Scientists do we have?"
contexts: list of context objects, each being an instance of
a subclass of Context. May contain contexts irrelevant for the currently processed query.
dry_run: if True, only generate the query without executing it
return_natural_response: if True (and dry_run is False as natural response requires query results),
the natural response will be included in the answer
llm_options: options to use for the LLM client. If provided, these options will be merged with the default
options provided to the LLM client, prioritizing option values other than NOT_GIVEN
contexts: An iterable (typically a list) of context objects, each being an instance of
a subclass of BaseCallerContext. May contain contexts irrelevant for the currently processed query.
event_tracker: Event tracker object for given ask.
Returns:
Expand Down Expand Up @@ -433,6 +435,7 @@ async def ask(
if self._fallback_collection:
result = await self._handle_fallback(
question=question,
contexts=contexts,
dry_run=dry_run,
return_natural_response=return_natural_response,
llm_options=llm_options,
Expand Down
11 changes: 11 additions & 0 deletions src/dbally/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from abc import ABC
from typing import ClassVar


class Context(ABC):
"""
Base class for all contexts that are used to pass additional knowledge about the caller environment to the view.
"""

type_name: ClassVar[str] = "Context"
alias_name: ClassVar[str] = "CONTEXT"
3 changes: 0 additions & 3 deletions src/dbally/context/__init__.py

This file was deleted.

75 changes: 0 additions & 75 deletions src/dbally/context/_utils.py

This file was deleted.

17 changes: 0 additions & 17 deletions src/dbally/context/context.py

This file was deleted.

23 changes: 0 additions & 23 deletions src/dbally/context/exceptions.py

This file was deleted.

4 changes: 2 additions & 2 deletions src/dbally/iql/_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, Generic, List, Optional, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.context.context import BaseCallerContext
from dbally.context import Context
from dbally.iql import syntax
from dbally.iql._exceptions import (
IQLArgumentParsingError,
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
self,
source: str,
allowed_functions: List[ExposedFunction],
allowed_contexts: Optional[List[BaseCallerContext]] = None,
allowed_contexts: Optional[List[Context]] = None,
event_tracker: Optional[EventTracker] = None,
) -> None:
self.source = source
Expand Down
4 changes: 2 additions & 2 deletions src/dbally/iql/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ._processor import IQLAggregationProcessor, IQLFiltersProcessor, IQLProcessor, RootT

if TYPE_CHECKING:
from dbally.context.context import BaseCallerContext
from dbally.context import Context
from dbally.views.exposed_functions import ExposedFunction


Expand All @@ -33,7 +33,7 @@ async def parse(
cls,
source: str,
allowed_functions: List["ExposedFunction"],
allowed_contexts: Optional[List["BaseCallerContext"]] = None,
allowed_contexts: Optional[List["Context"]] = None,
event_tracker: Optional[EventTracker] = None,
) -> Self:
"""
Expand Down
8 changes: 4 additions & 4 deletions src/dbally/iql_generator/iql_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Generic, List, Optional, TypeVar, Union

from dbally.audit.event_tracker import EventTracker
from dbally.context.context import BaseCallerContext
from dbally.context import Context
from dbally.iql import IQLError, IQLQuery
from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery
from dbally.iql_generator.prompt import (
Expand Down Expand Up @@ -67,7 +67,7 @@ async def __call__(
question: str,
filters: List[ExposedFunction],
aggregations: List[ExposedFunction],
contexts: List[BaseCallerContext],
contexts: List[Context],
examples: List[FewShotExample],
llm: LLM,
event_tracker: Optional[EventTracker] = None,
Expand Down Expand Up @@ -146,7 +146,7 @@ async def __call__(
*,
question: str,
methods: List[ExposedFunction],
contexts: List[BaseCallerContext],
contexts: List[Context],
examples: List[FewShotExample],
llm: LLM,
event_tracker: Optional[EventTracker] = None,
Expand Down Expand Up @@ -265,7 +265,7 @@ async def __call__(
*,
question: str,
methods: List[ExposedFunction],
contexts: List[BaseCallerContext],
contexts: List[Context],
examples: List[FewShotExample],
llm: LLM,
llm_options: Optional[LLMOptions] = None,
Expand Down
8 changes: 4 additions & 4 deletions src/dbally/iql_generator/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Optional

from dbally.audit.event_tracker import EventTracker
from dbally.context.context import BaseCallerContext
from dbally.context import Context
from dbally.exceptions import DbAllyError
from dbally.iql._query import IQLAggregationQuery, IQLFiltersQuery
from dbally.prompt.elements import FewShotExample
Expand All @@ -21,7 +21,7 @@ class UnsupportedQueryError(DbAllyError):
async def _iql_filters_parser(
response: str,
allowed_functions: List[ExposedFunction],
allowed_contexts: List[BaseCallerContext],
allowed_contexts: List[Context],
event_tracker: Optional[EventTracker] = None,
) -> IQLFiltersQuery:
"""
Expand Down Expand Up @@ -53,7 +53,7 @@ async def _iql_filters_parser(
async def _iql_aggregation_parser(
response: str,
allowed_functions: List[ExposedFunction],
allowed_contexts: List[BaseCallerContext],
allowed_contexts: List[Context],
event_tracker: Optional[EventTracker] = None,
) -> IQLAggregationQuery:
"""
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(
*,
question: str,
methods: List[ExposedFunction],
contexts: List[BaseCallerContext],
contexts: List[Context],
examples: Optional[List[FewShotExample]] = None,
) -> None:
"""
Expand Down
6 changes: 3 additions & 3 deletions src/dbally/views/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dbally.audit.event_tracker import EventTracker
from dbally.collection.results import ViewExecutionResult
from dbally.context.context import BaseCallerContext
from dbally.context import Context
from dbally.llms.base import LLM
from dbally.llms.clients.base import LLMOptions
from dbally.prompt.elements import FewShotExample
Expand All @@ -25,7 +25,7 @@ async def ask(
self,
query: str,
llm: LLM,
contexts: Optional[List[BaseCallerContext]] = None,
contexts: Optional[List[Context]] = None,
event_tracker: Optional[EventTracker] = None,
n_retries: int = 3,
dry_run: bool = False,
Expand All @@ -38,7 +38,7 @@ async def ask(
query: The natural language query to execute.
llm: The LLM used to execute the query.
contexts: An iterable (typically a list) of context objects, each being
an instance of a subclass of BaseCallerContext.
an instance of a subclass of Context.
event_tracker: The event tracker used to audit the query execution.
n_retries: The number of retries to execute the query in case of errors.
dry_run: If True, the query will not be used to fetch data from the datasource.
Expand Down
Loading

0 comments on commit e8271ac

Please sign in to comment.