Skip to content

Commit

Permalink
fix: fix builder sequential bug
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 6, 2025
1 parent a8c49bc commit 2bec508
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 21 deletions.
12 changes: 6 additions & 6 deletions tests/unit/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def dummy_f(arg):


@pytest.fixture
def file_io_funcitons():
def file_io_functions():
builder.register("assert_file_exists", versions=">=0.0.0")(assert_file_exists)
builder.register("write_file", versions=">=0.0.0")(write_file)
builder.register("dummy_f", versions=">=0.0.0")(dummy_f)
Expand All @@ -355,7 +355,7 @@ def file_io_funcitons():
builder.unregister(name="assert_file_exists", fn=assert_file_exists)


def test_sequential_side_effects_x0(file_io_funcitons):
def test_sequential_side_effects_x0(file_io_functions):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
Expand All @@ -370,7 +370,7 @@ def test_sequential_side_effects_x0(file_io_funcitons):
)


def test_sequential_side_effects_x1(file_io_funcitons):
def test_sequential_side_effects_x1(file_io_functions):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
Expand All @@ -385,7 +385,7 @@ def test_sequential_side_effects_x1(file_io_funcitons):
)


def test_sequential_side_effects_x2(file_io_funcitons):
def test_sequential_side_effects_x2(file_io_functions):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
Expand All @@ -400,7 +400,7 @@ def test_sequential_side_effects_x2(file_io_funcitons):
)


def test_sequential_side_effects_x3(file_io_funcitons):
def test_sequential_side_effects_x3(file_io_functions):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
Expand All @@ -413,7 +413,7 @@ def test_sequential_side_effects_x3(file_io_funcitons):
)


def test_sequential_side_effects_x4(file_io_funcitons):
def test_sequential_side_effects_x4(file_io_functions):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
Expand Down
39 changes: 24 additions & 15 deletions zetta_utils/builder/building.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build(spec: dict | list | None = None, path: str | None = None, parallel: bo


def _build(spec: JsonSerializableValue, parallel: bool, version: str, name_prefix: str) -> Any:
stages = _parse_stages(spec, version=version, name_prefix=name_prefix)
stages = _parse_stages(spec, version=version, name_prefix=name_prefix, parallel=parallel)
result = _execute_build_stages(stages=stages, parallel=parallel)
return result

Expand Down Expand Up @@ -151,14 +151,7 @@ def _process_result(obj: ObjectToBeBuilt, result: Any):
if len(results_parallel) > 0:
obj_result = results_parallel[-1]
else:
for obj in stage.parallel_part:
obj_result = _build_object(
fn=obj.fn,
kwargs=obj.kwargs,
spec=obj.spec,
name_prefix=obj.name_prefix,
)
_process_result(obj, obj_result)
assert len(stage.parallel_part) == 0

return obj_result

Expand All @@ -181,7 +174,9 @@ def _is_trivial(obj) -> bool:
)


def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -> list[Stage]:
def _parse_stages(
spec: JsonSerializableValue, name_prefix: str, version: str, parallel: bool
) -> list[Stage]:
stage_dict: dict[int, Stage] = defaultdict(Stage)
final_obj = _parse_stages_inner(
spec=spec,
Expand All @@ -191,6 +186,7 @@ def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -
parent=None,
parent_kwarg_name=None,
name_prefix=name_prefix,
parallel=parallel,
)
result = [stage_dict[k] for k in reversed(sorted(stage_dict.keys()))]
if _is_trivial(final_obj):
Expand All @@ -212,8 +208,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
name_prefix: str,
parent: ObjectToBeBuilt | None,
parent_kwarg_name: str | None,
parallel: bool,
) -> ObjectToBeBuilt | JsonSerializableValue | BuilderPartial:
result: ObjectToBeBuilt | JsonSerializableValue | BuilderPartial

if isinstance(spec, (int, float, bool, str)) or spec is None:
result = spec
elif isinstance(spec, (list, tuple)):
Expand All @@ -234,6 +232,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
parent_kwarg_name=str(i),
version=version,
name_prefix=f"{name_prefix}[{i}]",
parallel=parallel,
)
for i, e in enumerate(spec)
]
Expand All @@ -243,10 +242,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
else:
if any(not v.allow_parallel for v in args if isinstance(v, ObjectToBeBuilt)):
this_obj.allow_parallel = False

this_obj.kwargs = {str(i): v for i, v in enumerate(args) if _is_trivial(v)}
stages_dict[level].sequential_part.append(this_obj)
result = this_obj

elif isinstance(spec, dict) and "@type" not in spec:
assert "@mode" not in spec
this_obj = ObjectToBeBuilt(
Expand All @@ -265,6 +264,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
parent_kwarg_name=k,
version=version,
name_prefix=f"{name_prefix}[{k}]",
parallel=parallel,
)
for k, v in spec.items()
}
Expand All @@ -276,6 +276,8 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt)
):
this_obj.allow_parallel = False
this_obj.allow_parallel = False

this_obj.kwargs = {k: v for k, v in kwargs.items() if _is_trivial(v)}
stages_dict[level].sequential_part.append(this_obj)
result = this_obj
Expand Down Expand Up @@ -303,15 +305,21 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
name_prefix=name_prefix,
allow_parallel=entry.allow_parallel,
)
if parallel:
children_level_increment = 1
else:
children_level_increment = 0

kwargs = {
k: _parse_stages_inner(
v,
stages_dict=stages_dict,
level=level + 1,
level=level + children_level_increment,
parent=this_obj,
parent_kwarg_name=k,
version=version,
name_prefix=f"{name_prefix}.{k}",
parallel=parallel,
)
for k, v in spec.items()
if not k.startswith("@")
Expand All @@ -321,10 +329,11 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt)
):
this_obj.allow_parallel = False
if entry.allow_parallel:
stages_dict[level + 1].parallel_part.append(this_obj)

if entry.allow_parallel and parallel:
stages_dict[level + children_level_increment].parallel_part.append(this_obj)
else:
stages_dict[level + 1].sequential_part.append(this_obj)
stages_dict[level + children_level_increment].sequential_part.append(this_obj)
result = this_obj
else:
raise ValueError(f"Unsupported type: {type(spec)}")
Expand Down

0 comments on commit 2bec508

Please sign in to comment.