Skip to content

Commit

Permalink
#76: Added get_connection to query_handler_context (#77)
Browse files Browse the repository at this point in the history
  • Loading branch information
tkilias authored Oct 10, 2022
1 parent fbf9444 commit 2869dff
Show file tree
Hide file tree
Showing 16 changed files with 608 additions and 286 deletions.
Original file line number Diff line number Diff line change
@@ -1,9 +1,33 @@
from abc import abstractmethod
import enum
from abc import abstractmethod, ABC

from exasol_advanced_analytics_framework.query_handler.context.query_handler_context import QueryHandlerContext
from exasol_advanced_analytics_framework.query_handler.context.proxy.object_proxy import ObjectProxy


class Connection(ABC):

@property
@abstractmethod
def name(self) -> str:
"""Name of the connection object"""

@property
@abstractmethod
def address(self) -> str:
"""Address of the connection object"""

@property
@abstractmethod
def user(self) -> str:
"""User of the connection object"""

@property
@abstractmethod
def password(self) -> str:
"""Password of the connection object"""


class ScopeQueryHandlerContext(QueryHandlerContext):
@abstractmethod
def release(self):
Expand Down Expand Up @@ -32,3 +56,6 @@ def transfer_object_to(self, object_proxy: ObjectProxy,
handler is always responsible for the transfer.
"""
pass

def get_connection(self, name: str) -> Connection:
pass
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import textwrap
import traceback
from abc import ABC
from typing import Set, List
from typing import Set, List, Callable

from exasol_bucketfs_utils_python.abstract_bucketfs_location import AbstractBucketFSLocation
from exasol_data_science_utils_python.schema.schema_name import SchemaName
Expand All @@ -17,7 +17,7 @@
from exasol_advanced_analytics_framework.query_handler.context.proxy.table_name_proxy import TableNameProxy
from exasol_advanced_analytics_framework.query_handler.context.proxy.view_name_proxy import ViewNameProxy
from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import \
ScopeQueryHandlerContext
ScopeQueryHandlerContext, Connection
from exasol_advanced_analytics_framework.query_handler.query.query import Query


Expand Down Expand Up @@ -57,12 +57,17 @@ def get_all_not_released_contexts(self):
return result


ConnectionLookup = Callable[[str], Connection]


class _ScopeQueryHandlerContextBase(ScopeQueryHandlerContext, ABC):
def __init__(self,
temporary_bucketfs_location: AbstractBucketFSLocation,
temporary_db_object_name_prefix: str,
temporary_schema_name: str,
connection_lookup: ConnectionLookup,
global_temporary_object_counter: TemporaryObjectCounter):
self._connection_lookup = connection_lookup
self._global_temporary_object_counter = global_temporary_object_counter
self._temporary_schema_name = temporary_schema_name
self._temporary_bucketfs_location = temporary_bucketfs_location
Expand Down Expand Up @@ -149,6 +154,7 @@ def get_child_query_handler_context(self) -> ScopeQueryHandlerContext:
new_temporary_bucketfs_location,
self._get_temporary_db_object_name(),
self._temporary_schema_name,
self._connection_lookup,
self._global_temporary_object_counter
)
self._child_query_handler_context_list.append(child_query_handler_context)
Expand Down Expand Up @@ -216,16 +222,21 @@ def _release_object(self, object_proxy: ObjectProxy):
self._owned_object_proxies.remove(object_proxy)
self._invalid_object_proxies.add(object_proxy)

def get_connection(self, name: str) -> Connection:
return self._connection_lookup(name)


class TopLevelQueryHandlerContext(_ScopeQueryHandlerContextBase):
def __init__(self,
temporary_bucketfs_location: AbstractBucketFSLocation,
temporary_db_object_name_prefix: str,
temporary_schema_name: str,
connection_lookup: ConnectionLookup,
global_temporary_object_counter: TemporaryObjectCounter = TemporaryObjectCounter()):
super().__init__(temporary_bucketfs_location,
temporary_db_object_name_prefix,
temporary_schema_name,
connection_lookup,
global_temporary_object_counter)

def _release_object(self, object_proxy: ObjectProxy):
Expand Down Expand Up @@ -271,10 +282,12 @@ def __init__(self, parent: _ScopeQueryHandlerContextBase,
temporary_bucketfs_location: AbstractBucketFSLocation,
temporary_db_object_name_prefix: str,
temporary_schema_name: str,
connection_lookup: ConnectionLookup,
global_temporary_object_counter: TemporaryObjectCounter):
super().__init__(temporary_bucketfs_location,
temporary_db_object_name_prefix,
temporary_schema_name,
connection_lookup,
global_temporary_object_counter)
self.__parent = parent

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import \
ScopeQueryHandlerContext
from exasol_advanced_analytics_framework.query_handler.context.top_level_query_handler_context import \
TopLevelQueryHandlerContext
TopLevelQueryHandlerContext, ConnectionLookup
from exasol_advanced_analytics_framework.query_handler.query.query import Query
from exasol_advanced_analytics_framework.query_handler.query.select_query import SelectQueryWithColumnDefinition
from exasol_advanced_analytics_framework.query_handler.query_handler import QueryHandler
Expand All @@ -33,7 +33,8 @@ def __init__(self,
query_handler = query_handler_factory(parameter, top_level_query_handler_context)
self._state = QueryHandlerRunnerState(
top_level_query_handler_context=top_level_query_handler_context,
query_handler=query_handler
query_handler=query_handler,
connection_lookup=None
)

def run(self) -> ResultType:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,19 @@
from exasol_data_science_utils_python.schema.column import \
Column

from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import ScopeQueryHandlerContext
from exasol_advanced_analytics_framework.query_handler.context.top_level_query_handler_context import TopLevelQueryHandlerContext
from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import \
ScopeQueryHandlerContext
from exasol_advanced_analytics_framework.query_handler.context.top_level_query_handler_context import \
TopLevelQueryHandlerContext
from exasol_advanced_analytics_framework.query_handler.query_handler \
import QueryHandler
from exasol_advanced_analytics_framework.udf_framework.udf_connection_lookup import UDFConnectionLookup


@dataclass()
class QueryHandlerRunnerState:
top_level_query_handler_context: TopLevelQueryHandlerContext
query_handler: QueryHandler
connection_lookup: UDFConnectionLookup
input_query_query_handler_context: Optional[ScopeQueryHandlerContext] = None
input_query_output_columns: Optional[List[Column]] = None
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import UDFQueryResult
from exasol_advanced_analytics_framework.udf_framework.query_handler_runner_state \
import QueryHandlerRunnerState
from exasol_advanced_analytics_framework.udf_framework.udf_connection_lookup import UDFConnectionLookup


@dataclasses.dataclass
Expand Down Expand Up @@ -187,21 +188,29 @@ def _create_state_or_load_latest_state(self) -> QueryHandlerRunnerState:
query_handler_state = self._create_state()
return query_handler_state

def _create_state(self):
context = TopLevelQueryHandlerContext(self.bucketfs_location,
self.parameter.temporary_name_prefix,
self.parameter.temporary_schema_name)
def _create_state(self) -> QueryHandlerRunnerState:
connection_lookup = UDFConnectionLookup(self.exa)
context = TopLevelQueryHandlerContext(
self.bucketfs_location,
self.parameter.temporary_name_prefix,
self.parameter.temporary_schema_name,
connection_lookup
)
module = importlib.import_module(self.parameter.python_class_module)
query_handler_factory_class = getattr(module, self.parameter.python_class_name)
query_handler_obj = query_handler_factory_class().create(self.parameter.parameters, context)
query_handler_state = QueryHandlerRunnerState(
top_level_query_handler_context=context,
query_handler=query_handler_obj)
query_handler=query_handler_obj,
connection_lookup=connection_lookup
)
return query_handler_state

def _load_latest_state(self):
def _load_latest_state(self) -> QueryHandlerRunnerState:
state_file_bucketfs_path = self._generate_state_file_bucketfs_path()
query_handler_state = self.bucketfs_location.read_file_from_bucketfs_via_joblib(str(state_file_bucketfs_path))
query_handler_state: QueryHandlerRunnerState = \
self.bucketfs_location.read_file_from_bucketfs_via_joblib(str(state_file_bucketfs_path))
query_handler_state.connection_lookup.exa = self.exa
return query_handler_state

def _save_current_state(self, current_state: QueryHandlerRunnerState) -> None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from exasol_advanced_analytics_framework.query_handler.context.scope_query_handler_context import Connection


class UDFConnection(Connection):

def __init__(self, name: str, udf_connection):
self._udf_connection = udf_connection
self._name = name

@property
def name(self) -> str:
return self._name

@property
def address(self) -> str:
return self._udf_connection.address

@property
def user(self) -> str:
return self._udf_connection.user

@property
def password(self) -> str:
return self._udf_connection.password


class UDFConnectionLookup:
def __init__(self, exa):
self.exa = exa

def __getstate__(self):
result = self.__dict__.copy()
del result["exa"]
return result

def __call__(self, name: str):
udf_connection = self.exa.get_connection(name)
return UDFConnection(name, udf_connection)
Loading

0 comments on commit 2869dff

Please sign in to comment.