From 16890822b68308cdfc83688861d864eb8ef8b3cb Mon Sep 17 00:00:00 2001 From: Samuel Pastva Date: Wed, 4 Sep 2024 22:04:39 +0200 Subject: [PATCH] Alternative two-way reachability method. --- biobalm/_sd_attractors/attractor_symbolic.py | 154 +++++++++++++++---- 1 file changed, 125 insertions(+), 29 deletions(-) diff --git a/biobalm/_sd_attractors/attractor_symbolic.py b/biobalm/_sd_attractors/attractor_symbolic.py index a8754d2..a8d1e2c 100644 --- a/biobalm/_sd_attractors/attractor_symbolic.py +++ b/biobalm/_sd_attractors/attractor_symbolic.py @@ -16,6 +16,7 @@ from biobalm.symbolic_utils import state_list_to_bdd from biobalm.types import BooleanSpace import biodivine_aeon +import copy def symbolic_attractor_fallback( @@ -262,6 +263,7 @@ def symbolic_attractor_test( if not incompatible.is_false(): conflict_vars.append(var) conflict_vars = sort_variable_list(conflict_vars) + all_conflict_vars = copy.copy(conflict_vars) # Remaining network variables that are still relevant, but may not # be necessary to reach `avoid`. @@ -275,11 +277,29 @@ def symbolic_attractor_test( f"[{node_id}] > Start symbolic reachability with {len(conflict_vars)} conflict variables and {len(other_vars)} other variables." ) + # This algorithm has two modes of operation: If `avoid` is `None`, + # then the only thing that we can do is compute forward reachability. + # However, if we have any avoid states, we can run two reachability + # operations that interleave each other: One forward from the + # pivot vertex, the other backward from the `avoid` set. If these two + # sets ever intersect, we know that there is a path from pivot into the + # `avoid` set. To further speed up the computation, we first prioritize + # the variables in which the pivot differs from the avoid set (so called + # conflict variables). If all such variables are covered, we then consider + # those that are close to the conflict variables in terms of regulations + # (i.e. if we can't update a conflict variable, we ideally want to update + # one that regulates it). + + # True if the main cycle completes with all saturation procedures fully + # completed and no unprocessed variables remaining. all_done = False + while not all_done: all_done = True - # Saturate reach set with currently selected variables. + # Saturate reach_set with currently selected variables, but only if + # it's symbolic size is smaller than that of the avoid set (reach set + # tends to grow quite large and we'd like to avoid that). saturation_done = False while not saturation_done: if avoid is not None and not avoid.intersect(reach_set).is_empty(): @@ -291,41 +311,117 @@ def symbolic_attractor_test( for var in saturated_vars: successors = graph.var_post_out(var, reach_set) if not successors.is_empty(): - reach_set = reach_set.union(successors) - saturation_done = False - if reach_set.symbolic_size() > 100_000 and sd.config["debug"]: - print( - f"[{node_id}] > Saturation({len(saturated_vars)}) Expanded reach_set: {reach_set}" - ) - break + all_done = False # The main loop should continue. + updated = reach_set.union(successors) + no_avoid = avoid is None + avoid_is_larger = ( + avoid is not None + ) and avoid.symbolic_size() >= updated.symbolic_size() + all_variables_done = ( + len(conflict_vars) == 0 and len(other_vars) == 0 + ) + if no_avoid or avoid_is_larger or all_variables_done: + reach_set = updated + saturation_done = False + if reach_set.symbolic_size() > 100_000 and sd.config["debug"]: + print( + f"[{node_id}] > Saturation({len(saturated_vars)}) Incremented forward reach set: {reach_set}" + ) + break + + if avoid is not None: + # If `avoid` is not `None`, we also want to expand it backwards. + saturation_done = False + while not saturation_done: + if not avoid.intersect(reach_set).is_empty(): + if sd.config["debug"]: + print(f"[{node_id}] > Discovered avoid state. Done.") + return None + + saturation_done = True + for var in saturated_vars: + predecessors = graph.var_pre_out(var, avoid) + if not predecessors.is_empty(): + all_done = False + avoid = avoid.union(predecessors) + saturation_done = False + if avoid.symbolic_size() > 100_000 and sd.config["debug"]: + print( + f"[{node_id}] > Saturation({len(saturated_vars)}) Incremented backward avoid set: {avoid}" + ) + break + + if sd.config["debug"] and ( + reach_set.symbolic_size() > 100_000 + or (avoid is not None and avoid.symbolic_size() > 100_000) + ): + print( + f"[{node_id}] > Saturation({len(saturated_vars)}) Finished with avoid set {avoid} and reach set {reach_set}." + ) # Once saturated, try to expand the saturated # collection with either a conflict variable or # other variable. - # First try conflict vars, then other vars. - for var in conflict_vars + other_vars: - successors = graph.var_post_out(var, reach_set) - if not successors.is_empty(): - reach_set = reach_set.union(successors) - all_done = False - - # This is a bit wasteful but at this point it - # should be irrelevant for performance. - if var in conflict_vars: - conflict_vars.remove(var) - - if var in other_vars: - other_vars.remove(var) + # First, we sort the non-conflict variables based on their + # distance towards the conflict variables. + network = sd.node_percolated_network(node_id) + distances: dict[VariableId, int] = { + v: network.variable_count() for v in network.variables() + } + visited: set[VariableId] + if len(conflict_vars) > 0: + visited = set(conflict_vars) + current_level = set(conflict_vars) + else: + visited = set(all_conflict_vars) + current_level = set(all_conflict_vars) + + next_level: set[VariableId] = set() + + distance = 0 + while len(current_level) > 0: + for var in current_level: + distances[var] = min(distance, distances[var]) + for s in network.predecessors(var): + if s not in visited: + visited.add(s) + next_level.add(s) + current_level = next_level + next_level = set() + distance += 1 + + other_sorted = sorted(other_vars, key=lambda x: distances[x]) + for var in conflict_vars + other_sorted: + can_go_fwd = graph.var_post_out(var, reach_set) + if avoid is not None: + can_go_bwd = graph.var_pre_out(var, avoid) + else: + can_go_bwd = graph.mk_empty_colored_vertices() + + if can_go_fwd.is_empty() and can_go_bwd.is_empty(): + continue + + all_done = False + + reach_set = reach_set.union(can_go_fwd) + if avoid is not None: + avoid = avoid.union(can_go_bwd) + + if var in conflict_vars: + conflict_vars.remove(var) + if var in other_vars: + other_vars.remove(var) + + saturated_vars.append(var) + saturated_vars = sort_variable_list(saturated_vars) - saturated_vars.append(var) - saturated_vars = sort_variable_list(saturated_vars) + if sd.config["debug"]: + print( + f"[{node_id}] > Saturation({len(saturated_vars)}) Added saturation variable. {len(conflict_vars)} conflict and {len(other_vars)} other variables remaining." + ) - if sd.config["debug"]: - print( - f"[{node_id}] > Saturation({len(saturated_vars)}) Added saturation variable. {len(conflict_vars)} conflict and {len(other_vars)} other variables remaining." - ) - break + break if sd.config["debug"]: print(f"[{node_id}] > Reachability completed with {reach_set}.")