From 6eb7d1c06f7da6b103f1cf6918a059c38bb9da58 Mon Sep 17 00:00:00 2001 From: Sergiy Date: Thu, 6 Feb 2025 05:12:00 +0000 Subject: [PATCH] fix: fix builder sequential bug --- tests/unit/test_builder.py | 12 +++++----- zetta_utils/builder/building.py | 39 ++++++++++++++++++++------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py index 3681719ce..c62ffcedf 100644 --- a/tests/unit/test_builder.py +++ b/tests/unit/test_builder.py @@ -345,7 +345,7 @@ def dummy_f(arg): @pytest.fixture -def file_io_funcitons(): +def file_io_functions(): builder.register("assert_file_exists", versions=">=0.0.0")(assert_file_exists) builder.register("write_file", versions=">=0.0.0")(write_file) builder.register("dummy_f", versions=">=0.0.0")(dummy_f) @@ -355,7 +355,7 @@ def file_io_funcitons(): builder.unregister(name="assert_file_exists", fn=assert_file_exists) -def test_sequential_side_effects_x0(file_io_funcitons): +def test_sequential_side_effects_x0(file_io_functions): with tempfile.TemporaryDirectory() as tmp_dir: file_path = os.path.join(tmp_dir, "file.txt") builder.build( @@ -370,7 +370,7 @@ def test_sequential_side_effects_x0(file_io_funcitons): ) -def test_sequential_side_effects_x1(file_io_funcitons): +def test_sequential_side_effects_x1(file_io_functions): with tempfile.TemporaryDirectory() as tmp_dir: file_path = os.path.join(tmp_dir, "file.txt") builder.build( @@ -385,7 +385,7 @@ def test_sequential_side_effects_x1(file_io_funcitons): ) -def test_sequential_side_effects_x2(file_io_funcitons): +def test_sequential_side_effects_x2(file_io_functions): with tempfile.TemporaryDirectory() as tmp_dir: file_path = os.path.join(tmp_dir, "file.txt") builder.build( @@ -400,7 +400,7 @@ def test_sequential_side_effects_x2(file_io_funcitons): ) -def test_sequential_side_effects_x3(file_io_funcitons): +def test_sequential_side_effects_x3(file_io_functions): with tempfile.TemporaryDirectory() as tmp_dir: file_path = os.path.join(tmp_dir, "file.txt") builder.build( @@ -413,7 +413,7 @@ def test_sequential_side_effects_x3(file_io_funcitons): ) -def test_sequential_side_effects_x4(file_io_funcitons): +def test_sequential_side_effects_x4(file_io_functions): with tempfile.TemporaryDirectory() as tmp_dir: file_path = os.path.join(tmp_dir, "file.txt") builder.build( diff --git a/zetta_utils/builder/building.py b/zetta_utils/builder/building.py index b7e792e68..ce94b0631 100644 --- a/zetta_utils/builder/building.py +++ b/zetta_utils/builder/building.py @@ -59,7 +59,7 @@ def build(spec: dict | list | None = None, path: str | None = None, parallel: bo def _build(spec: JsonSerializableValue, parallel: bool, version: str, name_prefix: str) -> Any: - stages = _parse_stages(spec, version=version, name_prefix=name_prefix) + stages = _parse_stages(spec, version=version, name_prefix=name_prefix, parallel=parallel) result = _execute_build_stages(stages=stages, parallel=parallel) return result @@ -151,14 +151,7 @@ def _process_result(obj: ObjectToBeBuilt, result: Any): if len(results_parallel) > 0: obj_result = results_parallel[-1] else: - for obj in stage.parallel_part: - obj_result = _build_object( - fn=obj.fn, - kwargs=obj.kwargs, - spec=obj.spec, - name_prefix=obj.name_prefix, - ) - _process_result(obj, obj_result) + assert len(stage.parallel_part) == 0 return obj_result @@ -181,7 +174,9 @@ def _is_trivial(obj) -> bool: ) -def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -> list[Stage]: +def _parse_stages( + spec: JsonSerializableValue, name_prefix: str, version: str, parallel: bool +) -> list[Stage]: stage_dict: dict[int, Stage] = defaultdict(Stage) final_obj = _parse_stages_inner( spec=spec, @@ -191,6 +186,7 @@ def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) - parent=None, parent_kwarg_name=None, name_prefix=name_prefix, + parallel=parallel, ) result = [stage_dict[k] for k in reversed(sorted(stage_dict.keys()))] if _is_trivial(final_obj): @@ -212,8 +208,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement name_prefix: str, parent: ObjectToBeBuilt | None, parent_kwarg_name: str | None, + parallel: bool, ) -> ObjectToBeBuilt | JsonSerializableValue | BuilderPartial: result: ObjectToBeBuilt | JsonSerializableValue | BuilderPartial + if isinstance(spec, (int, float, bool, str)) or spec is None: result = spec elif isinstance(spec, (list, tuple)): @@ -234,6 +232,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement parent_kwarg_name=str(i), version=version, name_prefix=f"{name_prefix}[{i}]", + parallel=parallel, ) for i, e in enumerate(spec) ] @@ -243,10 +242,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement else: if any(not v.allow_parallel for v in args if isinstance(v, ObjectToBeBuilt)): this_obj.allow_parallel = False + this_obj.kwargs = {str(i): v for i, v in enumerate(args) if _is_trivial(v)} stages_dict[level].sequential_part.append(this_obj) result = this_obj - elif isinstance(spec, dict) and "@type" not in spec: assert "@mode" not in spec this_obj = ObjectToBeBuilt( @@ -265,6 +264,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement parent_kwarg_name=k, version=version, name_prefix=f"{name_prefix}[{k}]", + parallel=parallel, ) for k, v in spec.items() } @@ -276,6 +276,8 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt) ): this_obj.allow_parallel = False + this_obj.allow_parallel = False + this_obj.kwargs = {k: v for k, v in kwargs.items() if _is_trivial(v)} stages_dict[level].sequential_part.append(this_obj) result = this_obj @@ -303,15 +305,21 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement name_prefix=name_prefix, allow_parallel=entry.allow_parallel, ) + if parallel: + children_level_increment = 1 + else: + children_level_increment = 0 + kwargs = { k: _parse_stages_inner( v, stages_dict=stages_dict, - level=level + 1, + level=level + children_level_increment, parent=this_obj, parent_kwarg_name=k, version=version, name_prefix=f"{name_prefix}.{k}", + parallel=parallel, ) for k, v in spec.items() if not k.startswith("@") @@ -321,10 +329,11 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt) ): this_obj.allow_parallel = False - if entry.allow_parallel: - stages_dict[level + 1].parallel_part.append(this_obj) + + if entry.allow_parallel and parallel: + stages_dict[level + children_level_increment].parallel_part.append(this_obj) else: - stages_dict[level + 1].sequential_part.append(this_obj) + stages_dict[level + children_level_increment].sequential_part.append(this_obj) result = this_obj else: raise ValueError(f"Unsupported type: {type(spec)}")