Skip to content

Commit

Permalink
Remove logpdf from internal nodes (closes #44).
Browse files Browse the repository at this point in the history
  • Loading branch information
Feras A. Saad committed Mar 2, 2020
1 parent df27b48 commit 929a27c
Showing 1 changed file with 2 additions and 15 deletions.
17 changes: 2 additions & 15 deletions src/spn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def logprob(self, event):
def prob(self, event):
lp = self.logprob(event)
return exp(lp)
def logpdf(self, x):
raise NotImplementedError()
def condition(self, event):
raise NotImplementedError()
def mutual_information(self, A, B):
Expand Down Expand Up @@ -164,10 +162,6 @@ def logprob(self, event):
logps = [spn.logprob(event_dnf) for spn in self.children]
return logsumexp([p + w for (p, w) in zip(logps, self.weights)])

def logpdf(self, x):
logps = [spn.logpdf(x) for spn in self.children]
return logsumexp([p + w for (p, w) in zip(logps, self.weights)])

def condition(self, event):
logps_condt = [spn.logprob(event) for spn in self.children]
indexes = [i for i, lp in enumerate(logps_condt) if not isinf_neg(lp)]
Expand Down Expand Up @@ -287,12 +281,6 @@ def sample_func(self, func, N, rng):
samples = self.sample_subset(symbols, N, rng)
return func_evaluate(self, func, samples)

def logpdf(self, x):
assert len(x) == len(self.children)
logps = [spn.logpdf(v) for (spn, v) in zip(self.children, x)
if x is not None]
return logsumexp(logps)

def logprob(self, event):
event_dnf = dnf_normalize(event)
if event_dnf is None:
Expand Down Expand Up @@ -390,6 +378,8 @@ def sample_subset(self, symbols, N, rng):
def sample_func(self, func, N, rng):
samples = self.sample(N, rng)
return func_evaluate(self, func, samples)
def logpdf(self, x):
raise NotImplementedError()

# ==============================================================================
# RealDistribution base class.
Expand Down Expand Up @@ -443,9 +433,6 @@ def logcdf(self, x):
p = logdiffexp(self.dist.logcdf(x), self.logFl)
return p - self.logZ

def logpdf(self, x):
raise NotImplementedError()

def logprob_values(self, values):
if values is EmptySet:
return -inf
Expand Down

0 comments on commit 929a27c

Please sign in to comment.