Skip to content

Commit

Permalink
load all rules at once
Browse files Browse the repository at this point in the history
  • Loading branch information
pauleve committed Jun 11, 2024
1 parent b0f9b2c commit f481591
Showing 1 changed file with 50 additions and 33 deletions.
83 changes: 50 additions & 33 deletions mpbn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,22 @@ def encode_dnf(f):
facts.append(circuitasp_of_boolfunc(f, n, self.ba))
return "".join(facts)

def load_eval(self, solver):
def _file_eval(self):
if self.encoding == "circuit":
solver.load(aspf("eval_circuit.asp"))
f = aspf("eval_circuit.asp")
elif self.encoding == "mixed-dnf-bdd":
solver.load(aspf("eval_mixed.asp"))
f = aspf("eval_mixed.asp")
else:
solver.load(aspf("mp_eval.asp"))
f = aspf("mp_eval.asp")
return f

def rules_eval(self):
f = self._file_eval()
with open(f, "r") as fp:
return fp.read()
def load_eval(self, solver):
f = self._file_eval()
solver.load(f)

def assert_pc_encoding(self):
assert self.encoding not in self.nonpc_encodings, "Unsupported encoding"
Expand All @@ -337,8 +346,9 @@ def asp_of_cfg(self, e, t, c):
facts = ["timepoint({},{}).".format(e,t)]
facts += [" mp_state({},{},\"{}\",{}).".format(e,t,n,s2v(s))
for (n,s) in c.items()]
facts += [" 1 {{mp_state({},{},\"{}\",(-1;1))}} 1.".format(e,t,n)
for n in self if n not in c]
facts += [f"1 {{mp_state({e},{t},N,(-1;1))}} 1 :- node(N)."]
#facts += [" 1 {{mp_state({},{},\"{}\",(-1;1))}} 1 :- node(N).".format(e,t,n)
# for n in self if n not in c]
return "".join(facts)

def reachability(self, x, y):
Expand Down Expand Up @@ -368,24 +378,30 @@ def reachability(self, x, y):
res = s.solve()
return res.satisfiable

def _ground_rules(self, ctl, rules):
rules = "\n".join(rules)
ctl.add("base", [], rules)
ctl.ground([("base",[])])

def _fixedpoints(self, reachable_from=None, constraints={}, limit=0):
project = reachable_from and set(self.keys()).difference(reachable_from)
s = clingo_enum(limit=limit, project=project)
s.add("base", [], self.asp_of_bn())
e = "fp"
t2 = "fp"
self.load_eval(s)
s.add("base", [], self.asp_of_cfg(e, t2, constraints))
s.add("base", [], f"mp_reach({e},{t2},N,V) :- mp_state({e},{t2},N,V).")
s.add("base", [], f":- mp_state({e},{t2},N,V), mp_eval({e},{t2},N,-V).")
rules = [self.asp_of_cfg(e, t2, constraints)]
rules.append(f"mp_reach({e},{t2},N,V) :- mp_state({e},{t2},N,V).")
rules.append(f":- mp_state({e},{t2},N,V), mp_eval({e},{t2},N,-V).")
rules.append(self.asp_of_bn())
if reachable_from:
self.assert_pc_encoding()
t1 = "0"
s.load(aspf("mp_positivereach-np.asp"))
s.add("base", [], self.asp_of_cfg(e,t1,reachable_from))
s.add("base", [], "is_reachable({},{},{}).".format(e,t1,t2))
s.add("base", [], f"#show. #show fixpoint(N,V) : mp_state({e},{t2},N,V).")
s.ground([("base",[])])
rules.append(open(aspf("mp_positivereach-np.asp")).read())
rules.append(self.asp_of_cfg(e,t1,reachable_from))
rules.append("is_reachable({},{},{}).".format(e,t1,t2))
rules.append(f"#show. #show fixpoint(N,V) : mp_state({e},{t2},N,V).")
rules.append(open(aspf("mp_eval.asp")).read())

project = reachable_from and set(self.keys()).difference(reachable_from)
s = clingo_enum(limit=limit, project=project)
self._ground_rules(s, rules)
return s

def fixedpoints(self, reachable_from=None, constraints={}, limit=0):
Expand Down Expand Up @@ -434,35 +450,36 @@ def count_fixedpoints(self, reachable_from=None, constraints={}, limit=0):

def _trapspaces(self, reachable_from=None, subcube={}, limit=0,
mode="min", exclude_full=False):
self.assert_pc_encoding()

project = reachable_from and set(self.keys()).difference(reachable_from)
rules = []
rules.append(self.asp_of_bn())
rules.append(self.rules_eval())
rules.append(open(aspf("mp_attractor.asp")).read())
rules.append("#show attractor/2.")

self.assert_pc_encoding()
solver = clingo_subsets if mode == "min" else clingo_supsets
s = solver(limit=limit, project=project)
self.load_eval(s)
s.load(aspf("mp_attractor.asp"))
s.add("base", [], self.asp_of_bn())
e = "__a"
t2 = "final"
if exclude_full and not subcube:
s.add("base", [], f"{{ mp_reach({e},{t2},N,(-1;1)): node(N) }} {len(self)*2-1}.")
rules.append(f"{{ mp_reach({e},{t2},N,(-1;1)): node(N) }} {len(self)*2-1}.")
if reachable_from:
t1 = "0"
s.load(aspf("mp_positivereach-np.asp"))
s.add("base", [], self.asp_of_cfg(e,t1,reachable_from))
s.add("base", [], "is_reachable({},{},{}).".format(e,t1,t2))
s.add("base", [], "mp_state({},{},N,V) :- attractor(N,V).".format(e,t2))
rules.append(open(aspf("mp_positivereach-np.asp")).read())
rules.append(self.asp_of_cfg(e,t1,reachable_from))
rules.append("is_reachable({},{},{}).".format(e,t1,t2))
rules.append("mp_state({},{},N,V) :- attractor(N,V).".format(e,t2))

for n, b in subcube.items():
if isinstance(b, str):
b = int(b)
if b not in [0,1]:
continue
s.add("base", [], ":- mp_reach({},{},\"{}\",{}).".format(e,t2,n,s2v(1-b)))
rules.append(":- mp_reach({},{},\"{}\",{}).".format(e,t2,n,s2v(1-b)))

s.add("base", [], "#show attractor/2.")
s.ground([("base",[])])
project = reachable_from and set(self.keys()).difference(reachable_from)
solver = clingo_subsets if mode == "min" else clingo_supsets
s = solver(limit=limit, project=project)
self._ground_rules(s, rules)
return s

def _yield_trapspaces(self, *args, star="*", **kwargs):
Expand Down

0 comments on commit f481591

Please sign in to comment.