diff --git a/SEMANTICS b/SEMANTICS index 640bd66..2a53799 100644 --- a/SEMANTICS +++ b/SEMANTICS @@ -7,7 +7,9 @@ Domain Variables r ∈ Reals w ∈ [0,1] s ∈ Strings - x ∈ Vars + x ∈ Var + y ∈ ArrVar + z ∈ Vars = x | y[n] Transform Data Type (Semantic Domain) @@ -98,10 +100,13 @@ Modeling Language (Syntactic Domain) | w₁*d₁ + w₂*d₂ [MixtureDist] c ∈ Program - = x ~ d [Sample] - | x ← t [Assign] - | if ϕ then c₁ else c₂ [If-Else] + = y ← array(n) [DeclareArr] + | z ~ d [Sample] + | z ← t [Assign] | c₁ ; c₂ [Sequence] + | if ϕ then c₁ else c₂ [If-Else] + | λx. c [Lambda] + | repeat n n' f [Repeat] Small Step Semantics @@ -123,17 +128,27 @@ Small Step Semantics ϕ₁ or ϕ₂ -> (ϕ₁[σ] or ϕ₂[σ]) ~ϕ -> ~ϕ[σ] + [Sample] ---------------------------------- ⟨(s, σ), x ~ d⟩ → (⊗ s (ℓ x d), σ) [Assign] ----------------------------- ⟨(s, σ), x ← t⟩ → (s, [x\t]σ) - ⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂) -[IfElse] ----------------------------------------------------------- - ⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])] (s₁, σ₁)) - (1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂)) - ⟨(s, σ), c₁ → (s₁, σ₁) ⟨(s₁, σ₁), c₂ → (s₂, σ₂) [Sequence] ------------------------------------------------ ⟨(s, σ), c₁;c₂⟩ → (s₂, σ₂) + + ⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂) +[IfElse] ----------------------------------------------------------- + ⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])) (s₁, σ₁) + (1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂) + + +[Repeat-False] + -------------------------------- (where n' <= n) + ⟨(s, σ), repeat n n' f⟩ -> (s, σ) + +[Repeat-True] + ------------------------------------------------------------ (where n < n') + ⟨(s, σ), repeat n λx.c⟩ -> ⟨(s, σ), ( (λx.c)(n); repeat (n+1) n' λx.c)⟩ diff --git a/src/interpret.py b/src/interpret.py index 285b5fc..d2a1779 100644 --- a/src/interpret.py +++ b/src/interpret.py @@ -12,9 +12,6 @@ from .transforms import Identity -Start = None -Otherwise = True - class Variable(Identity): def __rshift__(self, f): if isinstance(f, Callable): @@ -26,6 +23,9 @@ def __hash__(self): x = (symbol.__class__, self.token) return hash(x) +def VariableArray(token, n): + return [Variable('%s[%d]' % (token, i,)) for i in range(n)] + class Command(): def interpret(self, spn): raise NotImplementedError() @@ -43,6 +43,12 @@ def __rand__(self, x): return self.interpret(x) return NotImplemented +class Skip(Command): + def __init__(self): + pass + def interpret(self, spn=None): + return spn + class Sample(Command): def __init__(self, symbol, distribution): self.symbol = symbol @@ -76,13 +82,15 @@ def interpret(self, spn=None): # Return the overall sum. return SumSPN(children, weights) -class Skip(Command): - def __init__(self): - pass +class Repeat(Command): + def __init__(self, n0, n1, f): + self.n0 = n0 + self.n1 = n1 + self.f = f def interpret(self, spn=None): - return spn - -Cond = IfElse + commands = [self.f(i) for i in range(self.n0, self.n1)] + sequence = Sequence(*commands) + return sequence.interpret(spn) class Sequence(Command): def __init__(self, *commands): @@ -97,3 +105,7 @@ def __and__(self, x): commands = self.commands + (x,) return Sequence(*commands) return NotImplemented + +Start = None +Otherwise = True +Cond = IfElse diff --git a/tests/test_repeat.py b/tests/test_repeat.py new file mode 100644 index 0000000..0f26633 --- /dev/null +++ b/tests/test_repeat.py @@ -0,0 +1,50 @@ +# Copyright 2020 MIT Probabilistic Computing Project. +# See LICENSE.txt + +from math import log + +from spn.distributions import Bernoulli +from spn.distributions import NominalDist +from spn.interpret import Cond +from spn.interpret import Otherwise +from spn.interpret import Repeat +from spn.interpret import Start +from spn.interpret import Variable +from spn.interpret import VariableArray +from spn.math_util import allclose + +Y = Variable('Y') +X = VariableArray('X', 5) +Z = VariableArray('Z', 5) + +def test_simple_model(): + model = (Start + & Y >> Bernoulli(p=0.5) + & Repeat(0, 5, lambda i: + X[i] >> Bernoulli(p=1/(i+1)))) + + symbols = model.get_symbols() + assert len(symbols) == 6 + assert Y in symbols + assert X[0] in symbols + assert X[1] in symbols + assert X[2] in symbols + assert X[3] in symbols + assert X[4] in symbols + assert model.logprob(X[0] << {1}) == log(1/1) + assert model.logprob(X[1] << {1}) == log(1/2) + assert model.logprob(X[2] << {1}) == log(1/3) + assert model.logprob(X[3] << {1}) == log(1/4) + assert model.logprob(X[4] << {1}) == log(1/5) + +def test_complex_model(): + # Slow for larger number of repetitions + # https://github.com/probcomp/sum-product-dsl/issues/43 + model = (Start + & Y >> NominalDist({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2}) + & Repeat(0, 3, lambda i: + Z[i] >> Bernoulli(p=0.1) + & Cond ( + Y << {str(i)} | Z[i] << {0}, X[i] >> Bernoulli(p=1/(i+1)), + Otherwise, X[i] >> Bernoulli(p=0.1)))) + assert allclose(model.prob(Y << {'0'}), 0.2)