Skip to content

Commit

Permalink
fix variable_task, nested_key_task
Browse files Browse the repository at this point in the history
  • Loading branch information
tunglxfast committed Dec 13, 2024
1 parent 65bd272 commit 26af513
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 80 deletions.
4 changes: 3 additions & 1 deletion datamimic_ce/clients/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def get_by_page_with_type(self, table_name: str, pagination: DataSourcePaginatio
"""

@abstractmethod
def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination) -> list:
def get_cyclic_data(
self, query: str, data_len: int, pagination: DataSourcePagination, cyclic: bool | None = False
) -> list:
"""
Get cyclic data from database
"""
Expand Down
4 changes: 3 additions & 1 deletion datamimic_ce/clients/mongodb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def _create_connection(self) -> MongoClient:

return MongoClient(**ars)

def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination) -> list:
def get_cyclic_data(
self, query: str, data_len: int, pagination: DataSourcePagination, cyclic: bool = False
) -> list:
"""
Get cyclic data from query
"""
Expand Down
4 changes: 3 additions & 1 deletion datamimic_ce/clients/rdbms_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,9 @@ def insert(self, table_name: str, data_list: list):
except Exception as err:
raise RuntimeError(f"Error when write data to RDBMS. Message: {err}") from err

def get_cyclic_data(self, query: str, cyclic: bool, data_len: int, pagination: DataSourcePagination) -> list:
def get_cyclic_data(
self, query: str, data_len: int, pagination: DataSourcePagination, cyclic: bool = False
) -> list:
"""
Get cyclic data from relational database
"""
Expand Down
1 change: 1 addition & 0 deletions datamimic_ce/exporters/xml_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def _sanitize_record(data: dict[str, Any]) -> dict[str, Any]:
Returns:
dict: Sanitized data with string values and attribute prefixes.
"""

def sanitize_value(recursion_data: Any) -> Any:
if isinstance(recursion_data, dict):
sanitized = {}
Expand Down
10 changes: 6 additions & 4 deletions datamimic_ce/tasks/generate_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from collections.abc import Callable
from contextlib import contextmanager
from pathlib import Path
from typing import Literal, Dict
from typing import Literal, Dict, Optional

import dill
import xmltodict
Expand Down Expand Up @@ -407,12 +407,12 @@ def _load_csv_file(
ctx: SetupContext,
file_path: Path,
separator: str,
cyclic: bool,
start_idx: int,
end_idx: int,
source_scripted: bool,
prefix: str,
suffix: str,
cyclic: Optional[bool] = False,
) -> list[dict]:
"""
Load CSV content from file with skip and limit.
Expand Down Expand Up @@ -442,7 +442,9 @@ def _load_csv_file(
return result


def _load_json_file(task_id: str, file_path: Path, cyclic: bool, start_idx: int, end_idx: int) -> list[dict]:
def _load_json_file(
task_id: str, file_path: Path, start_idx: int, end_idx: int, cyclic: Optional[bool] = False
) -> list[dict]:
"""
Load JSON content from file using skip and limit.
Expand Down Expand Up @@ -476,7 +478,7 @@ def _load_json_file(task_id: str, file_path: Path, cyclic: bool, start_idx: int,
return DataSourceUtil.get_cyclic_data_list(data=data, cyclic=cyclic, pagination=pagination)


def _load_xml_file(file_path: Path, cyclic: bool, start_idx: int, end_idx: int) -> list[dict]:
def _load_xml_file(file_path: Path, start_idx: int, end_idx: int, cyclic: Optional[bool] = False) -> list[dict]:
"""
Load XML content from file using skip and limit.
Expand Down
26 changes: 15 additions & 11 deletions datamimic_ce/tasks/key_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,18 @@ def __init__(
def pre_execute(self, ctx: Context):
if self.statement.generator is not None and "SequenceTableGenerator" in self.statement.generator:
sequence_table_generator = GeneratorUtil(ctx).create_generator(
self._statement.generator, self._statement, self._pagination
str(self._statement.generator), self._statement, self._pagination
)
sequence_table_generator.pre_execute(ctx)

@property
def statement(self) -> KeyStatement:
return self._statement
if isinstance(self._statement, KeyStatement):
return self._statement
else:
raise TypeError("Expected an KeyStatement")

def execute(self, ctx: GenIterContext):
def execute(self, ctx: GenIterContext): # type: ignore
"""
Generate data for element "attribute"
If 'type' element is not specified, then default type of generated data is string
Expand All @@ -77,14 +80,15 @@ def execute(self, ctx: GenIterContext):
value = self._convert_generated_value(value)

attributes = {}
for stmt in self._statement.sub_statements:
task = task_util_cls.get_task_by_statement(root_ctx, stmt)
if isinstance(task, ElementTask):
attributes.update(task.generate_xml_attribute(ctx))
else:
raise ValueError(
f"Cannot execute subtask {task.__class__.__name__} of <key> '{self.statement.name}'"
)
if hasattr(self._statement, "sub_statements"):
for stmt in self._statement.sub_statements:
task = task_util_cls.get_task_by_statement(root_ctx, stmt)
if isinstance(task, ElementTask):
attributes.update(task.generate_xml_attribute(ctx))
else:
raise ValueError(
f"Cannot execute subtask {task.__class__.__name__} of <key> '{self.statement.name}'"
)

result = value if len(attributes) == 0 else {"#text": value, **attributes}
# Add field "attribute" into current product
Expand Down
54 changes: 28 additions & 26 deletions datamimic_ce/tasks/nested_key_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def __init__(
self._default_value = statement.default_value
self._descriptor_dir = ctx.root.descriptor_dir
self._class_factory_util = class_factory_util
self._sub_tasks = None
self._sub_tasks: list | None = None
self._converter_list = class_factory_util.get_task_util_cls().create_converter_list(ctx, statement.converter)

@property
def statement(self) -> NestedKeyStatement:
return self._statement

def execute(self, parent_context: GenIterContext):
def execute(self, parent_context: GenIterContext): # type: ignore
"""
Generate data for element "nestedKey"
:param parent_context:
Expand Down Expand Up @@ -97,18 +97,19 @@ def _execute_generate(self, parent_context: GenIterContext) -> None:
nestedkey_type = self._statement.type
if nestedkey_type == DATA_TYPE_LIST:
nestedkey_len = self._determine_nestedkey_length(context=parent_context)
self._lazy_init_sub_tasks(parent_context=parent_context, nestedkey_length=nestedkey_len)
value = []
# Generate data for each nestedkey record
for _ in range(nestedkey_len):
# Create sub-context for each list element creation
ctx = GenIterContext(parent_context, self._statement.name)
generated_value = self._try_execute_sub_tasks(ctx)
value.append(generated_value)
if nestedkey_len:
self._lazy_init_sub_tasks(parent_context=parent_context, nestedkey_length=nestedkey_len)
# Generate data for each nestedkey record
for _ in range(nestedkey_len):
# Create sub-context for each list element creation
ctx = GenIterContext(parent_context, str(self._statement.name))
generated_value = self._try_execute_sub_tasks(ctx)
value.append(generated_value)
elif nestedkey_type == DATA_TYPE_DICT:
self._lazy_init_sub_tasks(parent_context=parent_context, nestedkey_length=1)
# Create sub-context for nestedkey creation
ctx = GenIterContext(parent_context, self._statement.name)
ctx = GenIterContext(parent_context, str(self._statement.name))
value = self._try_execute_sub_tasks(ctx)
else:
# Load value from current product then assign to nestedkey if type is not defined,
Expand All @@ -119,7 +120,7 @@ def _execute_generate(self, parent_context: GenIterContext) -> None:
)
nestedkey_data = parent_context.current_product[self._statement.name]
self._lazy_init_sub_tasks(parent_context=parent_context, nestedkey_length=len(nestedkey_data))
ctx = GenIterContext(parent_context, self._statement.name)
ctx = GenIterContext(parent_context, str(self._statement.name))
ctx.current_product = nestedkey_data
value = self._try_execute_sub_tasks(ctx)

Expand Down Expand Up @@ -178,19 +179,20 @@ def _try_execute_sub_tasks(self, ctx: GenIterContext) -> dict:
"""
attributes = {}
# Try to execute sub_tasks
for sub_task in self._sub_tasks:
try:
if isinstance(sub_task, ElementTask):
attributes.update(sub_task.generate_xml_attribute(ctx))
else:
sub_task.execute(ctx)
except StopIteration:
# Stop generating data if one of datasource reach the end
logger.info(
f"Data generator sub-task {sub_task.__class__.__name__} '{sub_task.statement.name}' "
f"has already reached the end"
)
break
if self._sub_tasks:
for sub_task in self._sub_tasks:
try:
if isinstance(sub_task, ElementTask):
attributes.update(sub_task.generate_xml_attribute(ctx))
else:
sub_task.execute(ctx)
except StopIteration:
# Stop generating data if one of datasource reach the end
logger.info(
f"Data generator sub-task {sub_task.__class__.__name__} '{sub_task.statement.name}' "
f"has already reached the end"
)
break
ctx.current_product = self._post_convert(ctx.current_product)
return {**ctx.current_product, **attributes}

Expand Down Expand Up @@ -291,7 +293,7 @@ def _modify_nestedkey_data_dict(self, parent_context: GenIterContext, value: dic
"""
self._lazy_init_sub_tasks(parent_context=parent_context, nestedkey_length=1)
# Create sub-context for nestedkey creation
ctx = GenIterContext(parent_context, self._statement.name)
ctx = GenIterContext(parent_context, str(self._statement.name))
ctx.current_product = copy.copy(value)
modified_value = self._try_execute_sub_tasks(ctx)
return modified_value
Expand Down Expand Up @@ -321,7 +323,7 @@ def _modify_nestedkey_data_list(self, parent_context: GenIterContext, value: lis
self._lazy_init_sub_tasks(parent_context=parent_context, nestedkey_length=nestedkey_len)
# Modify each nestedkey of the data
for idx in range(nestedkey_len):
ctx = GenIterContext(parent_context, self._statement.name)
ctx = GenIterContext(parent_context, str(self._statement.name))
ctx.current_product = iterate_value[idx]

# Ensure current_product is a dictionary
Expand Down
2 changes: 1 addition & 1 deletion datamimic_ce/tasks/reference_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


class ReferenceTask(Task):
def __init__(self, statement: ReferenceStatement, pagination: DataSourcePagination):
def __init__(self, statement: ReferenceStatement, pagination: DataSourcePagination | None):
self._statement = statement
self._pagination = pagination
self._iterator = None
Expand Down
46 changes: 22 additions & 24 deletions datamimic_ce/tasks/task_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# For questions and support, contact: [email protected]

import re
from typing import Any
from typing import Any, Optional, Union

from datamimic_ce.clients.mongodb_client import MongoDBClient
from datamimic_ce.clients.rdbms_client import RdbmsClient
Expand Down Expand Up @@ -148,28 +148,28 @@ def evaluate_file_script_template(ctx: Context, datas: Any, prefix: str, suffix:
e.g. '{1+3}' -> 4
"""
if isinstance(datas, dict):
result = {}
dict_result = {}
for key, json_value in datas.items():
if isinstance(json_value, dict | list):
value = TaskUtil.evaluate_file_script_template(ctx, json_value, prefix, suffix)
elif isinstance(json_value, str):
value = TaskUtil._evaluate_script_value(ctx, json_value, prefix, suffix)
else:
value = json_value
result.update({key: value})
return result
dict_result.update({key: value})
return dict_result
elif isinstance(datas, list):
result = []
list_result: list[Any] = []
for value in datas:
if isinstance(value, list):
result.extend(TaskUtil.evaluate_file_script_template(ctx, value, prefix, suffix))
list_result.extend(TaskUtil.evaluate_file_script_template(ctx, value, prefix, suffix))
elif isinstance(value, dict):
result.append(TaskUtil.evaluate_file_script_template(ctx, value, prefix, suffix))
list_result.append(TaskUtil.evaluate_file_script_template(ctx, value, prefix, suffix))
elif isinstance(value, str):
result.append(TaskUtil._evaluate_script_value(ctx, value, prefix, suffix))
list_result.append(TaskUtil._evaluate_script_value(ctx, value, prefix, suffix))
else:
result.append(value)
return result
list_result.append(value)
return list_result
elif isinstance(datas, str):
return TaskUtil._evaluate_script_value(ctx, datas, prefix, suffix)
else:
Expand All @@ -190,7 +190,7 @@ def _evaluate_script_value(ctx: Context, data: str, prefix: str, suffix: str):
is_whole_source_script = data[0] == "{" and data[-1] == "}"
if is_whole_source_script:
match = re.search(r"^{(.*)}$", data)
return ctx.evaluate_python_expression(match.group(1))
return ctx.evaluate_python_expression(match.group(1)) if match is not None else None

return TaskUtil.evaluate_variable_concat_prefix_suffix(ctx, data, prefix, suffix)

Expand Down Expand Up @@ -294,7 +294,7 @@ def gen_task_load_data_from_source(
"""
build_from_source = True
root_context = context.root
source_data = []
source_data: Union[dict, list] = []

# get prefix and suffix
setup_ctx = context.root if not isinstance(context, SetupContext) else context
Expand All @@ -309,21 +309,21 @@ def gen_task_load_data_from_source(
ctx=context,
file_path=root_context.descriptor_dir / source_str,
separator=separator,
cyclic=stmt.cyclic,
start_idx=load_start_idx,
end_idx=load_end_idx,
source_scripted=source_scripted,
prefix=prefix,
suffix=suffix,
cyclic=stmt.cyclic,
)
# Load data from JSON
elif source_str.endswith(".json"):
source_data = _load_json_file(
root_context.task_id,
root_context.descriptor_dir / source_str,
stmt.cyclic,
load_start_idx,
load_end_idx,
stmt.cyclic,
)
# if sourceScripted then evaluate python expression in json
if source_scripted:
Expand All @@ -336,10 +336,7 @@ def gen_task_load_data_from_source(
# Load data from XML
elif source_str.endswith(".xml"):
source_data = _load_xml_file(
root_context.descriptor_dir / source_str,
stmt.cyclic,
load_start_idx,
load_end_idx,
root_context.descriptor_dir / source_str, load_start_idx, load_end_idx, stmt.cyclic
)
# if sourceScripted then evaluate python expression in json
if source_scripted:
Expand Down Expand Up @@ -379,14 +376,15 @@ def gen_task_load_data_from_source(
source_data = client.get_by_page_with_query(original_query=selector, pagination=load_pagination)
else:
source_data = client.get_by_page_with_type(
table_name=stmt.type or stmt.name, pagination=load_pagination
table_name=stmt.type or stmt.name,
pagination=load_pagination, # type: ignore
)
else:
raise ValueError(f"Cannot load data from client: {type(client).__name__}")
else:
raise ValueError(f"cannot find data source {source_str} for iterate task")

return source_data, build_from_source
return source_data, build_from_source # type: ignore

# @staticmethod
# def consume_minio_after_page_processing(stmt, context: Context) -> None:
Expand Down Expand Up @@ -441,13 +439,13 @@ def consume_product_by_page(
# Create exporters cache in root context if it doesn't exist
if not hasattr(root_context, "_task_exporters"):
# Using task_id to namespace the cache
root_context._task_exporters = {}
root_context._task_exporters = {} # type: ignore # skip mypy check

# Create a unique cache key incorporating task_id and statement details
cache_key = f"{root_context.task_id}_{stmt.name}_{stmt.storage_id}_{stmt}"

# Get or create exporters
if cache_key not in root_context._task_exporters:
if cache_key not in root_context._task_exporters: # type: ignore # skip mypy check
# Create the consumer set once
consumer_set = stmt.targets.copy()
# consumer_set.add(EXPORTER_PREVIEW) deactivating preview exporter for multi-process
Expand All @@ -465,14 +463,14 @@ def consume_product_by_page(
)

# Cache the exporters
root_context._task_exporters[cache_key] = {
root_context._task_exporters[cache_key] = { # type: ignore # skip mypy check
"with_operation": consumers_with_operation,
"without_operation": consumers_without_operation,
"page_count": 0, # Track number of pages processed
}

# Get cached exporters
exporters = root_context._task_exporters[cache_key]
exporters = root_context._task_exporters[cache_key] # type: ignore # skip mypy check
exporters["page_count"] += 1

# Use cached exporters
Expand Down
Loading

0 comments on commit 26af513

Please sign in to comment.