Skip to content

Commit

Permalink
Handle exceptions during index update
Browse files Browse the repository at this point in the history
  • Loading branch information
ludwiktrammer committed Apr 11, 2024
1 parent 5026807 commit 22f099f
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 12 deletions.
2 changes: 1 addition & 1 deletion benchmark/dbally_benchmark/e2e_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from sqlalchemy import create_engine

import dbally
from dbally._collection import Collection
from dbally.collection import Collection
from dbally.data_models.prompts.iql_prompt_template import default_iql_template
from dbally.data_models.prompts.view_selector_prompt_template import default_view_selector_template
from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError
Expand Down
4 changes: 3 additions & 1 deletion docs/reference/collection.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@

::: dbally.Collection

::: dbally.data_models.execution_result.ExecutionResult
::: dbally.data_models.execution_result.ExecutionResult

::: dbally.collection.IndexUpdateError
2 changes: 1 addition & 1 deletion src/dbally/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from dbally.views.sqlalchemy_base import SqlAlchemyBaseView

from .__version__ import __version__
from ._collection import Collection
from ._main import create_collection, use_event_handler, use_openai_llm
from .collection import Collection

__all__ = [
"__version__",
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/_main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import List, Optional

from ._collection import Collection
from .audit.event_handlers.base import EventHandler
from .collection import Collection
from .iql_generator.iql_generator import IQLGenerator
from .llm_client.base import LLMClient
from .llm_client.openai_client import OpenAIClient
Expand Down
2 changes: 1 addition & 1 deletion src/dbally/assistants/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

from openai.types.beta.threads import RequiredActionFunctionToolCall

from dbally._collection import Collection
from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState
from dbally.collection import Collection
from dbally.utils.errors import UnsupportedQueryError

_DBALLY_INFO = "Dbally has access to the following database views: "
Expand Down
33 changes: 32 additions & 1 deletion src/dbally/_collection.py → src/dbally/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
from dbally.views.base import AbstractBaseView


class IndexUpdateError(Exception):
"""
Exception for when updating any of the Collection's similarity indexes fails.
Provides a dictionary mapping failed indexes to their
respective exceptions as the `failed_indexes` attribute.
"""

def __init__(self, message: str, failed_indexes: Dict[AbstractSimilarityIndex, Exception]) -> None:
"""
Args:
failed_indexes: Dictionary mapping failed indexes to their respective exceptions.
"""
self.failed_indexes = failed_indexes
super().__init__(message)


class Collection:
"""
Collection is a container for a set of views that can be used by db-ally to answer user questions.
Expand Down Expand Up @@ -250,7 +267,21 @@ def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[Tuple[str
async def update_similarity_indexes(self) -> None:
"""
Update all similarity indexes from all views in the collection.
Raises:
IndexUpdateError: if updating any of the indexes fails. The exception provides `failed_indexes` attribute,
a dictionary mapping failed indexes to their respective exceptions. Indexes not present in
the dictionary were updated successfully.
"""
indexes = self.get_similarity_indexes()
update_corutines = [index.update() for index in indexes]
await asyncio.gather(*update_corutines)
results = await asyncio.gather(*update_corutines, return_exceptions=True)
failed_indexes = {
index: exception for index, exception in zip(indexes, results) if isinstance(exception, Exception)
}
if failed_indexes:
failed_locations = [loc for index in failed_indexes for loc in indexes[index]]
description = ", ".join(
f"{view_name}.{method_name}.{param_name}" for view_name, method_name, param_name in failed_locations
)
raise IndexUpdateError(f"Failed to update similarity indexes for {description}", failed_indexes)
8 changes: 4 additions & 4 deletions tests/unit/similarity/test_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_detector_with_module_not_found():
"""
with pytest.raises(SimilarityIndexDetectorException) as exc:
SimilarityIndexDetector.from_path("not_found")
assert exc.value.message == "Module not_found not found."
assert exc.value.message == "Module not_found not found."


def test_detector_with_empty_module():
Expand All @@ -97,7 +97,7 @@ def test_detector_with_view_not_found():
detector = SimilarityIndexDetector.from_path("sample_module.submodule:NotFoundView")
with pytest.raises(SimilarityIndexDetectorException) as exc:
detector.list_views()
assert exc.value.message == "View NotFoundView not found."
assert exc.value.message == "View NotFoundView not found in module sample_module.submodule."


def test_detector_with_method_not_found():
Expand All @@ -107,7 +107,7 @@ def test_detector_with_method_not_found():
detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.not_found")
with pytest.raises(SimilarityIndexDetectorException) as exc:
detector.list_indexes()
assert exc.value.message == "Filter method not_found not found in view FooView."
assert exc.value.message == "Filter method not_found not found in view FooView."


def test_detector_with_argument_not_found():
Expand All @@ -117,4 +117,4 @@ def test_detector_with_argument_not_found():
detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.method_bar.not_found")
with pytest.raises(SimilarityIndexDetectorException) as exc:
detector.list_indexes()
assert exc.value.message == "Argument not_found not found in method FooView.method_bar."
assert exc.value.message == "Argument not_found not found in method method_bar."
35 changes: 33 additions & 2 deletions tests/unit/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest
from typing_extensions import Annotated

from dbally._collection import Collection
from dbally.collection import Collection, IndexUpdateError
from dbally.iql._exceptions import IQLError
from dbally.utils.errors import NoViewFoundError
from dbally.views.base import ExposedFunction, MethodParamWithTyping, ViewExecutionResult
Expand Down Expand Up @@ -280,7 +280,7 @@ async def test_ask_feedback_loop(collection_feedback: Collection) -> None:
ValueError("err3"),
ValueError("err4"),
]
with patch("dbally._collection.IQLQuery.parse") as mock_iql_query:
with patch("dbally.collection.IQLQuery.parse") as mock_iql_query:
mock_iql_query.side_effect = errors

await collection_feedback.ask("Mock question")
Expand Down Expand Up @@ -406,3 +406,34 @@ async def test_update_similarity_indexes(
await collection.update_similarity_indexes()
assert foo_index.update_count == 1
assert bar_index.update_count == 1


async def test_update_similarity_indexes_error(
similarity_classes: Tuple[MockSimilarityIndex, MockSimilarityIndex, Type[MockViewBase], Type[MockViewBase]],
collection: Collection,
) -> None:
"""
Tests that the update_similarity_indexes method does not raise an exception when the update method of the similarity
indexes raises an exception
"""
(
foo_index,
bar_index,
MockViewWithSimilarity, # pylint: disable=invalid-name
MockViewWithSimilarity2, # pylint: disable=invalid-name
) = similarity_classes
collection.add(MockViewWithSimilarity)
collection.add(MockViewWithSimilarity2)

foo_exception = ValueError("foo")
foo_index.update = AsyncMock(side_effect=foo_exception) # type: ignore
with pytest.raises(IndexUpdateError) as e:
await collection.update_similarity_indexes()
assert (
str(e.value) == "Failed to update similarity indexes for MockViewWithSimilarity.test_filter.dog, "
"MockViewWithSimilarity2.test_filter.monkey"
)
assert e.value.failed_indexes == {
foo_index: foo_exception,
}
assert bar_index.update_count == 1

0 comments on commit 22f099f

Please sign in to comment.