Skip to content

Commit

Permalink
Refactoring/#231 rename train query handler (#236)
Browse files Browse the repository at this point in the history
* Renamed method SQLStage.create_train_query_handler() to create_query_handler()

* Renamed assert_stage_train_query_handler_created to assert_stage_query_handler_created

* Removed prefix "train_" from remaining attributes, mocks, and variables
  • Loading branch information
ckunki authored Dec 4, 2024
1 parent 09f1196 commit 43261bc
Show file tree
Hide file tree
Showing 12 changed files with 260 additions and 259 deletions.
3 changes: 2 additions & 1 deletion doc/changes/changes_0.2.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ Code name:
* #217: Rename dataflow abstraction files
* #219: Applied PTB checks and fixes
* #221: Fixed mypy warnings
* #233: Upgraded pydantic to version 2
* #233: Upgraded pydantic to version 2
* #231: Renamed `TrainQueryHandler` to `SQLStageQueryHandler`
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage
from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import (
SQLStageTrainQueryHandlerInput,
SQLStageQueryHandlerInput,
)
from exasol.analytics.query_handler.query_handler import QueryHandler
from exasol.analytics.query_handler.result import Continue, Finish
Expand Down Expand Up @@ -135,11 +135,11 @@ def _create_current_query_handler(self):
result_bucketfs_location = self._result_bucketfs_location.joinpath(
str(self._current_stage_index)
)
stage_input = SQLStageTrainQueryHandlerInput(
stage_input = SQLStageQueryHandlerInput(
result_bucketfs_location=result_bucketfs_location,
sql_stage_inputs=stage_inputs,
)
self._current_query_handler = self._checked_current_stage.create_train_query_handler(
self._current_query_handler = self._checked_current_stage.create_query_handler(
stage_input, self._current_qh_context
)

Expand Down
6 changes: 3 additions & 3 deletions exasol/analytics/query_handler/graph/stage/sql/sql_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from exasol.analytics.query_handler.context.scope import ScopeQueryHandlerContext
from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import (
SQLStageQueryHandler,
SQLStageTrainQueryHandlerInput,
SQLStageQueryHandlerInput,
)
from exasol.analytics.query_handler.graph.stage.stage import Stage


class SQLStage(Stage):
@abc.abstractmethod
def create_train_query_handler(
def create_query_handler(
self,
stage_input: SQLStageTrainQueryHandlerInput,
stage_input: SQLStageQueryHandlerInput,
query_handler_context: ScopeQueryHandlerContext,
) -> SQLStageQueryHandler:
pass
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def is_empty(obj: Sized):


@dataclasses.dataclass(eq=True)
class SQLStageTrainQueryHandlerInput:
class SQLStageQueryHandlerInput:
sql_stage_inputs: List[SQLStageInputOutput]
result_bucketfs_location: AbstractBucketFSLocation

Expand All @@ -27,6 +27,6 @@ def __post_init__(self):


class SQLStageQueryHandler(
QueryHandler[SQLStageTrainQueryHandlerInput, SQLStageInputOutput], ABC
QueryHandler[SQLStageQueryHandlerInput, SQLStageInputOutput], ABC
):
pass
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
SQLStageInputOutput,
)
from exasol.analytics.query_handler.graph.stage.sql.sql_stage_query_handler import (
SQLStageTrainQueryHandlerInput,
SQLStageQueryHandlerInput,
)
from tests.utils.mock_cast import mock_cast
from tests.unit_tests.sql_stage_graph.stage_graph_execution_query_handler.state_test_setup import (
Expand Down Expand Up @@ -38,12 +38,12 @@ def assert_parent_query_handler_context_not_called(test_setup: TestSetup):

def assert_stage_not_called(test_setup: TestSetup, *, stage_index: int):
stage_setup = test_setup.stage_setups[stage_index]
mock_cast(stage_setup.stage.create_train_query_handler).assert_not_called()
assert stage_setup.train_query_handler.mock_calls == []
mock_cast(stage_setup.stage.create_query_handler).assert_not_called()
assert stage_setup.query_handler.mock_calls == []
assert stage_setup.child_query_handler_context.mock_calls == []


def assert_stage_train_query_handler_created(
def assert_stage_query_handler_created(
test_setup: TestSetup, *, stage_index: int, stage_inputs: List[SQLStageInputOutput]
):
stage_setup = test_setup.stage_setups[stage_index]
Expand All @@ -56,20 +56,20 @@ def assert_stage_train_query_handler_created(
result_bucketfs_location = test_setup.stage_setups[
stage_index
].result_bucketfs_location
stage_input = SQLStageTrainQueryHandlerInput(
stage_input = SQLStageQueryHandlerInput(
result_bucketfs_location=result_bucketfs_location, sql_stage_inputs=stage_inputs
)
mock_cast(stage_setup.stage.create_train_query_handler).assert_called_once_with(
mock_cast(stage_setup.stage.create_query_handler).assert_called_once_with(
stage_input, stage_setup.child_query_handler_context
)
assert stage_setup.train_query_handler.mock_calls == []
assert stage_setup.query_handler.mock_calls == []
assert stage_setup.child_query_handler_context.mock_calls == []


def assert_release_on_query_handler_context_for_stage(
test_setup: TestSetup, *, stage_index: int
):
stage_setup = test_setup.stage_setups[stage_index]
assert stage_setup.train_query_handler.mock_calls == []
assert stage_setup.query_handler.mock_calls == []
mock_cast(stage_setup.child_query_handler_context.release).assert_called_once()
mock_cast(stage_setup.stage.create_train_query_handler).assert_not_called()
mock_cast(stage_setup.stage.create_query_handler).assert_not_called()
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from tests.utils.mock_cast import mock_cast

MockScopeQueryHandlerContext = Union[ScopeQueryHandlerContext, MagicMock]
MockSQLStageTrainQueryHandler = Union[SQLStageQueryHandler, MagicMock]
MockSQLStageQueryHandler = Union[SQLStageQueryHandler, MagicMock]
MockQueryHandlerResult = Union[Continue, Finish, MagicMock]
MockSQLStageInputOutput = Union[SQLStageInputOutput, MagicMock]
MockSQLStage = Union[SQLStage, MagicMock]
Expand All @@ -47,14 +47,14 @@
class StageSetup:
index: int
child_query_handler_context: MockScopeQueryHandlerContext
train_query_handler: MockSQLStageTrainQueryHandler
query_handler: MockSQLStageQueryHandler
stage: MockSQLStage
results: List[MockQueryHandlerResult]
result_bucketfs_location: MockBucketFSLocation

def reset_mock(self):
self.child_query_handler_context.reset_mock()
self.train_query_handler.reset_mock()
self.query_handler.reset_mock()
self.stage.reset_mock()
self.result_bucketfs_location.reset_mock()
for result in self.results:
Expand Down Expand Up @@ -173,15 +173,15 @@ def create_mocks_for_stage(
result: List[MockQueryHandlerResult] = [
create_autospec(result_prototype) for result_prototype in result_prototypes
]
train_query_handler: MockSQLStageTrainQueryHandler = create_autospec(QueryHandler)
sql_stage.create_train_query_handler.return_value = train_query_handler
query_handler: MockSQLStageQueryHandler = create_autospec(QueryHandler)
sql_stage.create_query_handler.return_value = query_handler
mock_result_bucketfs_location: MockBucketFSLocation = create_autospec(
AbstractBucketFSLocation
)
return StageSetup(
index=stage_index,
child_query_handler_context=child_scoped_query_handler_context,
train_query_handler=train_query_handler,
query_handler=query_handler,
stage=sql_stage,
results=result,
result_bucketfs_location=mock_result_bucketfs_location,
Expand Down
Loading

0 comments on commit 43261bc

Please sign in to comment.