Skip to content

Commit

Permalink
Implement repeat with minimal test (see notes).
Browse files Browse the repository at this point in the history
1. SEMANTICS need to be better to account for \lambda now.
2. The optimization in
    #43
   is especially important now.
  • Loading branch information
Feras A. Saad committed Mar 2, 2020
1 parent f488a3c commit ba549ba
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 9 deletions.
33 changes: 24 additions & 9 deletions SEMANTICS
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ Domain Variables
r ∈ Reals
w ∈ [0,1]
s ∈ Strings
x ∈ Vars
x ∈ Var
y ∈ ArrVar
z ∈ Vars = x | y[n]

Transform Data Type (Semantic Domain)

Expand Down Expand Up @@ -98,10 +100,13 @@ Modeling Language (Syntactic Domain)
| w₁*d₁ + w₂*d₂ [MixtureDist]

c ∈ Program
= x ~ d [Sample]
| x ← t [Assign]
| if ϕ then c₁ else c₂ [If-Else]
= y ← array(n) [DeclareArr]
| z ~ d [Sample]
| z ← t [Assign]
| c₁ ; c₂ [Sequence]
| if ϕ then c₁ else c₂ [If-Else]
| λx. c [Lambda]
| repeat n n' f [Repeat]

Small Step Semantics

Expand All @@ -123,17 +128,27 @@ Small Step Semantics
ϕ₁ or ϕ₂ -> (ϕ₁[σ] or ϕ₂[σ])
~ϕ -> ~ϕ[σ]


[Sample] ----------------------------------
⟨(s, σ), x ~ d⟩ → (⊗ s (ℓ x d), σ)

[Assign] -----------------------------
⟨(s, σ), x ← t⟩ → (s, [x\t]σ)

⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂)
[IfElse] -----------------------------------------------------------
⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])] (s₁, σ₁))
(1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂))

⟨(s, σ), c₁ → (s₁, σ₁) ⟨(s₁, σ₁), c₂ → (s₂, σ₂)
[Sequence] ------------------------------------------------
⟨(s, σ), c₁;c₂⟩ → (s₂, σ₂)

⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂)
[IfElse] -----------------------------------------------------------
⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])) (s₁, σ₁)
(1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂)


[Repeat-False]
-------------------------------- (where n' <= n)
⟨(s, σ), repeat n n' f⟩ -> (s, σ)

[Repeat-True]
------------------------------------------------------------ (where n < n')
⟨(s, σ), repeat n λx.c⟩ -> ⟨(s, σ), ( (λx.c)(n); repeat (n+1) n' λx.c)⟩
13 changes: 13 additions & 0 deletions src/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def __hash__(self):
x = (symbol.__class__, self.token)
return hash(x)

def VariableArray(token, n):
return [Variable('%s[%d]' % (token, i,)) for i in range(n)]

class Command():
def interpret(self, spn):
raise NotImplementedError()
Expand Down Expand Up @@ -79,6 +82,16 @@ def interpret(self, spn=None):
# Return the overall sum.
return SumSPN(children, weights)

class Repeat(Command):
def __init__(self, n0, n1, f):
self.n0 = n0
self.n1 = n1
self.f = f
def interpret(self, spn=None):
commands = [self.f(i) for i in range(self.n0, self.n1)]
sequence = Sequence(*commands)
return sequence.interpret(spn)

class Sequence(Command):
def __init__(self, *commands):
self.commands = commands
Expand Down
50 changes: 50 additions & 0 deletions tests/test_repeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt

from math import log

from spn.distributions import Bernoulli
from spn.distributions import NominalDist
from spn.interpret import Cond
from spn.interpret import Otherwise
from spn.interpret import Repeat
from spn.interpret import Start
from spn.interpret import Variable
from spn.interpret import VariableArray
from spn.math_util import allclose

Y = Variable('Y')
X = VariableArray('X', 5)
Z = VariableArray('Z', 5)

def test_simple_model():
model = (Start
& Y >> Bernoulli(p=0.5)
& Repeat(0, 5, lambda i:
X[i] >> Bernoulli(p=1/(i+1))))

symbols = model.get_symbols()
assert len(symbols) == 6
assert Y in symbols
assert X[0] in symbols
assert X[1] in symbols
assert X[2] in symbols
assert X[3] in symbols
assert X[4] in symbols
assert model.logprob(X[0] << {1}) == log(1/1)
assert model.logprob(X[1] << {1}) == log(1/2)
assert model.logprob(X[2] << {1}) == log(1/3)
assert model.logprob(X[3] << {1}) == log(1/4)
assert model.logprob(X[4] << {1}) == log(1/5)

def test_complex_model():
# Slow for larger number of repetitions
# https://github.com/probcomp/sum-product-dsl/issues/43
model = (Start
& Y >> NominalDist({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})
& Repeat(0, 3, lambda i:
Z[i] >> Bernoulli(p=0.1)
& Cond (
Y << {str(i)} | Z[i] << {0}, X[i] >> Bernoulli(p=1/(i+1)),
Otherwise, X[i] >> Bernoulli(p=0.1))))
assert allclose(model.prob(Y << {'0'}), 0.2)

0 comments on commit ba549ba

Please sign in to comment.