Skip to content

Commit

Permalink
Merge branch '20200229-fsaad-repeat'
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A. Saad committed Mar 2, 2020
2 parents d02f8ab + ba549ba commit df27b48
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 18 deletions.
33 changes: 24 additions & 9 deletions SEMANTICS
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ Domain Variables
r ∈ Reals
w ∈ [0,1]
s ∈ Strings
x ∈ Vars
x ∈ Var
y ∈ ArrVar
z ∈ Vars = x | y[n]

Transform Data Type (Semantic Domain)

Expand Down Expand Up @@ -98,10 +100,13 @@ Modeling Language (Syntactic Domain)
| w₁*d₁ + w₂*d₂ [MixtureDist]

c ∈ Program
= x ~ d [Sample]
| x ← t [Assign]
| if ϕ then c₁ else c₂ [If-Else]
= y ← array(n) [DeclareArr]
| z ~ d [Sample]
| z ← t [Assign]
| c₁ ; c₂ [Sequence]
| if ϕ then c₁ else c₂ [If-Else]
| λx. c [Lambda]
| repeat n n' f [Repeat]

Small Step Semantics

Expand All @@ -123,17 +128,27 @@ Small Step Semantics
ϕ₁ or ϕ₂ -> (ϕ₁[σ] or ϕ₂[σ])
~ϕ -> ~ϕ[σ]


[Sample] ----------------------------------
⟨(s, σ), x ~ d⟩ → (⊗ s (ℓ x d), σ)

[Assign] -----------------------------
⟨(s, σ), x ← t⟩ → (s, [x\t]σ)

⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂)
[IfElse] -----------------------------------------------------------
⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])] (s₁, σ₁))
(1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂))

⟨(s, σ), c₁ → (s₁, σ₁) ⟨(s₁, σ₁), c₂ → (s₂, σ₂)
[Sequence] ------------------------------------------------
⟨(s, σ), c₁;c₂⟩ → (s₂, σ₂)

⟨(s|ϕ[σ], σ), c₁⟩ → (s₁, σ₁) ⟨(s|~ϕ[σ], σ), c₂⟩ → (s₂, σ₂)
[IfElse] -----------------------------------------------------------
⟨(s, σ), if ϕ then c₁ else c₂⟩ → ⊕ (𝐏⟦s⟧(ϕ[σ])) (s₁, σ₁)
(1-𝐏⟦s⟧(~ϕ[σ])) (s₂, σ₂)


[Repeat-False]
-------------------------------- (where n' <= n)
⟨(s, σ), repeat n n' f⟩ -> (s, σ)

[Repeat-True]
------------------------------------------------------------ (where n < n')
⟨(s, σ), repeat n λx.c⟩ -> ⟨(s, σ), ( (λx.c)(n); repeat (n+1) n' λx.c)⟩
30 changes: 21 additions & 9 deletions src/interpret.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@

from .transforms import Identity

Start = None
Otherwise = True

class Variable(Identity):
def __rshift__(self, f):
if isinstance(f, Callable):
Expand All @@ -26,6 +23,9 @@ def __hash__(self):
x = (symbol.__class__, self.token)
return hash(x)

def VariableArray(token, n):
return [Variable('%s[%d]' % (token, i,)) for i in range(n)]

class Command():
def interpret(self, spn):
raise NotImplementedError()
Expand All @@ -43,6 +43,12 @@ def __rand__(self, x):
return self.interpret(x)
return NotImplemented

class Skip(Command):
def __init__(self):
pass
def interpret(self, spn=None):
return spn

class Sample(Command):
def __init__(self, symbol, distribution):
self.symbol = symbol
Expand Down Expand Up @@ -76,13 +82,15 @@ def interpret(self, spn=None):
# Return the overall sum.
return SumSPN(children, weights)

class Skip(Command):
def __init__(self):
pass
class Repeat(Command):
def __init__(self, n0, n1, f):
self.n0 = n0
self.n1 = n1
self.f = f
def interpret(self, spn=None):
return spn

Cond = IfElse
commands = [self.f(i) for i in range(self.n0, self.n1)]
sequence = Sequence(*commands)
return sequence.interpret(spn)

class Sequence(Command):
def __init__(self, *commands):
Expand All @@ -97,3 +105,7 @@ def __and__(self, x):
commands = self.commands + (x,)
return Sequence(*commands)
return NotImplemented

Start = None
Otherwise = True
Cond = IfElse
50 changes: 50 additions & 0 deletions tests/test_repeat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2020 MIT Probabilistic Computing Project.
# See LICENSE.txt

from math import log

from spn.distributions import Bernoulli
from spn.distributions import NominalDist
from spn.interpret import Cond
from spn.interpret import Otherwise
from spn.interpret import Repeat
from spn.interpret import Start
from spn.interpret import Variable
from spn.interpret import VariableArray
from spn.math_util import allclose

Y = Variable('Y')
X = VariableArray('X', 5)
Z = VariableArray('Z', 5)

def test_simple_model():
model = (Start
& Y >> Bernoulli(p=0.5)
& Repeat(0, 5, lambda i:
X[i] >> Bernoulli(p=1/(i+1))))

symbols = model.get_symbols()
assert len(symbols) == 6
assert Y in symbols
assert X[0] in symbols
assert X[1] in symbols
assert X[2] in symbols
assert X[3] in symbols
assert X[4] in symbols
assert model.logprob(X[0] << {1}) == log(1/1)
assert model.logprob(X[1] << {1}) == log(1/2)
assert model.logprob(X[2] << {1}) == log(1/3)
assert model.logprob(X[3] << {1}) == log(1/4)
assert model.logprob(X[4] << {1}) == log(1/5)

def test_complex_model():
# Slow for larger number of repetitions
# https://github.com/probcomp/sum-product-dsl/issues/43
model = (Start
& Y >> NominalDist({'0': .2, '1': .2, '2': .2, '3': .2, '4': .2})
& Repeat(0, 3, lambda i:
Z[i] >> Bernoulli(p=0.1)
& Cond (
Y << {str(i)} | Z[i] << {0}, X[i] >> Bernoulli(p=1/(i+1)),
Otherwise, X[i] >> Bernoulli(p=0.1))))
assert allclose(model.prob(Y << {'0'}), 0.2)

0 comments on commit df27b48

Please sign in to comment.