Skip to content

Commit

Permalink
Refactor database client and task classes for improved type handling …
Browse files Browse the repository at this point in the history
…and optional parameters

- Updated the `get_cyclic_data` method in `DatabaseClient` to allow `pagination` to be optional.
- Enhanced `SetupContext` to provide default values for `default_separator` and `default_locale` if not specified.
- Refined type hints in `KeyVariableTask` for `_generator` and improved handling of optional parameters in various methods.
- Adjusted `VariableTask` to ensure `statement.cyclic` defaults to `False` when not provided.

These changes enhance type safety and flexibility across the database client and task management system.
  • Loading branch information
ake2l committed Dec 23, 2024
1 parent d4ec496 commit 39c69a5
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 12 deletions.
2 changes: 1 addition & 1 deletion datamimic_ce/clients/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def get_by_page_with_type(self, table_name: str, pagination: DataSourcePaginatio
"""

@abstractmethod
def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination) -> list:
def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination | None) -> list:
"""
Get cyclic data from database
"""
Expand Down
4 changes: 2 additions & 2 deletions datamimic_ce/contexts/setup_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def __init__(
self._max_count = 1000
self._validate = False
self._default_error_handler = None
self._default_separator = default_separator
self._default_locale = default_locale
self._default_separator = default_separator or ","
self._default_locale: str = default_locale or "en_US"
self._default_dataset = default_dataset
self._default_null = None
self._default_script = "py"
Expand Down
2 changes: 1 addition & 1 deletion datamimic_ce/tasks/key_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def execute(self, ctx: Context) -> None:
if hasattr(self._statement, "sub_statements"):
for stmt in self._statement.sub_statements:
task = task_util_cls.get_task_by_statement(root_ctx, stmt)
if isinstance(task, ElementTask):
if isinstance(task, ElementTask) and isinstance(ctx, GenIterContext):
attributes.update(task.generate_xml_attribute(ctx))
else:
raise ValueError(
Expand Down
13 changes: 7 additions & 6 deletions datamimic_ce/tasks/key_variable_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(

self._element_tag = "key" if isinstance(statement, KeyStatement) else "variable"
self._statement = statement
self._generator = None
self._generator: WeightedDataSource | None = None
self._pagination = pagination
self._converter_list = TaskUtil.create_converter_list(ctx, statement.converter)

Expand Down Expand Up @@ -112,8 +112,9 @@ def _determine_generation_mode(self, ctx: SetupContext):
if not source.endswith("wgt.csv"):
raise ValueError(f"Data source of attribute '{self._statement.name}' must be type of: 'wgt.csv'")
separator = self._statement.separator or ctx.default_separator
# TODO: mypy issue[assignment]
self._generator = WeightedDataSource(file_path=ctx.descriptor_dir / source, separator=separator)
self._generator: WeightedDataSource = WeightedDataSource(
file_path=ctx.descriptor_dir / source, separator=separator
)
self._mode = self._GENERATOR_MODE
elif self._statement.pattern is not None:
self._mode = self._PATTERN_MODE
Expand Down Expand Up @@ -142,8 +143,8 @@ def _generate_value(self, ctx: Context):

if self._mode == self._SCRIPT_MODE:
try:
# TODO: mypy issue `self._statement.script` maybe None
value = ctx.evaluate_python_expression(self._statement.script)
if self._statement.script is not None:
value = ctx.evaluate_python_expression(self._statement.script)
except Exception as e:
if self._statement.default_value is not None:
value = (
Expand Down Expand Up @@ -174,7 +175,7 @@ def _generate_value(self, ctx: Context):
# TODO: mypy issue: `self.statement.string` maybe None
value = TaskUtil.evaluate_variable_concat_prefix_suffix(
context=ctx,
expr=self.statement.string,
expr=self.statement.string or "",
prefix=self._prefix,
suffix=self._suffix,
)
Expand Down
3 changes: 1 addition & 2 deletions datamimic_ce/tasks/variable_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,9 @@ def __init__(
len_data = ctx.data_source_len.get(statement.full_name)
if len_data is None:
len_data = client.count_query_length(selector)
# TODO: mypy issue handle when cyclic is None, pagination is None
file_data = client.get_cyclic_data(
selector,
statement.cyclic,
statement.cyclic or False,
len_data,
pagination,
)
Expand Down

0 comments on commit 39c69a5

Please sign in to comment.