Skip to content

Commit

Permalink
Fixed type check errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ckunki committed Nov 22, 2024
1 parent 8cf706c commit b2e61fe
Show file tree
Hide file tree
Showing 16 changed files with 290 additions and 192 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import dataclasses

from exasol_bucketfs_utils_python.abstract_bucketfs_location import (
AbstractBucketFSLocation,
)
import exasol.bucketfs as bfs # type: ignore[import-untyped]

from exasol.analytics.query_handler.graph.stage.sql.input_output import (
SQLStageInputOutput,
Expand All @@ -13,5 +11,5 @@
@dataclasses.dataclass(frozen=True, eq=True)
class SQLStageGraphExecutionInput:
input: SQLStageInputOutput
result_bucketfs_location: AbstractBucketFSLocation # should this be bfs.path.PathLike?
result_bucketfs_location: bfs.path.PathLike
sql_stage_graph: SQLStageGraph
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from exasol.analytics.query_handler.query_handler import QueryHandler
from exasol.analytics.query_handler.result import Continue, Finish
from exasol.analytics.utils.errors import UninitializedAttributeError


class ResultHandlerReturnValue(enum.Enum):
Expand Down Expand Up @@ -58,20 +59,33 @@ def __init__(
self._current_query_handler_context: Optional[ScopeQueryHandlerContext] = None
self._create_current_query_handler()

def _check_is_valid(self):
if self._current_query_handler is None:
raise RuntimeError("No current query handler set.")

def get_current_query_handler(
self,
self
) -> QueryHandler[List[SQLStageInputOutput], SQLStageInputOutput]:
self._check_is_valid()
return self._current_query_handler
value = self._current_query_handler
if value is None:
raise RuntimeError("No current query handler set.")
return value

@property
def current_query_handler_context(self) -> ScopeQueryHandlerContext:
value = self._current_query_handler_context
if value is None:
raise UninitializedAttributeError("current query handler context is None.")
return value

@property
def current_stage(self) -> SQLStage:
value = self._current_stage
if value is None:
raise UninitializedAttributeError("current stage is None.")
return value

def handle_result(
self, result: Union[Continue, Finish[SQLStageInputOutput]]
) -> ResultHandlerReturnValue:
self._check_is_valid()
# _check_is_valid()
self.get_current_query_handler()
if isinstance(result, Finish):
return self._handle_finished_result(result)
elif isinstance(result, Continue):
Expand All @@ -90,7 +104,7 @@ def _handle_finished_result(
return self._try_to_move_to_next_stage()

def _try_to_move_to_next_stage(self) -> ResultHandlerReturnValue:
self._current_query_handler_context.release()
self.current_query_handler_context.release()
if self._is_not_last_stage():
self._move_to_next_stage()
return ResultHandlerReturnValue.CONTINUE_PROCESSING
Expand Down Expand Up @@ -123,12 +137,12 @@ def _create_current_query_handler(self):
result_bucketfs_location=result_bucketfs_location,
sql_stage_inputs=stage_inputs,
)
self._current_query_handler = self._current_stage.create_train_query_handler(
self._current_query_handler = self.current_stage.create_train_query_handler(
stage_input, self._current_query_handler_context
)

def _add_result_to_successors(self, result: SQLStageInputOutput):
successors = self._sql_stage_graph.successors(self._current_stage)
successors = self._sql_stage_graph.successors(self.current_stage)
if len(successors) == 0:
raise RuntimeError("Programming error")
self._add_result_to_inputs_of_successors(result, successors)
Expand All @@ -146,7 +160,7 @@ def _add_result_to_reference_counting_bag(
object_proxies = find_object_proxies(result)
for object_proxy in object_proxies:
if object_proxy not in self._reference_counting_bag:
self._current_query_handler_context.transfer_object_to(
self.current_query_handler_context.transfer_object_to(
object_proxy, self._query_handler_context
)
for _ in successors:
Expand All @@ -160,7 +174,7 @@ def _transfer_ownership_of_result_to_query_result_handler(self, result):
object_proxy
)
else:
self._current_query_handler_context.transfer_object_to(
self.current_query_handler_context.transfer_object_to(
object_proxy, self._query_handler_context
)

Expand Down
5 changes: 5 additions & 0 deletions exasol/analytics/query_handler/python_query_handler_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from exasol.analytics.query_handler.result import Continue, Finish
from exasol.analytics.query_handler.udf.runner.state import QueryHandlerRunnerState
from exasol.analytics.sql_executor.interface import SQLExecutor
from exasol.analytics.utils.errors import UninitializedAttributeError

LOGGER = logging.getLogger(__file__)

Expand Down Expand Up @@ -116,6 +117,10 @@ def _release_and_create_query_handler_context_of_input_query(self):
def _wrap_return_query(
self, input_query: SelectQueryWithColumnDefinition
) -> Tuple[str, str]:
if self._state.input_query_query_handler_context is None:
raise UninitializedAttributeError(
"Attribute _state.input_query_query_handler_context is None"
)
temporary_view_name = (
self._state.input_query_query_handler_context.get_temporary_view_name()
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _range(self, num_rows: Union[int, str]) -> range:
if isinstance(num_rows, int):
return range(num_rows - 1)
if num_rows == "all":
return range(len(self.data) - 1)
return range(len(self._data) - 1)
raise ValueError(
f'num_rows must be an int or str "all" but is {num_rows}'
)
Expand Down
58 changes: 30 additions & 28 deletions exasol/analytics/query_handler/udf/runner/udf.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import dataclasses
import importlib
import json
import logging
import traceback
from collections import OrderedDict
from dataclasses import dataclass, field
from enum import Enum, auto
from io import BytesIO
from typing import Any, List, Optional, Tuple
Expand All @@ -29,6 +29,7 @@
SchemaName,
UDFNameBuilder,
)
from exasol.analytics.utils.errors import UninitializedAttributeError


def create_bucketfs_location_from_conn_object(bfs_conn_obj) -> bfs.path.PathLike:
Expand All @@ -51,18 +52,17 @@ def read_via_joblib(location: bfs.path.PathLike) -> Any:
return joblib.load(buffer)


class IllegalStateException(Exception):
"""
A method of class QueryHandlerRunnerUDF accesses an attribute that has
not been initialized before, e.g. in method run().
"""

@dataclasses.dataclass
@dataclass
class UDFParameter:
iter_num: int
temporary_bfs_location_conn: str
temporary_bfs_location_directory: str
temporary_name_prefix: str
#
# The following attributes are only needed in the first iteration and
# hence need to be Optional, see method
# QueryHandlerRunnerUDF._get_parameter().
#
temporary_schema_name: Optional[str] = None
python_class_name: Optional[str] = None
python_class_module: Optional[str] = None
Expand All @@ -75,13 +75,13 @@ class QueryHandlerStatus(Enum):
ERROR = auto()


@dataclasses.dataclass
@dataclass
class UDFResult:
input_query_view: Optional[str] = None
input_query: Optional[str] = None
final_result: Any = {}
query_list: List[Query] = []
cleanup_query_list: List[Query] = []
final_result: str = "{}"
query_list: List[Query] = field(default_factory=list)
cleanup_query_list: List[Query] = field(default_factory=list)
status: QueryHandlerStatus = QueryHandlerStatus.CONTINUE


Expand All @@ -95,15 +95,13 @@ def __init__(self, exa):
@property
def bucketfs_location(self) -> bfs.path.PathLike:
if self._bucketfs_location is None:
raise IllegalStateException(
"attribute _bucketfs_location is not initialized"
)
raise UninitializedAttributeError("BucketFS location is undefined.")
return self._bucketfs_location

@property
def parameter(self) -> UDFParameter:
if self._parameter is None:
raise IllegalStateException("attribute _parameter is not initialized")
raise UninitializedAttributeError("Parameter is undefined.")
return self._parameter

def run(self, ctx) -> None:
Expand Down Expand Up @@ -134,12 +132,12 @@ def run(self, ctx) -> None:

def handle_exception(self, ctx, current_state: QueryHandlerRunnerState):
stacktrace = traceback.format_exc()
logging.exception("Catched exception, starting cleanup.")
logging.exception("Caught exception, starting cleanup.")
try:
self.release_query_handler_context(current_state)
except:
logging.exception(
"Catched exception during handling cleanup of another exception"
"Caught exception during handling cleanup of another exception"
)
cleanup_queries = (
current_state.top_level_query_handler_context.cleanup_released_object_proxies()
Expand Down Expand Up @@ -193,6 +191,8 @@ def handle_query_handler_result_continue(
query_handler_result.input_query.output_columns
)
self.release_and_create_query_handler_context_if_input_query(current_state)
if current_state.input_query_query_handler_context is None:
raise UninitializedAttributeError("Current state has no input query handler context.")
udf_result.input_query_view, udf_result.input_query = self._wrap_return_query(
current_state.input_query_query_handler_context,
query_handler_result.input_query,
Expand Down Expand Up @@ -229,14 +229,12 @@ def _get_parameter(self, ctx) -> UDFParameter:
temporary_name_prefix=ctx[3],
)

def _create_bucketfs_location(self):
bucketfs_connection_obj = self.exa.get_connection(
def _create_bucketfs_location(self) -> bfs.path.PathLike:
bfscon = self.exa.get_connection(
self.parameter.temporary_bfs_location_conn
)
bucketfs_location_from_con = create_bucketfs_location_from_conn_object(
bucketfs_connection_obj
)
self.bucketfs_location = bucketfs_location_from_con.joinpath(
bfs_location = create_bucketfs_location_from_conn_object(bfscon)
return bfs_location.joinpath(
self.parameter.temporary_bfs_location_directory
).joinpath(self.parameter.temporary_name_prefix)

Expand All @@ -247,6 +245,7 @@ def _create_state_or_load_latest_state(self) -> QueryHandlerRunnerState:
query_handler_state = self._create_state()
return query_handler_state

@property
def _query_handler_factory(self) -> UDFQueryHandlerFactory:
module_name = self.parameter.python_class_module
if not module_name:
Expand All @@ -260,22 +259,25 @@ def _query_handler_factory(self) -> UDFQueryHandlerFactory:
"UDFQueryHandler parameters must define a factory class"
)
factory = getattr(module, class_name)
if factory:
if not factory:
raise ValueError(
f'class "{class_name}" not found in module "{module_name}"'
)
return factory
return factory()

def _create_state(self) -> QueryHandlerRunnerState:
connection_lookup = UDFConnectionLookup(self.exa)
if self.parameter.temporary_schema_name is None:
raise UninitializedAttributeError("Temporary schema name is undefined.")
context = TopLevelQueryHandlerContext(
self.bucketfs_location,
self.parameter.temporary_name_prefix,
self.parameter.temporary_schema_name,
connection_lookup,
)
query_handler_obj = self._query_handler_factory().create(
self.parameter.parameter or "", context
str_parameter = self.parameter.parameter or ""
query_handler_obj = self._query_handler_factory.create(
str_parameter, context,
)
query_handler_state = QueryHandlerRunnerState(
top_level_query_handler_context=context,
Expand Down
11 changes: 3 additions & 8 deletions exasol/analytics/schema/column_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Union
from typeguard import TypeCheckError

from exasol.analytics.schema.column import Column
from exasol.analytics.schema.column_name import ColumnName
Expand All @@ -11,12 +12,6 @@ def __init__(self, column: Union[Column, None] = None):
(None, None) if column is None
else (column.name, column.type)
)
# if column is not None:
# self._name = column.name
# self._type = column.type
# else:
# self._name = None
# self._type = None

def with_name(self, name: ColumnName) -> "ColumnBuilder":
self._name = name
Expand All @@ -28,8 +23,8 @@ def with_type(self, type: ColumnType) -> "ColumnBuilder":

def build(self) -> Column:
if self._name is None:
raise ValueError("name must not be None")
raise TypeCheckError("name must not be None")
if self._type is None:
raise ValueError("type must not be None")
raise TypeCheckError("type must not be None")
column = Column(self._name, self._type)
return column
5 changes: 3 additions & 2 deletions exasol/analytics/schema/column_name_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Optional

from typeguard import TypeCheckError
from exasol.analytics.schema.column_name import ColumnName
from exasol.analytics.schema.table_like_name import TableLikeName

Expand Down Expand Up @@ -38,8 +39,8 @@ def with_table_like_name(

def build(self) -> ColumnName:
if self._name is None:
raise ValueError("name must not be None")
name = self.create(str(self._name), table_like_name=self._table_like_name)
raise TypeCheckError("name must not be None")
name = self.create(self._name, table_like_name=self._table_like_name)
return name

@staticmethod
Expand Down
5 changes: 3 additions & 2 deletions exasol/analytics/schema/table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from exasol.analytics.schema.column import Column
from exasol.analytics.schema.table import Table
from exasol.analytics.schema.table_name import TableName
from typeguard import TypeCheckError


class TableBuilder:
Expand All @@ -22,8 +23,8 @@ def with_columns(self, columns: List[Column]) -> "TableBuilder":

def build(self) -> Table:
if self._name is None:
raise ValueError("name must not be None")
raise TypeCheckError("name must not be None")
if not self._columns:
raise ValueError("there must be at least one column")
raise TypeCheckError("there must be at least one column")
table = Table(self._name, self._columns)
return table
3 changes: 2 additions & 1 deletion exasol/analytics/schema/table_name_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from exasol.analytics.schema.schema_name import SchemaName
from exasol.analytics.schema.table_name import TableName
from exasol.analytics.schema.table_name_impl import TableNameImpl
from typeguard import TypeCheckError


class TableNameBuilder:
Expand Down Expand Up @@ -39,7 +40,7 @@ def with_schema_name(self, schema_name: SchemaName) -> "TableNameBuilder":

def build(self) -> TableName:
if self._name is None:
raise ValueError("name must not be None")
raise TypeCheckError("name must not be None")
return self.create(self._name, self._schema_name)

@staticmethod
Expand Down
3 changes: 2 additions & 1 deletion exasol/analytics/schema/udf_name_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from exasol.analytics.schema.schema_name import SchemaName
from exasol.analytics.schema.udf_name import UDFName
from exasol.analytics.schema.udf_name_impl import UDFNameImpl
from typeguard import TypeCheckError


class UDFNameBuilder:
Expand Down Expand Up @@ -39,7 +40,7 @@ def with_schema_name(self, schema_name: SchemaName) -> "UDFNameBuilder":

def build(self) -> UDFName:
if self._name is None:
raise ValueError("name must not be None")
raise TypeCheckError("name must not be None")
return self.create(self._name, self._schema_name)

@staticmethod
Expand Down
Loading

0 comments on commit b2e61fe

Please sign in to comment.