diff --git a/benchmark/dbally_benchmark/e2e_benchmark.py b/benchmark/dbally_benchmark/e2e_benchmark.py index 6ea01c36..da20857d 100644 --- a/benchmark/dbally_benchmark/e2e_benchmark.py +++ b/benchmark/dbally_benchmark/e2e_benchmark.py @@ -21,7 +21,7 @@ from sqlalchemy import create_engine import dbally -from dbally._collection import Collection +from dbally.collection import Collection from dbally.data_models.prompts.iql_prompt_template import default_iql_template from dbally.data_models.prompts.view_selector_prompt_template import default_view_selector_template from dbally.utils.errors import NoViewFoundError, UnsupportedQueryError diff --git a/docs/reference/collection.md b/docs/reference/collection.md index 1a0cdd44..85a0b75c 100644 --- a/docs/reference/collection.md +++ b/docs/reference/collection.md @@ -5,4 +5,6 @@ ::: dbally.Collection -::: dbally.data_models.execution_result.ExecutionResult \ No newline at end of file +::: dbally.data_models.execution_result.ExecutionResult + +::: dbally.collection.IndexUpdateError \ No newline at end of file diff --git a/src/dbally/__init__.py b/src/dbally/__init__.py index 8c773013..d8c4a9b4 100644 --- a/src/dbally/__init__.py +++ b/src/dbally/__init__.py @@ -8,8 +8,8 @@ from dbally.views.sqlalchemy_base import SqlAlchemyBaseView from .__version__ import __version__ -from ._collection import Collection from ._main import create_collection, use_event_handler, use_openai_llm +from .collection import Collection __all__ = [ "__version__", diff --git a/src/dbally/_main.py b/src/dbally/_main.py index b32fb88b..c54ece2d 100644 --- a/src/dbally/_main.py +++ b/src/dbally/_main.py @@ -1,7 +1,7 @@ from typing import List, Optional -from ._collection import Collection from .audit.event_handlers.base import EventHandler +from .collection import Collection from .iql_generator.iql_generator import IQLGenerator from .llm_client.base import LLMClient from .llm_client.openai_client import OpenAIClient diff --git a/src/dbally/assistants/openai.py b/src/dbally/assistants/openai.py index e9555402..4f715787 100644 --- a/src/dbally/assistants/openai.py +++ b/src/dbally/assistants/openai.py @@ -4,8 +4,8 @@ from openai.types.beta.threads import RequiredActionFunctionToolCall -from dbally._collection import Collection from dbally.assistants.base import AssistantAdapter, FunctionCallingError, FunctionCallState +from dbally.collection import Collection from dbally.utils.errors import UnsupportedQueryError _DBALLY_INFO = "Dbally has access to the following database views: " diff --git a/src/dbally/_collection.py b/src/dbally/collection.py similarity index 87% rename from src/dbally/_collection.py rename to src/dbally/collection.py index 153c97b8..c36be374 100644 --- a/src/dbally/_collection.py +++ b/src/dbally/collection.py @@ -18,6 +18,23 @@ from dbally.views.base import AbstractBaseView +class IndexUpdateError(Exception): + """ + Exception for when updating any of the Collection's similarity indexes fails. + + Provides a dictionary mapping failed indexes to their + respective exceptions as the `failed_indexes` attribute. + """ + + def __init__(self, message: str, failed_indexes: Dict[AbstractSimilarityIndex, Exception]) -> None: + """ + Args: + failed_indexes: Dictionary mapping failed indexes to their respective exceptions. + """ + self.failed_indexes = failed_indexes + super().__init__(message) + + class Collection: """ Collection is a container for a set of views that can be used by db-ally to answer user questions. @@ -250,7 +267,21 @@ def get_similarity_indexes(self) -> Dict[AbstractSimilarityIndex, List[Tuple[str async def update_similarity_indexes(self) -> None: """ Update all similarity indexes from all views in the collection. + + Raises: + IndexUpdateError: if updating any of the indexes fails. The exception provides `failed_indexes` attribute, + a dictionary mapping failed indexes to their respective exceptions. Indexes not present in + the dictionary were updated successfully. """ indexes = self.get_similarity_indexes() update_corutines = [index.update() for index in indexes] - await asyncio.gather(*update_corutines) + results = await asyncio.gather(*update_corutines, return_exceptions=True) + failed_indexes = { + index: exception for index, exception in zip(indexes, results) if isinstance(exception, Exception) + } + if failed_indexes: + failed_locations = [loc for index in failed_indexes for loc in indexes[index]] + description = ", ".join( + f"{view_name}.{method_name}.{param_name}" for view_name, method_name, param_name in failed_locations + ) + raise IndexUpdateError(f"Failed to update similarity indexes for {description}", failed_indexes) diff --git a/tests/unit/similarity/test_detector.py b/tests/unit/similarity/test_detector.py index 39f093b1..5c5a406c 100644 --- a/tests/unit/similarity/test_detector.py +++ b/tests/unit/similarity/test_detector.py @@ -78,7 +78,7 @@ def test_detector_with_module_not_found(): """ with pytest.raises(SimilarityIndexDetectorException) as exc: SimilarityIndexDetector.from_path("not_found") - assert exc.value.message == "Module not_found not found." + assert exc.value.message == "Module not_found not found." def test_detector_with_empty_module(): @@ -97,7 +97,7 @@ def test_detector_with_view_not_found(): detector = SimilarityIndexDetector.from_path("sample_module.submodule:NotFoundView") with pytest.raises(SimilarityIndexDetectorException) as exc: detector.list_views() - assert exc.value.message == "View NotFoundView not found." + assert exc.value.message == "View NotFoundView not found in module sample_module.submodule." def test_detector_with_method_not_found(): @@ -107,7 +107,7 @@ def test_detector_with_method_not_found(): detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.not_found") with pytest.raises(SimilarityIndexDetectorException) as exc: detector.list_indexes() - assert exc.value.message == "Filter method not_found not found in view FooView." + assert exc.value.message == "Filter method not_found not found in view FooView." def test_detector_with_argument_not_found(): @@ -117,4 +117,4 @@ def test_detector_with_argument_not_found(): detector = SimilarityIndexDetector.from_path("sample_module.submodule:FooView.method_bar.not_found") with pytest.raises(SimilarityIndexDetectorException) as exc: detector.list_indexes() - assert exc.value.message == "Argument not_found not found in method FooView.method_bar." + assert exc.value.message == "Argument not_found not found in method method_bar." diff --git a/tests/unit/test_collection.py b/tests/unit/test_collection.py index 034f7250..e8b6639f 100644 --- a/tests/unit/test_collection.py +++ b/tests/unit/test_collection.py @@ -6,7 +6,7 @@ import pytest from typing_extensions import Annotated -from dbally._collection import Collection +from dbally.collection import Collection, IndexUpdateError from dbally.iql._exceptions import IQLError from dbally.utils.errors import NoViewFoundError from dbally.views.base import ExposedFunction, MethodParamWithTyping, ViewExecutionResult @@ -280,7 +280,7 @@ async def test_ask_feedback_loop(collection_feedback: Collection) -> None: ValueError("err3"), ValueError("err4"), ] - with patch("dbally._collection.IQLQuery.parse") as mock_iql_query: + with patch("dbally.collection.IQLQuery.parse") as mock_iql_query: mock_iql_query.side_effect = errors await collection_feedback.ask("Mock question") @@ -406,3 +406,34 @@ async def test_update_similarity_indexes( await collection.update_similarity_indexes() assert foo_index.update_count == 1 assert bar_index.update_count == 1 + + +async def test_update_similarity_indexes_error( + similarity_classes: Tuple[MockSimilarityIndex, MockSimilarityIndex, Type[MockViewBase], Type[MockViewBase]], + collection: Collection, +) -> None: + """ + Tests that the update_similarity_indexes method does not raise an exception when the update method of the similarity + indexes raises an exception + """ + ( + foo_index, + bar_index, + MockViewWithSimilarity, # pylint: disable=invalid-name + MockViewWithSimilarity2, # pylint: disable=invalid-name + ) = similarity_classes + collection.add(MockViewWithSimilarity) + collection.add(MockViewWithSimilarity2) + + foo_exception = ValueError("foo") + foo_index.update = AsyncMock(side_effect=foo_exception) # type: ignore + with pytest.raises(IndexUpdateError) as e: + await collection.update_similarity_indexes() + assert ( + str(e.value) == "Failed to update similarity indexes for MockViewWithSimilarity.test_filter.dog, " + "MockViewWithSimilarity2.test_filter.monkey" + ) + assert e.value.failed_indexes == { + foo_index: foo_exception, + } + assert bar_index.update_count == 1