Skip to content

Commit

Permalink
implemented NodeData type to simplify API
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrozum committed Mar 1, 2024
1 parent a48e9bf commit 6edc026
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 59 deletions.
6 changes: 3 additions & 3 deletions balm/_sd_algorithms/compute_attractor_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def compute_attractor_seeds(
if balm.succession_diagram.DEBUG:
print(f"[{node_id}] Start computing attractor seeds.")

node_space = sd.node_space(node_id)
node_space = sd.node_data(node_id)["space"]

if len(node_space) == sd.network.variable_count():
# This node is a fixed-point.
Expand All @@ -39,8 +39,8 @@ def compute_attractor_seeds(
# Compute the list of child spaces if the node is expanded. Otherwise
# "pretend" that there are no children.
child_spaces = []
if sd.node_is_expanded(node_id):
child_spaces = [sd.node_space(s) for s in sd.node_successors(node_id)]
if sd.node_data(node_id)["expanded"]:
child_spaces = [sd.node_data(s)["space"] for s in sd.node_successors(node_id)]

# Fix everything in the NFVS to zero, as long as
# it isn't already fixed by our `node_space`.
Expand Down
6 changes: 3 additions & 3 deletions balm/_sd_algorithms/expand_attractor_seeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def expand_attractor_seeds(sd: SuccessionDiagram, size_limit: int | None = None)

# Retrieve the stable motifs of children that are already expanded.
expanded_children = [
x for x in sd.node_successors(node) if sd.node_is_expanded(x)
x for x in sd.node_successors(node) if sd.node_data(x)["expanded"]
]
expanded_motifs = [
sd.edge_stable_motif(node, child) for child in expanded_children
Expand All @@ -61,15 +61,15 @@ def expand_attractor_seeds(sd: SuccessionDiagram, size_limit: int | None = None)
# and continue to the next one.
successors.pop()
continue
if sd.node_is_expanded(successors[-1]):
if sd.node_data(successors[-1])["expanded"]:
# The next node to explore is expanded (by some previous procedure)
# but not "seen" in this search yet. We need to visit this node
# regardless of other conditions
break
# Now, we need to asses if the next successor has some candidate states which
# are not covered by the already expanded children.

successor_space = sd.node_space(successors[-1])
successor_space = sd.node_data(successors[-1])["space"]
retained_set = make_retained_set(
sd.symbolic, sd.node_nfvs(node), successor_space
)
Expand Down
4 changes: 2 additions & 2 deletions balm/_sd_algorithms/expand_minimal_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def expand_minimal_spaces(sd: SuccessionDiagram, size_limit: int | None = None)
successors = sorted(successors, reverse=True) # For determinism!
# (reversed because we explore the list from the back)

node_space = sd.node_space(node)
node_space = sd.node_data(node)["space"]

# Remove all immediate successors that are already visited or those who
# do not cover any new minimal trap space.
Expand All @@ -52,7 +52,7 @@ def expand_minimal_spaces(sd: SuccessionDiagram, size_limit: int | None = None)
# of this node is already in the succession diagram.
if len(successors) == 0:
if sd.node_is_minimal(node):
minimal_traps.remove(sd.node_space(node))
minimal_traps.remove(sd.node_data(node)["space"])
continue

# At this point, we know that `s` is not visited and it contains
Expand Down
2 changes: 1 addition & 1 deletion balm/_sd_algorithms/expand_to_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def expand_to_target(

while len(current_level) > 0:
for node in current_level:
node_space = sd.node_space(node)
node_space = sd.node_data(node)["space"]

if intersect(node_space, target) is None:
# If `node_space` does not intersect with `target`, it is not relevant
Expand Down
4 changes: 2 additions & 2 deletions balm/control.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
self._control: list[ControlOverrides] = []
for c in control:
cs = sorted(map(lambda x: sorted(x.items()), c))
self._control.append(list(map(dict, cs)))
self._control.append(list(map(dict, cs))) # type: ignore

self._strategy = strategy
self._succession = succession
Expand Down Expand Up @@ -335,7 +335,7 @@ def successions_to_target(
)

for s in succession_diagram.node_ids():
fixed_vars = succession_diagram.node_space(s)
fixed_vars = succession_diagram.node_data(s)["space"]
if not is_subspace(fixed_vars, target):
continue

Expand Down
92 changes: 50 additions & 42 deletions balm/succession_diagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
)
from balm.space_utils import percolate_space, space_unique_key
from balm.trappist_core import trappist
from balm.types import BooleanSpace, SuccessionDiagramState
from balm.types import BooleanSpace, NodeData, SuccessionDiagramState

# Enables helpful "progress" messages.
DEBUG = False
Expand Down Expand Up @@ -205,7 +205,7 @@ def summary(self) -> str:
except KeyError:
continue

space = self.node_space(node)
space = self.node_data(node)["space"]

if self.node_is_minimal(node):
space_str_prefix = "minimal trap space "
Expand Down Expand Up @@ -238,7 +238,7 @@ def depth(self) -> int:
"""
d = 0
for node in cast(set[int], self.dag.nodes()):
d = max(d, self.node_depth(int(node)))
d = max(d, self.node_data(int(node))["depth"])
return d

def node_ids(self) -> Iterator[int]:
Expand All @@ -253,15 +253,15 @@ def stub_ids(self) -> Iterator[int]:
Iterator over all node IDs that are currently not expanded.
"""
for i in range(len(self)):
if not self.node_is_expanded(i):
if not self.node_data(i)["expanded"]:
yield i

def expanded_ids(self) -> Iterator[int]:
"""
Iterator over all node IDs that are currently expanded.
"""
for i in range(len(self)):
if self.node_is_expanded(i):
if self.node_data(i)["expanded"]:
yield i

def minimal_trap_spaces(self) -> list[int]:
Expand Down Expand Up @@ -307,16 +307,16 @@ def is_subgraph(self, other: SuccessionDiagram) -> bool:
# Every stub node is reachable through an expanded node and
# thus will be checked by the following code.
for i in self.expanded_ids():
other_i = other.find_node(self.node_space(i))
other_i = other.find_node(self.node_data(i)["space"])
if other_i is None:
return False
my_successors = self.node_successors(i)
other_successors = []
if other.node_is_expanded(other_i):
if other.node_data(other_i)["expanded"]:
other_successors = other.node_successors(other_i)

for my_s in my_successors:
other_s = other.find_node(self.node_space(my_s))
other_s = other.find_node(self.node_data(my_s)["space"])
if other_s not in other_successors:
return False
return True
Expand All @@ -335,27 +335,32 @@ def is_isomorphic(self, other: SuccessionDiagram) -> bool:
"""
return self.is_subgraph(other) and other.is_subgraph(self)

def node_depth(self, node_id: int) -> int:
def node_data(self, node_id: int) -> NodeData:
"""
Get the depth associated with the provided `node_id`. The depth is counted
as the longest path from the root node to the given node.
Get the data associated with the provided `node_id`.
"""
return cast(int, self.dag.nodes[node_id]["depth"])
return cast(NodeData, self.dag.nodes[node_id])

def node_space(self, node_id: int) -> BooleanSpace:
"""
Get the sub-space associated with the provided `node_id`.
# def node_depth(self, node_id: int) -> int:
# """
# Get the depth associated with the provided `node_id`. The depth is counted
# as the longest path from the root node to the given node.
# """
# return cast(int, self.dag.nodes[node_id]["depth"])
# def node_space(self, node_id: int) -> BooleanSpace:
# """
# Get the sub-space associated with the provided `node_id`.

Note that this is the space *after* percolation. Hence it can hold that
`|node_space(child)| < |node_space(parent)| + |stable_motif(parent, child)|`.
"""
return cast(BooleanSpace, self.dag.nodes[node_id]["space"])
# Note that this is the space *after* percolation. Hence it can hold that
# `|node_space(child)| < |node_space(parent)| + |stable_motif(parent, child)|`.
# """
# return cast(BooleanSpace, self.dag.nodes[node_id]["space"])

def node_is_expanded(self, node_id: int) -> bool:
"""
True if the successors of the given node are already computed.
"""
return cast(bool, self.dag.nodes[node_id]["expanded"])
# def node_is_expanded(self, node_id: int) -> bool:
# """
# True if the successors of the given node are already computed.
# """
# return cast(bool, self.dag.nodes[node_id]["expanded"])

def node_is_minimal(self, node_id: int) -> bool:
"""
Expand Down Expand Up @@ -406,9 +411,9 @@ def node_attractor_seeds(
same attractor in multiple stub nodes, if the stub nodes intersect), and
(b) this data is erased if the stub node is expanded later on.
"""
node = cast(dict[str, Any], self.dag.nodes[node_id])
node = cast(NodeData, self.dag.nodes[node_id])

attractors = cast(list[BooleanSpace] | None, node["attractors"])
attractors = node["attractors"]

if attractors is None and not compute:
raise KeyError(f"Attractor data not computed for node {node_id}.")
Expand Down Expand Up @@ -440,17 +445,17 @@ def node_nfvs(self, node_id: int) -> list[str]:

return self.nfvs

def node_restricted_petri_net(self, node_id: int) -> nx.DiGraph | None:
"""
Return the pre-computed Petri net representation restricted to the subspace
of the specified SD node.
# def node_restricted_petri_net(self, node_id: int) -> nx.DiGraph | None:
# """
# Return the pre-computed Petri net representation restricted to the subspace
# of the specified SD node.

This can return `None` if the requested node is already fully expanded, because
in such a case, there is no need to store the Petri net anymore. However,
in general you should assume that this field is optional, even on nodes that
are not expanded yet.
"""
return cast(nx.DiGraph, self.dag.nodes[node_id]["petri_net"])
# This can return `None` if the requested node is already fully expanded, because
# in such a case, there is no need to store the Petri net anymore. However,
# in general you should assume that this field is optional, even on nodes that
# are not expanded yet.
# """
# return cast(nx.DiGraph, self.dag.nodes[node_id]["petri_net"])

def edge_stable_motif(
self, parent_id: int, child_id: int, reduced: bool = False
Expand All @@ -472,7 +477,7 @@ def edge_stable_motif(
{
k: v
for k, v in self.dag.edges[parent_id, child_id]["motif"].items() # type: ignore
if k not in self.node_space(parent_id)
if k not in self.node_data(parent_id)["space"]
},
)
else:
Expand Down Expand Up @@ -615,14 +620,14 @@ def _update_node_petri_net(self, node_id: int, parent_id: int | None):
such Petri net is always empty.
"""

node_space = self.node_space(node_id)
node_space = self.node_data(node_id)["space"]

if len(node_space) == self.network.variable_count():
# If fixed point, no need to compute.
return

if parent_id is not None:
parent_pn = self.node_restricted_petri_net(parent_id)
parent_pn = self.node_data(parent_id)["petri_net"]
if parent_pn is None:
pn = self.petri_net
else:
Expand Down Expand Up @@ -669,7 +674,9 @@ def _expand_one_node(self, node_id: int):
current_space = node["space"]

if DEBUG:
print(f"[{node_id}] Expanding: {len(self.node_space(node_id))} fixed vars.")
print(
f"[{node_id}] Expanding: {len(self.node_data(node_id)['space'])} fixed vars."
)

if len(current_space) == self.network.variable_count():
# This node is a fixed-point. Trappist would just
Expand All @@ -686,7 +693,7 @@ def _expand_one_node(self, node_id: int):
source_nodes = extract_source_variables(self.petri_net)

sub_spaces: list[BooleanSpace]
pn = self.node_restricted_petri_net(node_id)
pn = self.node_data(node_id)["petri_net"]
if pn is not None:
# We have a pre-propagated PN for this sub-space, hence we can use
# that to compute the trap spaces.
Expand Down Expand Up @@ -747,9 +754,10 @@ def _ensure_node(self, parent_id: int | None, stable_motif: BooleanSpace) -> int
child_id = None
if key not in self.node_indices:
child_id = self.dag.number_of_nodes()

# Note: this must match the fields of the `NodeData` class
self.dag.add_node( # type: ignore
child_id,
id=child_id, # In case we ever need it within the "node data" dictionary.
space=fixed_vars,
depth=0,
expanded=False,
Expand Down
8 changes: 8 additions & 0 deletions balm/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,11 @@ class SuccessionDiagramState(TypedDict):
nfvs: list[str] | None
dag: nx.DiGraph
node_indices: dict[int, int]


class NodeData(TypedDict):
depth: int
attractors: list[BooleanSpace] | None
petri_net: nx.DiGraph | None
space: BooleanSpace
expanded: bool
6 changes: 3 additions & 3 deletions tests/source_SCC_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_find_scc_sd():
)

for node in scc_sd.node_ids():
print(scc_sd.node_space(node))
print(scc_sd.node_data(node)["space"])

assert scc_sd.dag.nodes[0]["space"] == {}
assert scc_sd.dag.nodes[1]["space"] == {"A": 0, "B": 0}
Expand Down Expand Up @@ -375,8 +375,8 @@ def test_isomorph():
sd_scc = SuccessionDiagram(bn)
expand_source_SCCs(sd_scc)

assert [sd_bfs.node_space(id) for id in sd_bfs.node_ids()] == [
sd_scc.node_space(id) for id in sd_scc.node_ids()
assert [sd_bfs.node_data(id)["space"] for id in sd_bfs.node_ids()] == [
sd_scc.node_data(id)["space"] for id in sd_scc.node_ids()
]

assert sd_scc.is_isomorphic(sd_bfs)
Expand Down
6 changes: 3 additions & 3 deletions tests/succession_diagram_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_succession_diagram_structure(self):
assert (
max(
[
succession_diagram.node_depth(i)
succession_diagram.node_data(i)["depth"]
for i in succession_diagram.node_ids()
]
)
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_succession_diagram_structure(self):
assert (
max(
[
succession_diagram.node_depth(i)
succession_diagram.node_data(i)["depth"]
for i in succession_diagram.node_ids()
]
)
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_expansion_comparisons(network_file: str):
# This should always create a succession_diagram with exactly one minimal trap space,
# as the rest
for min_trap in sd_bfs.minimal_trap_spaces():
space = sd_bfs.node_space(min_trap)
space = sd_bfs.node_data(min_trap)["space"]

sd_target = SuccessionDiagram(bn)
assert sd_target.expand_to_target(space, size_limit=NODE_LIMIT)
Expand Down

1 comment on commit 6edc026

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
balm
   control.py1141488%102, 114, 120, 124, 129, 138–154, 472, 475, 488
   interaction_graph_utils.py54591%6–8, 47, 165–166
   motif_avoidant.py152299%25, 121
   petri_net_translation.py1481193%18–22, 58, 94, 207–208, 232–233, 242, 346
   space_utils.py129497%25–27, 252, 278
   succession_diagram.py2411295%6, 163, 177, 185, 188, 205–206, 312, 444, 632, 670, 707
   symbolic_utils.py26388%10–12, 44
   trappist_core.py1833084%11–15, 45, 47, 82, 129, 195, 197, 199, 227–230, 234–236, 256–262, 320, 322, 352, 392, 394, 425, 454
balm/_sd_algorithms
   compute_attractor_seeds.py30197%8
   expand_attractor_seeds.py51590%6, 42, 97–102
   expand_bfs.py28196%6
   expand_dfs.py30197%6
   expand_minimal_spaces.py37295%6, 31
   expand_source_SCCs.py164696%19–21, 91, 101, 143, 287
   expand_to_target.py31390%6, 38, 43
TOTAL146910093% 

Tests Skipped Failures Errors Time
363 0 💤 0 ❌ 0 🔥 53.272s ⏱️

Please sign in to comment.