diff --git a/SEMANTICS b/SEMANTICS index 640bd66..2a53799 100644 --- a/SEMANTICS +++ b/SEMANTICS @@ -7,7 +7,9 @@ Domain Variables r ∈ Reals w ∈ [0,1] s ∈ Strings - x ∈ Vars + x ∈ Var + y ∈ ArrVar + z ∈ Vars = x | y[n] Transform Data Type (Semantic Domain) @@ -98,10 +100,13 @@ Modeling Language (Syntactic Domain) | w₁*d₁ + w₂*d₂ [MixtureDist] c ∈ Program - = x ~ d [Sample] - | x ← t [Assign] - | if ϕ then c₁ else c₂ [If-Else] + = y ← array(n) [DeclareArr] + | z ~ d [Sample] + | z ← t [Assign] | c₁ ; c₂ [Sequence] + | if ϕ then c₁ else c₂ [If-Else] + | λx. c [Lambda] + | repeat n n' f [Repeat] Small Step Semantics @@ -123,17 +128,27 @@ Small Step Semantics ϕ₁ or ϕ₂ -> (ϕ₁[σ] or ϕ₂[σ]) ~ϕ -> ~ϕ[σ] + [Sample] ---------------------------------- ⟨(s, σ), x ~ d⟩ → (⊗ s (ℓ x d), σ) [Assign] ----------------------------- ⟨(s, σ), x ← t⟩ → (s, [x\t]σ) - ⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂) -[IfElse] ----------------------------------------------------------- - ⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])] (s₁, σ₁)) - (1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂)) - ⟨(s, σ), c₁ → (s₁, σ₁) ⟨(s₁, σ₁), c₂ → (s₂, σ₂) [Sequence] ------------------------------------------------ ⟨(s, σ), c₁;c₂⟩ → (s₂, σ₂) + + ⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂) +[IfElse] ----------------------------------------------------------- + ⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])) (s₁, σ₁) + (1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂) + + +[Repeat-False] + -------------------------------- (where n' <= n) + ⟨(s, σ), repeat n n' f⟩ -> (s, σ) + +[Repeat-True] + ------------------------------------------------------------ (where n < n') + ⟨(s, σ), repeat n λx.c⟩ -> ⟨(s, σ), ( (λx.c)(n); repeat (n+1) n' λx.c)⟩ diff --git a/src/interpret.py b/src/interpret.py index b5d135c..d2a1779 100644 --- a/src/interpret.py +++ b/src/interpret.py @@ -23,6 +23,9 @@ def __hash__(self): x = (symbol.__class__, self.token) return hash(x) +def VariableArray(token, n): + return [Variable('%s[%d]' % (token, i,)) for i in range(n)] + class Command(): def interpret(self, spn): raise NotImplementedError() @@ -79,6 +82,16 @@ def interpret(self, spn=None): # Return the overall sum. return SumSPN(children, weights) +class Repeat(Command): + def __init__(self, n0, n1, f): + self.n0 = n0 + self.n1 = n1 + self.f = f + def interpret(self, spn=None): + commands = [self.f(i) for i in range(self.n0, self.n1)] + sequence = Sequence(*commands) + return sequence.interpret(spn) + class Sequence(Command): def __init__(self, *commands): self.commands = commands diff --git a/tests/test_repeat.py b/tests/test_repeat.py new file mode 100644 index 0000000..0f26633 --- /dev/null +++ b/tests/test_repeat.py @@ -0,0 +1,50 @@ +# Copyright 2020 MIT Probabilistic Computing Project. +# See LICENSE.txt + +from math import log + +from spn.distributions import Bernoulli +from spn.distributions import NominalDist +from spn.interpret import Cond +from spn.interpret import Otherwise +from spn.interpret import Repeat +from spn.interpret import Start +from spn.interpret import Variable +from spn.interpret import VariableArray +from spn.math_util import allclose + +Y = Variable('Y') +X = VariableArray('X', 5) +Z = VariableArray('Z', 5) + +def test_simple_model(): + model = (Start + & Y >> Bernoulli(p=0.5) + & Repeat(0, 5, lambda i: + X[i] >> Bernoulli(p=1/(i+1)))) + + symbols = model.get_symbols() + assert len(symbols) == 6 + assert Y in symbols + assert X[0] in symbols + assert X[1] in symbols + assert X[2] in symbols + assert X[3] in symbols + assert X[4] in symbols + assert model.logprob(X[0] << {1}) == log(1/1) + assert model.logprob(X[1] << {1}) == log(1/2) + assert model.logprob(X[2] << {1}) == log(1/3) + assert model.logprob(X[3] << {1}) == log(1/4) + assert model.logprob(X[4] << {1}) == log(1/5) + +def test_complex_model(): + # Slow for larger number of repetitions + # https://github.com/probcomp/sum-product-dsl/issues/43 + model = (Start + & Y >> NominalDist({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2}) + & Repeat(0, 3, lambda i: + Z[i] >> Bernoulli(p=0.1) + & Cond ( + Y << {str(i)} | Z[i] << {0}, X[i] >> Bernoulli(p=1/(i+1)), + Otherwise, X[i] >> Bernoulli(p=0.1)))) + assert allclose(model.prob(Y << {'0'}), 0.2)