diff --git a/src/spn.py b/src/spn.py index 78e7d9e..1a0e877 100644 --- a/src/spn.py +++ b/src/spn.py @@ -63,8 +63,6 @@ def logprob(self, event): def prob(self, event): lp = self.logprob(event) return exp(lp) - def logpdf(self, x): - raise NotImplementedError() def condition(self, event): raise NotImplementedError() def mutual_information(self, A, B): @@ -164,10 +162,6 @@ def logprob(self, event): logps = [spn.logprob(event_dnf) for spn in self.children] return logsumexp([p + w for (p, w) in zip(logps, self.weights)]) - def logpdf(self, x): - logps = [spn.logpdf(x) for spn in self.children] - return logsumexp([p + w for (p, w) in zip(logps, self.weights)]) - def condition(self, event): logps_condt = [spn.logprob(event) for spn in self.children] indexes = [i for i, lp in enumerate(logps_condt) if not isinf_neg(lp)] @@ -287,12 +281,6 @@ def sample_func(self, func, N, rng): samples = self.sample_subset(symbols, N, rng) return func_evaluate(self, func, samples) - def logpdf(self, x): - assert len(x) == len(self.children) - logps = [spn.logpdf(v) for (spn, v) in zip(self.children, x) - if x is not None] - return logsumexp(logps) - def logprob(self, event): event_dnf = dnf_normalize(event) if event_dnf is None: @@ -390,6 +378,8 @@ def sample_subset(self, symbols, N, rng): def sample_func(self, func, N, rng): samples = self.sample(N, rng) return func_evaluate(self, func, samples) + def logpdf(self, x): + raise NotImplementedError() # ============================================================================== # RealDistribution base class. @@ -443,9 +433,6 @@ def logcdf(self, x): p = logdiffexp(self.dist.logcdf(x), self.logFl) return p - self.logZ - def logpdf(self, x): - raise NotImplementedError() - def logprob_values(self, values): if values is EmptySet: return -inf