Skip to content

Commit

Permalink
fix: fix builder sequential bug
Browse files Browse the repository at this point in the history
  • Loading branch information
supersergiy committed Feb 6, 2025
1 parent 84e723c commit 935d281
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions zetta_utils/builder/building.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build(spec: dict | list | None = None, path: str | None = None, parallel: bo


def _build(spec: JsonSerializableValue, parallel: bool, version: str, name_prefix: str) -> Any:
stages = _parse_stages(spec, version=version, name_prefix=name_prefix)
stages = _parse_stages(spec, version=version, name_prefix=name_prefix, parallel=parallel)
result = _execute_build_stages(stages=stages, parallel=parallel)
return result

Expand Down Expand Up @@ -181,7 +181,9 @@ def _is_trivial(obj) -> bool:
)


def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -> list[Stage]:
def _parse_stages(
spec: JsonSerializableValue, name_prefix: str, version: str, parallel: bool
) -> list[Stage]:
stage_dict: dict[int, Stage] = defaultdict(Stage)
final_obj = _parse_stages_inner(
spec=spec,
Expand All @@ -191,6 +193,7 @@ def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -
parent=None,
parent_kwarg_name=None,
name_prefix=name_prefix,
parallel=parallel,
)
result = [stage_dict[k] for k in reversed(sorted(stage_dict.keys()))]
if _is_trivial(final_obj):
Expand All @@ -212,8 +215,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
name_prefix: str,
parent: ObjectToBeBuilt | None,
parent_kwarg_name: str | None,
parallel: bool,
) -> ObjectToBeBuilt | JsonSerializableValue | BuilderPartial:
result: ObjectToBeBuilt | JsonSerializableValue | BuilderPartial

if isinstance(spec, (int, float, bool, str)) or spec is None:
result = spec
elif isinstance(spec, (list, tuple)):
Expand All @@ -234,6 +239,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
parent_kwarg_name=str(i),
version=version,
name_prefix=f"{name_prefix}[{i}]",
parallel=parallel,
)
for i, e in enumerate(spec)
]
Expand All @@ -243,10 +249,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
else:
if any(not v.allow_parallel for v in args if isinstance(v, ObjectToBeBuilt)):
this_obj.allow_parallel = False

this_obj.kwargs = {str(i): v for i, v in enumerate(args) if _is_trivial(v)}
stages_dict[level].sequential_part.append(this_obj)
result = this_obj

elif isinstance(spec, dict) and "@type" not in spec:
assert "@mode" not in spec
this_obj = ObjectToBeBuilt(
Expand All @@ -265,6 +271,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
parent_kwarg_name=k,
version=version,
name_prefix=f"{name_prefix}[{k}]",
parallel=parallel,
)
for k, v in spec.items()
}
Expand All @@ -276,6 +283,8 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt)
):
this_obj.allow_parallel = False
this_obj.allow_parallel = False

this_obj.kwargs = {k: v for k, v in kwargs.items() if _is_trivial(v)}
stages_dict[level].sequential_part.append(this_obj)
result = this_obj
Expand Down Expand Up @@ -303,15 +312,21 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
name_prefix=name_prefix,
allow_parallel=entry.allow_parallel,
)
if parallel:
children_level_increment = 1
else:
children_level_increment = 0

kwargs = {
k: _parse_stages_inner(
v,
stages_dict=stages_dict,
level=level + 1,
level=level + children_level_increment,
parent=this_obj,
parent_kwarg_name=k,
version=version,
name_prefix=f"{name_prefix}.{k}",
parallel=parallel,
)
for k, v in spec.items()
if not k.startswith("@")
Expand All @@ -321,10 +336,11 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt)
):
this_obj.allow_parallel = False
if entry.allow_parallel:
stages_dict[level + 1].parallel_part.append(this_obj)

if entry.allow_parallel and parallel:
stages_dict[level + children_level_increment].parallel_part.append(this_obj)
else:
stages_dict[level + 1].sequential_part.append(this_obj)
stages_dict[level + children_level_increment].sequential_part.append(this_obj)
result = this_obj
else:
raise ValueError(f"Unsupported type: {type(spec)}")
Expand Down

0 comments on commit 935d281

Please sign in to comment.