diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index 36c210012..3681719ce 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -1,4 +1,6 @@ # pylint: disable=missing-docstring,protected-access,unused-argument,redefined-outer-name,invalid-name, line-too-long +import os +import tempfile import time from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass @@ -327,3 +329,101 @@ def test_unpickleable_dict(): ) ) assert len(recons) == 0 + + +def assert_file_exists(path): + assert os.path.exists(path) + + +def write_file(path): + with open(path, "w", encoding="utf8") as f: + f.write("hello world") + + +def dummy_f(arg): + ... + + +@pytest.fixture +def file_io_funcitons(): + builder.register("assert_file_exists", versions=">=0.0.0")(assert_file_exists) + builder.register("write_file", versions=">=0.0.0")(write_file) + builder.register("dummy_f", versions=">=0.0.0")(dummy_f) + yield + builder.unregister(name="dummy_f", fn=dummy_f) + builder.unregister(name="write_file", fn=write_file) + builder.unregister(name="assert_file_exists", fn=assert_file_exists) + + +def test_sequential_side_effects_x0(file_io_funcitons): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = os.path.join(tmp_dir, "file.txt") + builder.build( + { + "parts": [ + {"@type": "write_file", "path": file_path}, + {"@type": "assert_file_exists", "path": file_path}, + [{"@type": "assert_file_exists", "path": file_path}], + [[{"@type": "assert_file_exists", "path": file_path}]], + ] + } + ) + + +def test_sequential_side_effects_x1(file_io_funcitons): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = os.path.join(tmp_dir, "file.txt") + builder.build( + { + "parts": [ + [[[{"@type": "write_file", "path": file_path}]]], + {"@type": "assert_file_exists", "path": file_path}, + [{"@type": "assert_file_exists", "path": file_path}], + [[{"@type": "assert_file_exists", "path": file_path}]], + ] + } + ) + + +def test_sequential_side_effects_x2(file_io_funcitons): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = os.path.join(tmp_dir, "file.txt") + builder.build( + { + "parts": [ + [[{"inner": [{"@type": "write_file", "path": file_path}]}]], + {"@type": "assert_file_exists", "path": file_path}, + [{"@type": "assert_file_exists", "path": file_path}], + [[{"@type": "assert_file_exists", "path": file_path}]], + ] + } + ) + + +def test_sequential_side_effects_x3(file_io_funcitons): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = os.path.join(tmp_dir, "file.txt") + builder.build( + { + "parts": [ + {"@type": "dummy_f", "arg": [{"@type": "write_file", "path": file_path}]}, + {"@type": "assert_file_exists", "path": file_path}, + ] + } + ) + + +def test_sequential_side_effects_x4(file_io_funcitons): + with tempfile.TemporaryDirectory() as tmp_dir: + file_path = os.path.join(tmp_dir, "file.txt") + builder.build( + { + "parts": [ + {"@type": "write_file", "path": file_path}, + { + "@type": "dummy_f", + "arg": {"@type": "assert_file_exists", "path": file_path}, + }, + ] + } + ) diff --git a/zetta_utils/builder/building.py b/zetta_utils/builder/building.py index b7e792e68..e99327e85 100644 --- a/zetta_utils/builder/building.py +++ b/zetta_utils/builder/building.py @@ -59,7 +59,7 @@ def build(spec: dict | list | None = None, path: str | None = None, parallel: bo def _build(spec: JsonSerializableValue, parallel: bool, version: str, name_prefix: str) -> Any: - stages = _parse_stages(spec, version=version, name_prefix=name_prefix) + stages = _parse_stages(spec, version=version, name_prefix=name_prefix, parallel=parallel) result = _execute_build_stages(stages=stages, parallel=parallel) return result @@ -181,7 +181,9 @@ def _is_trivial(obj) -> bool: ) -def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -> list[Stage]: +def _parse_stages( + spec: JsonSerializableValue, name_prefix: str, version: str, parallel: bool +) -> list[Stage]: stage_dict: dict[int, Stage] = defaultdict(Stage) final_obj = _parse_stages_inner( spec=spec, @@ -191,6 +193,7 @@ def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) - parent=None, parent_kwarg_name=None, name_prefix=name_prefix, + parallel=parallel, ) result = [stage_dict[k] for k in reversed(sorted(stage_dict.keys()))] if _is_trivial(final_obj): @@ -212,8 +215,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement name_prefix: str, parent: ObjectToBeBuilt | None, parent_kwarg_name: str | None, + parallel: bool, ) -> ObjectToBeBuilt | JsonSerializableValue | BuilderPartial: result: ObjectToBeBuilt | JsonSerializableValue | BuilderPartial + if isinstance(spec, (int, float, bool, str)) or spec is None: result = spec elif isinstance(spec, (list, tuple)): @@ -234,6 +239,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement parent_kwarg_name=str(i), version=version, name_prefix=f"{name_prefix}[{i}]", + parallel=parallel, ) for i, e in enumerate(spec) ] @@ -243,10 +249,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement else: if any(not v.allow_parallel for v in args if isinstance(v, ObjectToBeBuilt)): this_obj.allow_parallel = False + this_obj.kwargs = {str(i): v for i, v in enumerate(args) if _is_trivial(v)} stages_dict[level].sequential_part.append(this_obj) result = this_obj - elif isinstance(spec, dict) and "@type" not in spec: assert "@mode" not in spec this_obj = ObjectToBeBuilt( @@ -265,6 +271,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement parent_kwarg_name=k, version=version, name_prefix=f"{name_prefix}[{k}]", + parallel=parallel, ) for k, v in spec.items() } @@ -276,6 +283,8 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt) ): this_obj.allow_parallel = False + this_obj.allow_parallel = False + this_obj.kwargs = {k: v for k, v in kwargs.items() if _is_trivial(v)} stages_dict[level].sequential_part.append(this_obj) result = this_obj @@ -303,15 +312,21 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement name_prefix=name_prefix, allow_parallel=entry.allow_parallel, ) + if parallel: + children_level_increment = 1 + else: + children_level_increment = 0 + kwargs = { k: _parse_stages_inner( v, stages_dict=stages_dict, - level=level + 1, + level=level + children_level_increment, parent=this_obj, parent_kwarg_name=k, version=version, name_prefix=f"{name_prefix}.{k}", + parallel=parallel, ) for k, v in spec.items() if not k.startswith("@") @@ -321,10 +336,11 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt) ): this_obj.allow_parallel = False - if entry.allow_parallel: - stages_dict[level + 1].parallel_part.append(this_obj) + + if entry.allow_parallel and parallel: + stages_dict[level + children_level_increment].parallel_part.append(this_obj) else: - stages_dict[level + 1].sequential_part.append(this_obj) + stages_dict[level + children_level_increment].sequential_part.append(this_obj) result = this_obj else: raise ValueError(f"Unsupported type: {type(spec)}")