Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Builder non-sequential behavior fix #889

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 100 additions & 0 deletions tests/unit/test_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pylint: disable=missing-docstring,protected-access,unused-argument,redefined-outer-name,invalid-name, line-too-long
import os
import tempfile
import time
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
Expand Down Expand Up @@ -327,3 +329,101 @@ def test_unpickleable_dict():
)
)
assert len(recons) == 0


def assert_file_exists(path):
assert os.path.exists(path)


def write_file(path):
with open(path, "w", encoding="utf8") as f:
f.write("hello world")


def dummy_f(arg):
...


@pytest.fixture
def file_io_funcitons():
builder.register("assert_file_exists", versions=">=0.0.0")(assert_file_exists)
builder.register("write_file", versions=">=0.0.0")(write_file)
builder.register("dummy_f", versions=">=0.0.0")(dummy_f)
yield
builder.unregister(name="dummy_f", fn=dummy_f)
builder.unregister(name="write_file", fn=write_file)
builder.unregister(name="assert_file_exists", fn=assert_file_exists)


def test_sequential_side_effects_x0(file_io_funcitons):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
{
"parts": [
{"@type": "write_file", "path": file_path},
{"@type": "assert_file_exists", "path": file_path},
[{"@type": "assert_file_exists", "path": file_path}],
[[{"@type": "assert_file_exists", "path": file_path}]],
]
}
)


def test_sequential_side_effects_x1(file_io_funcitons):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
{
"parts": [
[[[{"@type": "write_file", "path": file_path}]]],
{"@type": "assert_file_exists", "path": file_path},
[{"@type": "assert_file_exists", "path": file_path}],
[[{"@type": "assert_file_exists", "path": file_path}]],
]
}
)


def test_sequential_side_effects_x2(file_io_funcitons):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
{
"parts": [
[[{"inner": [{"@type": "write_file", "path": file_path}]}]],
{"@type": "assert_file_exists", "path": file_path},
[{"@type": "assert_file_exists", "path": file_path}],
[[{"@type": "assert_file_exists", "path": file_path}]],
]
}
)


def test_sequential_side_effects_x3(file_io_funcitons):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
{
"parts": [
{"@type": "dummy_f", "arg": [{"@type": "write_file", "path": file_path}]},
{"@type": "assert_file_exists", "path": file_path},
]
}
)


def test_sequential_side_effects_x4(file_io_funcitons):
with tempfile.TemporaryDirectory() as tmp_dir:
file_path = os.path.join(tmp_dir, "file.txt")
builder.build(
{
"parts": [
{"@type": "write_file", "path": file_path},
{
"@type": "dummy_f",
"arg": {"@type": "assert_file_exists", "path": file_path},
},
]
}
)
30 changes: 23 additions & 7 deletions zetta_utils/builder/building.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def build(spec: dict | list | None = None, path: str | None = None, parallel: bo


def _build(spec: JsonSerializableValue, parallel: bool, version: str, name_prefix: str) -> Any:
stages = _parse_stages(spec, version=version, name_prefix=name_prefix)
stages = _parse_stages(spec, version=version, name_prefix=name_prefix, parallel=parallel)
result = _execute_build_stages(stages=stages, parallel=parallel)
return result

Expand Down Expand Up @@ -181,7 +181,9 @@ def _is_trivial(obj) -> bool:
)


def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -> list[Stage]:
def _parse_stages(
spec: JsonSerializableValue, name_prefix: str, version: str, parallel: bool
) -> list[Stage]:
stage_dict: dict[int, Stage] = defaultdict(Stage)
final_obj = _parse_stages_inner(
spec=spec,
Expand All @@ -191,6 +193,7 @@ def _parse_stages(spec: JsonSerializableValue, name_prefix: str, version: str) -
parent=None,
parent_kwarg_name=None,
name_prefix=name_prefix,
parallel=parallel,
)
result = [stage_dict[k] for k in reversed(sorted(stage_dict.keys()))]
if _is_trivial(final_obj):
Expand All @@ -212,8 +215,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
name_prefix: str,
parent: ObjectToBeBuilt | None,
parent_kwarg_name: str | None,
parallel: bool,
) -> ObjectToBeBuilt | JsonSerializableValue | BuilderPartial:
result: ObjectToBeBuilt | JsonSerializableValue | BuilderPartial

if isinstance(spec, (int, float, bool, str)) or spec is None:
result = spec
elif isinstance(spec, (list, tuple)):
Expand All @@ -234,6 +239,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
parent_kwarg_name=str(i),
version=version,
name_prefix=f"{name_prefix}[{i}]",
parallel=parallel,
)
for i, e in enumerate(spec)
]
Expand All @@ -243,10 +249,10 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
else:
if any(not v.allow_parallel for v in args if isinstance(v, ObjectToBeBuilt)):
this_obj.allow_parallel = False

this_obj.kwargs = {str(i): v for i, v in enumerate(args) if _is_trivial(v)}
stages_dict[level].sequential_part.append(this_obj)
result = this_obj

elif isinstance(spec, dict) and "@type" not in spec:
assert "@mode" not in spec
this_obj = ObjectToBeBuilt(
Expand All @@ -265,6 +271,7 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
parent_kwarg_name=k,
version=version,
name_prefix=f"{name_prefix}[{k}]",
parallel=parallel,
)
for k, v in spec.items()
}
Expand All @@ -276,6 +283,8 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt)
):
this_obj.allow_parallel = False
this_obj.allow_parallel = False

this_obj.kwargs = {k: v for k, v in kwargs.items() if _is_trivial(v)}
stages_dict[level].sequential_part.append(this_obj)
result = this_obj
Expand Down Expand Up @@ -303,15 +312,21 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
name_prefix=name_prefix,
allow_parallel=entry.allow_parallel,
)
if parallel:
children_level_increment = 1
else:
children_level_increment = 0

kwargs = {
k: _parse_stages_inner(
v,
stages_dict=stages_dict,
level=level + 1,
level=level + children_level_increment,
parent=this_obj,
parent_kwarg_name=k,
version=version,
name_prefix=f"{name_prefix}.{k}",
parallel=parallel,
)
for k, v in spec.items()
if not k.startswith("@")
Expand All @@ -321,10 +336,11 @@ def _parse_stages_inner( # pylint: disable=too-many-branches,too-many-statement
not v.allow_parallel for v in kwargs.values() if isinstance(v, ObjectToBeBuilt)
):
this_obj.allow_parallel = False
if entry.allow_parallel:
stages_dict[level + 1].parallel_part.append(this_obj)

if entry.allow_parallel and parallel:
stages_dict[level + children_level_increment].parallel_part.append(this_obj)
else:
stages_dict[level + 1].sequential_part.append(this_obj)
stages_dict[level + children_level_increment].sequential_part.append(this_obj)
result = this_obj
else:
raise ValueError(f"Unsupported type: {type(spec)}")
Expand Down
Loading