diff --git a/datamimic_ce/clients/database_client.py b/datamimic_ce/clients/database_client.py index 7bad57a..ed8cea2 100644 --- a/datamimic_ce/clients/database_client.py +++ b/datamimic_ce/clients/database_client.py @@ -31,7 +31,7 @@ def get_by_page_with_type(self, table_name: str, pagination: DataSourcePaginatio """ @abstractmethod - def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination) -> list: + def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination | None) -> list: """ Get cyclic data from database """ diff --git a/datamimic_ce/contexts/setup_context.py b/datamimic_ce/contexts/setup_context.py index 0c0ca75..1a4429b 100644 --- a/datamimic_ce/contexts/setup_context.py +++ b/datamimic_ce/contexts/setup_context.py @@ -67,8 +67,8 @@ def __init__( self._max_count = 1000 self._validate = False self._default_error_handler = None - self._default_separator = default_separator - self._default_locale = default_locale + self._default_separator = default_separator or "," + self._default_locale: str = default_locale or "en_US" self._default_dataset = default_dataset self._default_null = None self._default_script = "py" diff --git a/datamimic_ce/tasks/key_task.py b/datamimic_ce/tasks/key_task.py index 3bd762a..609cc2f 100644 --- a/datamimic_ce/tasks/key_task.py +++ b/datamimic_ce/tasks/key_task.py @@ -81,7 +81,7 @@ def execute(self, ctx: Context) -> None: if hasattr(self._statement, "sub_statements"): for stmt in self._statement.sub_statements: task = task_util_cls.get_task_by_statement(root_ctx, stmt) - if isinstance(task, ElementTask): + if isinstance(task, ElementTask) and isinstance(ctx, GenIterContext): attributes.update(task.generate_xml_attribute(ctx)) else: raise ValueError( diff --git a/datamimic_ce/tasks/key_variable_task.py b/datamimic_ce/tasks/key_variable_task.py index e749937..5188c51 100644 --- a/datamimic_ce/tasks/key_variable_task.py +++ b/datamimic_ce/tasks/key_variable_task.py @@ -51,7 +51,7 @@ def __init__( self._element_tag = "key" if isinstance(statement, KeyStatement) else "variable" self._statement = statement - self._generator = None + self._generator: WeightedDataSource | None = None self._pagination = pagination self._converter_list = TaskUtil.create_converter_list(ctx, statement.converter) @@ -112,8 +112,9 @@ def _determine_generation_mode(self, ctx: SetupContext): if not source.endswith("wgt.csv"): raise ValueError(f"Data source of attribute '{self._statement.name}' must be type of: 'wgt.csv'") separator = self._statement.separator or ctx.default_separator - # TODO: mypy issue[assignment] - self._generator = WeightedDataSource(file_path=ctx.descriptor_dir / source, separator=separator) + self._generator: WeightedDataSource = WeightedDataSource( + file_path=ctx.descriptor_dir / source, separator=separator + ) self._mode = self._GENERATOR_MODE elif self._statement.pattern is not None: self._mode = self._PATTERN_MODE @@ -142,8 +143,8 @@ def _generate_value(self, ctx: Context): if self._mode == self._SCRIPT_MODE: try: - # TODO: mypy issue `self._statement.script` maybe None - value = ctx.evaluate_python_expression(self._statement.script) + if self._statement.script is not None: + value = ctx.evaluate_python_expression(self._statement.script) except Exception as e: if self._statement.default_value is not None: value = ( @@ -174,7 +175,7 @@ def _generate_value(self, ctx: Context): # TODO: mypy issue: `self.statement.string` maybe None value = TaskUtil.evaluate_variable_concat_prefix_suffix( context=ctx, - expr=self.statement.string, + expr=self.statement.string or "", prefix=self._prefix, suffix=self._suffix, ) diff --git a/datamimic_ce/tasks/variable_task.py b/datamimic_ce/tasks/variable_task.py index 8881105..59f835d 100644 --- a/datamimic_ce/tasks/variable_task.py +++ b/datamimic_ce/tasks/variable_task.py @@ -130,10 +130,9 @@ def __init__( len_data = ctx.data_source_len.get(statement.full_name) if len_data is None: len_data = client.count_query_length(selector) - # TODO: mypy issue handle when cyclic is None, pagination is None file_data = client.get_cyclic_data( selector, - statement.cyclic, + statement.cyclic or False, len_data, pagination, )