Skip to content

Commit

Permalink
Feature/#232 redesign sql stage input output (#237)
Browse files Browse the repository at this point in the history
* #232: Redesigned SQLStageInputOutput
* Added docstring to class Dataset
* Change key of dict Dependencies to object
  • Loading branch information
ckunki authored Dec 5, 2024
1 parent 43261bc commit 223a124
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 354 deletions.
3 changes: 2 additions & 1 deletion doc/changes/changes_0.2.0.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# 0.2.0 - 2024-12-03
# 0.2.0 - 2024-12-06

Code name:

Expand All @@ -22,3 +22,4 @@ Code name:
* #221: Fixed mypy warnings
* #233: Upgraded pydantic to version 2
* #231: Renamed `TrainQueryHandler` to `SQLStageQueryHandler`
* #232: Redesigned `SQLStageInputOutput`
26 changes: 0 additions & 26 deletions exasol/analytics/query_handler/graph/stage/sql/data_partition.py

This file was deleted.

62 changes: 13 additions & 49 deletions exasol/analytics/query_handler/graph/stage/sql/dataset.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,20 @@
import dataclasses
from enum import Enum
from typing import Dict, List, Tuple, Union
from dataclasses import dataclass, field
from typing import List

from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition
from exasol.analytics.schema import Column
from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types
from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies
from exasol.analytics.schema import Column, TableLike

DataPartitionName = Union[Enum, Tuple[Enum, int]]


@dataclasses.dataclass(frozen=True)
@dataclass(frozen=True)
class Dataset:
"""
A Dataset consists of multiple data partitions and column lists which indicate the identifier,
sample and target columns, The data paritions can be used to describe train and test sets.
"""
A Dataset consists of a TableLike, column lists indicating the
identifier and other columns, and optional dependencies.
data_partitions: Dict[DataPartitionName, DataPartition]
The TableLike refers to a database table containing the actual data that
can be used for instance in training or testing.
"""
table_like: TableLike
identifier_columns: List[Column]
sample_columns: List[Column]
target_columns: List[Column]

def __post_init__(self):
check_dataclass_types(self)
self._check_table_name()
self._check_columns()

def _check_table_name(self):
all_table_like_names = {
data_partition.table_like.name
for data_partition in self.data_partitions.values()
}
if len(all_table_like_names) != len(self.data_partitions):
raise ValueError(
"The names of table likes of the data partitions should be different."
)

def _check_columns(self):
all_columns = {
column
for data_partition in self.data_partitions.values()
for column in data_partition.table_like.columns
}
all_data_partition_have_same_columns = all(
len(data_partition.table_like.columns) == len(all_columns)
for data_partition in self.data_partitions.values()
)
if not all_data_partition_have_same_columns:
raise ValueError("Not all data partitions have the same columns.")
if not all_columns.issuperset(self.sample_columns):
raise ValueError("Not all sample columns in data partitions.")
if not all_columns.issuperset(self.target_columns):
raise ValueError("Not all target columns in data partitions.")
if not all_columns.issuperset(self.identifier_columns):
raise ValueError("Not all identifier columns in data partitions.")
columns: List[Column]
dependencies: Dependencies = field(default_factory=dict)
14 changes: 8 additions & 6 deletions exasol/analytics/query_handler/graph/stage/sql/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,18 @@
@dataclasses.dataclass(frozen=True)
class Dependency:
"""
This class represents that a object depends on something which in fact can depend on something else.
That exactly this means is user defined.
For example, this could represent that a view depends on a certain table.
An instance of this class represents a node in a dependency graph.
The exact meaning of a dependency is user-defined. For example, a
dependency could express that a database view depends on a particular
table.
"""

object: Any
dependencies: Dict[Enum, "Dependency"] = dataclasses.field(default_factory=dict)
"""
Dependency can have their own dependencies. For example, a view which depends on another view
which in fact then consists of table.
Each dependency can again have subsequent dependencies. For example, a
view can depend on another view which in fact then consists of table.
"""

def __post_init__(self):
Expand All @@ -34,4 +36,4 @@ def __post_init__(self):
raise TypeCheckError(f"Field 'dependencies' has wrong type: {e}")


Dependencies = Dict[Enum, Dependency]
Dependencies = Dict[object, Dependency]
18 changes: 8 additions & 10 deletions exasol/analytics/query_handler/graph/stage/sql/input_output.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import dataclasses
from dataclasses import dataclass, field
from typing import Dict, Protocol

from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset
from exasol.analytics.query_handler.graph.stage.sql.dependency import Dependencies
from exasol.analytics.utils.data_classes_runtime_type_check import check_dataclass_types


@dataclasses.dataclass(frozen=True, eq=True)
@dataclass(frozen=True)
class SQLStageInputOutput:
"""
A SQLStageInputOutput is used as input and output between the SQLStageQueryHandler.
Expand All @@ -14,12 +14,10 @@ class SQLStageInputOutput:
For example, a dependency could be a table which the previous stage computed and
the subsequent one uses.
"""
pass

dataset: Dataset
dependencies: Dependencies = dataclasses.field(default_factory=dict)
"""
This contains user-defined dependencies which the previous stage wants to communicate to the subsequent stage.
"""

def __post_init__(self):
check_dataclass_types(self)
@dataclass(frozen=True)
class MultiDatasetSQLStageInputOutput(SQLStageInputOutput):
datasets: Dict[object, Dataset]
dependencies: Dependencies = field(default_factory=dict)
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from exasol.analytics.query_handler.context.top_level_query_handler_context import (
TopLevelQueryHandlerContext,
)
from exasol.analytics.query_handler.graph.stage.sql.data_partition import DataPartition
from exasol.analytics.query_handler.graph.stage.sql.dataset import Dataset
from exasol.analytics.query_handler.graph.stage.sql.execution.input import (
SQLStageGraphExecutionInput,
Expand All @@ -24,6 +23,7 @@
)
from exasol.analytics.query_handler.graph.stage.sql.input_output import (
SQLStageInputOutput,
MultiDatasetSQLStageInputOutput,
)
from exasol.analytics.query_handler.graph.stage.sql.sql_stage import SQLStage
from exasol.analytics.query_handler.graph.stage.sql.sql_stage_graph import SQLStageGraph
Expand Down Expand Up @@ -77,10 +77,8 @@ def __init__(
self.input_table_like_name: Optional[str] = None

def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]:
dataset = self._parameter.sql_stage_inputs[0].dataset
input_table_like = dataset.data_partitions[
TestDatasetPartitionName.TRAIN
].table_like
datasets = self._parameter.sql_stage_inputs[0].datasets
input_table_like = datasets[TestDatasetName.TRAIN].table_like
# This tests also, if temporary table names are still valid
self.input_table_like_name = input_table_like.name.fully_qualified

Expand Down Expand Up @@ -110,8 +108,8 @@ def __init__(
self.query_result: Optional[QueryResult] = None

def start(self) -> Union[Continue, Finish[SQLStageInputOutput]]:
dataset = self._parameter.sql_stage_inputs[0].dataset
table_like = dataset.data_partitions[TestDatasetPartitionName.TRAIN].table_like
datasets = self._parameter.sql_stage_inputs[0].datasets
table_like = datasets[TestDatasetName.TRAIN].table_like
table_like_name = table_like.name
table_like_columns = table_like.columns
select_query_with_column_definition = SelectQueryWithColumnDefinition(
Expand Down Expand Up @@ -166,7 +164,7 @@ def __hash__(self):
return self._index


class TestDatasetPartitionName(enum.Enum):
class TestDatasetName(enum.Enum):
__test__ = False
TRAIN = enum.auto()

Expand Down Expand Up @@ -202,16 +200,13 @@ def create_stage_input_output(table_name: TableName):
target_column,
]
table_like = TableBuilder().with_name(table_name).with_columns(columns).build()
data_partition = DataPartition(
table_like=table_like,
)
dataset = Dataset(
data_partitions={TestDatasetPartitionName.TRAIN: data_partition},
target_columns=[target_column],
sample_columns=[sample_column],
table_like = table_like,
columns=[target_column, sample_column],
identifier_columns=[identifier_column],
)
stage_input_output = SQLStageInputOutput(dataset=dataset)
datasets = { TestDatasetName.TRAIN: dataset }
stage_input_output = MultiDatasetSQLStageInputOutput(datasets=datasets)
return stage_input_output


Expand Down Expand Up @@ -289,7 +284,7 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]:
assert (
isinstance(result, Finish)
and isinstance(result.result, SQLStageInputOutput)
and result.result.dataset == test_setup.stage_input_output.dataset
and result.result.datasets == test_setup.stage_input_output.datasets
and len(top_level_query_handler_context_mock.cleanup_released_object_proxies())
== 0
)
Expand Down Expand Up @@ -345,7 +340,7 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]:
assert (
isinstance(result, Finish)
and isinstance(result.result, SQLStageInputOutput)
and result.result.dataset == test_setup.stage_input_output.dataset
and result.result.datasets == test_setup.stage_input_output.datasets
and len(top_level_query_handler_context_mock.cleanup_released_object_proxies())
== 0
)
Expand Down Expand Up @@ -406,22 +401,22 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]:
assert (
isinstance(result, Finish)
and isinstance(result.result, SQLStageInputOutput)
and result.result.dataset != test_setup.stage_input_output.dataset
and result.result.datasets != test_setup.stage_input_output.datasets
and isinstance(
stage_1_query_handler,
StartOnlyCreateNewOutputTestSQLStageQueryHandler,
)
and result.result.dataset
== stage_1_query_handler.stage_input_output.dataset
and result.result.datasets
== stage_1_query_handler.stage_input_output.datasets
and stage_1_query_handler.input_table_like_name is not None
and len(top_level_query_handler_context_mock.cleanup_released_object_proxies())
== 0
)

if isinstance(result, Finish) and isinstance(result.result, SQLStageInputOutput):
with not_raises(Exception):
name = result.result.dataset.data_partitions[
TestDatasetPartitionName.TRAIN
name = result.result.datasets[
TestDatasetName.TRAIN
].table_like.name

test_setup.child_query_handler_context.release()
Expand Down Expand Up @@ -478,7 +473,7 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]:
assert (
isinstance(result, Finish)
and isinstance(result.result, SQLStageInputOutput)
and result.result.dataset != test_setup.stage_input_output.dataset
and result.result.datasets != test_setup.stage_input_output.datasets
and isinstance(
stage_1_query_handler,
StartOnlyCreateNewOutputTestSQLStageQueryHandler,
Expand All @@ -487,8 +482,8 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]:
stage_2_query_handler,
StartOnlyCreateNewOutputTestSQLStageQueryHandler,
)
and result.result.dataset
== stage_2_query_handler.stage_input_output.dataset
and result.result.datasets
== stage_2_query_handler.stage_input_output.datasets
and stage_1_query_handler.input_table_like_name is not None
and stage_2_query_handler.input_table_like_name is not None
and len(top_level_query_handler_context_mock.cleanup_released_object_proxies())
Expand All @@ -497,9 +492,7 @@ def act(test_setup: TestSetup) -> Union[Continue, Finish[SQLStageInputOutput]]:

if isinstance(result, Finish) and isinstance(result.result, SQLStageInputOutput):
with not_raises(Exception):
name = result.result.dataset.data_partitions[
TestDatasetPartitionName.TRAIN
].table_like.name
name = result.result.datasets[TestDatasetName.TRAIN].table_like.name

test_setup.child_query_handler_context.release()
assert (
Expand Down Expand Up @@ -611,9 +604,7 @@ def act(

if isinstance(result, Finish) and isinstance(result.result, SQLStageInputOutput):
with not_raises(Exception):
name = result.result.dataset.data_partitions[
TestDatasetPartitionName.TRAIN
].table_like.name
name = result.result.datasets[TestDatasetName.TRAIN].table_like.name

test_setup.child_query_handler_context.release()
assert (
Expand Down
59 changes: 0 additions & 59 deletions tests/unit_tests/sql_stage_graph/test_data_partition.py

This file was deleted.

Loading

0 comments on commit 223a124

Please sign in to comment.