Skip to content

Commit

Permalink
Merge pull request #480 from jw-96/splitters
Browse files Browse the repository at this point in the history
Support for splitters containing a single item list
  • Loading branch information
djarecka authored Jun 17, 2021
2 parents 7fb3d17 + eb17a4d commit fbb39d6
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 70 deletions.
150 changes: 80 additions & 70 deletions pydra/engine/helpers_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,81 +56,91 @@ def _ordering(
if type(el) is tuple:
# checking if the splitter dont contain splitter from previous nodes
# i.e. has str "_NA", etc.
if type(el[0]) is str and el[0].startswith("_"):
node_nm = el[0][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
if len(el) == 1:
# treats .split(("x",)) like .split("x")
el = el[0]
_ordering(el, i, output_splitter, current_sign, other_states, state_fields)
else:
if type(el[0]) is str and el[0].startswith("_"):
node_nm = el[0][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
)
)
)
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
el = (splitter_mod, el[1])
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
if type(el[1]) is str and el[1].startswith("_"):
node_nm = el[1][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
)
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
el = (el[0], splitter_mod)
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
_iterate_list(
el,
".",
other_states,
output_splitter=output_splitter,
state_fields=state_fields,
)
el = (splitter_mod, el[1])
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
if type(el[1]) is str and el[1].startswith("_"):
node_nm = el[1][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
)
)
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
el = (el[0], splitter_mod)
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
_iterate_list(
el,
".",
other_states,
output_splitter=output_splitter,
state_fields=state_fields,
)
elif type(el) is list:
if type(el[0]) is str and el[0].startswith("_"):
node_nm = el[0][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
if len(el) == 1:
# treats .split(["x"]) like .split("x")
el = el[0]
_ordering(el, i, output_splitter, current_sign, other_states, state_fields)
else:
if type(el[0]) is str and el[0].startswith("_"):
node_nm = el[0][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
)
)
)
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
el[0] = splitter_mod
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
if type(el[1]) is str and el[1].startswith("_"):
node_nm = el[1][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
)
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
el[1] = splitter_mod
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
_iterate_list(
el,
"*",
other_states,
output_splitter=output_splitter,
state_fields=state_fields,
)
el[0] = splitter_mod
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
if type(el[1]) is str and el[1].startswith("_"):
node_nm = el[1][1:]
if node_nm not in other_states and state_fields:
raise PydraStateError(
"can't ask for splitter from {}, other nodes that are connected: {}".format(
node_nm, other_states.keys()
)
)
elif state_fields:
splitter_mod = add_name_splitter(
splitter=other_states[node_nm][0].splitter_final, name=node_nm
)
el[1] = splitter_mod
if other_states[node_nm][0].other_states:
other_states.update(other_states[node_nm][0].other_states)
_iterate_list(
el,
"*",
other_states,
output_splitter=output_splitter,
state_fields=state_fields,
)
elif type(el) is str:
if el.startswith("_"):
node_nm = el[1:]
Expand Down
12 changes: 12 additions & 0 deletions pydra/engine/tests/test_helpers_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@ def __init__(
"splitter, keys_exp, groups_exp, grstack_exp",
[
("a", ["a"], {"a": 0}, [[0]]),
(["a"], ["a"], {"a": 0}, [[0]]),
(("a",), ["a"], {"a": 0}, [[0]]),
(("a", "b"), ["a", "b"], {"a": 0, "b": 0}, [[0]]),
(["a", "b"], ["a", "b"], {"a": 0, "b": 1}, [[0, 1]]),
([["a", "b"]], ["a", "b"], {"a": 0, "b": 1}, [[0, 1]]),
((["a", "b"],), ["a", "b"], {"a": 0, "b": 1}, [[0, 1]]),
((["a", "b"], "c"), ["a", "b", "c"], {"a": 0, "b": 1, "c": [0, 1]}, [[0, 1]]),
([("a", "b"), "c"], ["a", "b", "c"], {"a": 0, "b": 0, "c": 1}, [[0, 1]]),
([["a", "b"], "c"], ["a", "b", "c"], {"a": 0, "b": 1, "c": 2}, [[0, 1, 2]]),
Expand All @@ -58,6 +62,8 @@ def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp):
"keys_final_exp, groups_final_exp, grstack_final_exp",
[
("a", ["a"], ["a"], [], {}, []),
(["a"], ["a"], ["a"], [], {}, []),
(("a",), ["a"], ["a"], [], {}, []),
(("a", "b"), ["a"], ["a", "b"], [], {}, [[]]),
(("a", "b"), ["b"], ["a", "b"], [], {}, [[]]),
(["a", "b"], ["b"], ["b"], ["a"], {"a": 0}, [[0]]),
Expand All @@ -69,6 +75,8 @@ def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp):
([("a", "b"), "c"], ["a"], ["a", "b"], ["c"], {"c": 0}, [[0]]),
([("a", "b"), "c"], ["b"], ["a", "b"], ["c"], {"c": 0}, [[0]]),
([("a", "b"), "c"], ["c"], ["c"], ["a", "b"], {"a": 0, "b": 0}, [[0]]),
([[("a", "b"), "c"]], ["c"], ["c"], ["a", "b"], {"a": 0, "b": 0}, [[0]]),
(([("a", "b"), "c"],), ["c"], ["c"], ["a", "b"], {"a": 0, "b": 0}, [[0]]),
],
)
def test_splits_groups_comb(
Expand All @@ -94,6 +102,8 @@ def test_splits_groups_comb(
"splitter, cont_dim, values, keys, splits",
[
("a", None, [(0,), (1,)], ["a"], [{"a": 1}, {"a": 2}]),
(["a"], None, [(0,), (1,)], ["a"], [{"a": 1}, {"a": 2}]),
(("a",), None, [(0,), (1,)], ["a"], [{"a": 1}, {"a": 2}]),
(
("a", "v"),
None,
Expand Down Expand Up @@ -468,6 +478,8 @@ def test_splits_2(splitter_rpn, inner_inputs, values, keys, splits):
(["a", ("b", ["c", "d"])], ["a", "b", "c", "d", "*", ".", "*"]),
((["a", "b"], "c"), ["a", "b", "*", "c", "."]),
((["a", "b"], ["c", "d"]), ["a", "b", "*", "c", "d", "*", "."]),
(([["a", "b"]], ["c", "d"]), ["a", "b", "*", "c", "d", "*", "."]),
(((["a", "b"],), ["c", "d"]), ["a", "b", "*", "c", "d", "*", "."]),
([("a", "b"), ("c", "d")], ["a", "b", ".", "c", "d", ".", "*"]),
],
)
Expand Down

0 comments on commit fbb39d6

Please sign in to comment.