Skip to content

Commit

Permalink
Allow .reduce(op, vars_not_in_inputs) (#397)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Jan 8, 2021
1 parent a6fa083 commit d3d2ce7
Show file tree
Hide file tree
Showing 16 changed files with 187 additions and 117 deletions.
16 changes: 6 additions & 10 deletions funsor/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,8 @@ def adjoint_tensor(adj_redop, adj_binop, out_adj, data, inputs, dtype):
def adjoint_binary(adj_redop, adj_binop, out_adj, op, lhs, rhs):
assert (adj_redop, op) in ops.DISTRIBUTIVE_OPS

lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)
lhs_adj = op(out_adj, rhs).reduce(adj_redop, lhs_reduced_vars)

rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)
rhs_adj = op(out_adj, lhs).reduce(adj_redop, rhs_reduced_vars)
lhs_adj = op(out_adj, rhs).reduce(adj_redop, rhs.input_vars - lhs.input_vars)
rhs_adj = op(out_adj, lhs).reduce(adj_redop, lhs.input_vars - rhs.input_vars)

return {lhs: lhs_adj, rhs: rhs_adj}

Expand Down Expand Up @@ -144,11 +141,10 @@ def adjoint_contract_generic(adj_redop, adj_binop, out_adj, sum_op, prod_op, red
def adjoint_contract(adj_redop, adj_binop, out_adj, sum_op, prod_op, reduced_vars, lhs, rhs):
assert sum_op is nullop or (sum_op, prod_op) in ops.DISTRIBUTIVE_OPS

lhs_reduced_vars = frozenset(rhs.inputs) - frozenset(lhs.inputs)
lhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop, prod_op, lhs_reduced_vars, out_adj, rhs)

rhs_reduced_vars = frozenset(lhs.inputs) - frozenset(rhs.inputs)
rhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop, prod_op, rhs_reduced_vars, out_adj, lhs)
lhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop,
prod_op, rhs.input_vars - lhs.input_vars, out_adj, rhs)
rhs_adj = Contraction(sum_op if sum_op is not nullop else adj_redop,
prod_op, lhs.input_vars - rhs.input_vars, out_adj, lhs)

return {lhs: lhs_adj, rhs: rhs_adj}

Expand Down
57 changes: 30 additions & 27 deletions funsor/cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import functools
import itertools
from collections import OrderedDict, defaultdict
from collections import Counter, OrderedDict, defaultdict
from functools import reduce
from typing import Tuple, Union

Expand Down Expand Up @@ -51,7 +51,7 @@ def __init__(self, red_op, bin_op, reduced_vars, terms):
assert isinstance(bin_op, AssociativeOp)
assert all(isinstance(v, Funsor) for v in terms)
assert isinstance(reduced_vars, frozenset)
assert all(isinstance(v, str) for v in reduced_vars)
assert all(isinstance(v, Variable) for v in reduced_vars)
assert isinstance(terms, tuple) and len(terms) > 0

assert not (isinstance(red_op, NullOp) and isinstance(bin_op, NullOp))
Expand All @@ -63,17 +63,17 @@ def __init__(self, red_op, bin_op, reduced_vars, terms):
assert reduced_vars and len(terms) > 1
assert (red_op, bin_op) in DISTRIBUTIVE_OPS

fresh = frozenset()
bound = {v.name: v.output for v in reduced_vars}
inputs = OrderedDict()
for v in terms:
inputs.update((k, d) for k, d in v.inputs.items() if k not in reduced_vars)
inputs.update((k, d) for k, d in v.inputs.items() if k not in bound)

if bin_op is nullop:
output = terms[0].output
else:
output = reduce(lambda lhs, rhs: find_domain(bin_op, lhs, rhs),
[v.output for v in reversed(terms)])
fresh = frozenset()
bound = reduced_vars
super(Contraction, self).__init__(inputs, output, fresh, bound)
self.red_op = red_op
self.bin_op = bin_op
Expand Down Expand Up @@ -153,11 +153,9 @@ def align(self, names):
return result

def _alpha_convert(self, alpha_subs):
reduced_vars = frozenset(alpha_subs.get(k, k) for k in self.reduced_vars)
bound_types = {}
for term in self.terms:
bound_types.update({k: term.inputs[k] for k in self.bound.intersection(term.inputs)})
alpha_subs = {k: to_funsor(v, bound_types[k]) for k, v in alpha_subs.items()}
reduced_vars = frozenset(to_funsor(alpha_subs.get(var.name, var), var.output)
for var in self.reduced_vars)
alpha_subs = {k: to_funsor(v, self.bound[k]) for k, v in alpha_subs.items()}
red_op, bin_op, _, terms = super()._alpha_convert(alpha_subs)
return red_op, bin_op, reduced_vars, terms

Expand Down Expand Up @@ -191,30 +189,35 @@ def eager_contraction_generic_to_tuple(red_op, bin_op, reduced_vars, *terms):

@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, tuple)
def eager_contraction_generic_recursive(red_op, bin_op, reduced_vars, terms):
# push down leaf reductions
terms, reduced_vars, leaf_reduced = list(terms), frozenset(reduced_vars), False
for i, v in enumerate(terms):
unique_vars = reduced_vars.intersection(v.inputs) - \
frozenset().union(*(reduced_vars.intersection(vv.inputs) for vv in terms if vv is not v))
if unique_vars:
result = v.reduce(red_op, unique_vars)
if result is not normalize(Contraction, red_op, nullop, unique_vars, (v,)):
terms[i] = result
reduced_vars -= unique_vars
leaf_reduced = True
# Count the number of terms in which each variable is reduced.
counts = Counter()
for term in terms:
counts.update(reduced_vars & term.input_vars)

# push down leaf reductions
terms = list(terms)
leaf_reduced = False
reduced_once = frozenset(v for v, count in counts.items() if count == 1)
if reduced_once:
for i, term in enumerate(terms):
unique_vars = reduced_once & term.input_vars
if unique_vars:
result = term.reduce(red_op, unique_vars)
if result is not normalize(Contraction, red_op, nullop, unique_vars, (term,)):
terms[i] = result
reduced_vars -= unique_vars
leaf_reduced = True
if leaf_reduced:
return Contraction(red_op, bin_op, reduced_vars, *terms)

# exploit associativity to recursively evaluate this contraction
# a bit expensive, but handles interpreter-imposed directionality constraints
terms = tuple(terms)
reduced_twice = frozenset(v for v, count in counts.items() if count == 2)
for i, lhs in enumerate(terms[0:-1]):
for j_, rhs in enumerate(terms[i+1:]):
j = i + j_ + 1
unique_vars = reduced_vars.intersection(lhs.inputs, rhs.inputs) - \
frozenset().union(*(reduced_vars.intersection(vv.inputs)
for vv in terms[:i] + terms[i+1:j] + terms[j+1:]))
unique_vars = reduced_twice.intersection(lhs.input_vars, rhs.input_vars)
result = Contraction(red_op, bin_op, unique_vars, lhs, rhs)
if result is not normalize(Contraction, red_op, bin_op, unique_vars, (lhs, rhs)): # did we make progress?
# pick the first evaluable pair
Expand All @@ -234,7 +237,7 @@ def eager_contraction_to_reduce(red_op, bin_op, reduced_vars, term):
@eager.register(Contraction, AssociativeOp, AssociativeOp, frozenset, Funsor, Funsor)
def eager_contraction_to_binary(red_op, bin_op, reduced_vars, lhs, rhs):

if reduced_vars - (reduced_vars.intersection(lhs.inputs, rhs.inputs)):
if not reduced_vars.issubset(lhs.input_vars & rhs.input_vars):
args = red_op, bin_op, reduced_vars, (lhs, rhs)
result = eager.dispatch(Contraction, *args)(*args)
if result is not None:
Expand Down Expand Up @@ -286,8 +289,8 @@ def _eager_contract_tensors(reduced_vars, terms, backend):
data = data.reshape(batch_shape + event_shape)
operands.append(data)

for k in reduced_vars:
del inputs[k]
for var in reduced_vars:
inputs.pop(var.name, None)
batch_shape = tuple(v.size for v in inputs.values())
event_shape = broadcast_shape(*(term.shape for term in terms))
einsum_output = ("".join(symbols[k] for k in inputs) +
Expand Down
15 changes: 10 additions & 5 deletions funsor/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, terms):

output = Real
fresh = frozenset(name for name, term in terms)
bound = frozenset()
bound = {}
super(Delta, self).__init__(inputs, output, fresh, bound)
self.terms = terms

Expand Down Expand Up @@ -137,12 +137,17 @@ def eager_subs(self, subs):
return Delta(tuple(new_terms.items())) + log_density if new_terms else log_density

def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.inputs)
if op is ops.logaddexp:
if reduced_vars - self.fresh and self.fresh - reduced_vars:
result = self.eager_reduce(op, reduced_vars & self.fresh) if reduced_vars & self.fresh else self
if result is not self:
result = result.eager_reduce(op, reduced_vars - self.fresh) if reduced_vars - self.fresh else self
return result if result is not self else None
result = self
if not reduced_vars.isdisjoint(self.fresh):
result = result.eager_reduce(op, reduced_vars & self.fresh)
if result is not self:
if not reduced_vars.issubset(self.fresh):
result = result.eager_reduce(op, reduced_vars - self.fresh)
if result is not self:
return result
return None

result_terms = [(name, (point, log_density)) for name, (point, log_density) in self.terms
Expand Down
17 changes: 9 additions & 8 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def __repr__(self):
', '.join('{}={}'.format(*kv) for kv in self.params.items()))

def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.inputs)
if op is ops.logaddexp and isinstance(self.value, Variable) and self.value.name in reduced_vars:
return Number(0.) # distributions are normalized
return super(Distribution, self).eager_reduce(op, reduced_vars)
Expand Down Expand Up @@ -686,7 +687,7 @@ def eager_beta_bernoulli(red_op, bin_op, reduced_vars, x, y):


def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y):
dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars)
dirichlet_reduction = x.input_vars & reduced_vars
if dirichlet_reduction:
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
identity = Tensor(ops.new_eye(funsor.tensor.get_default_prototype(), x.concentration.shape))
Expand All @@ -698,7 +699,7 @@ def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y):


def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y):
dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars)
dirichlet_reduction = x.input_vars & reduced_vars
if dirichlet_reduction:
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.DirichletMultinomial(concentration=x.concentration,
Expand All @@ -709,18 +710,18 @@ def eager_dirichlet_multinomial(red_op, bin_op, reduced_vars, x, y):


def eager_plate_multinomial(op, x, reduced_vars):
if not reduced_vars.isdisjoint(x.probs.inputs):
if not reduced_vars.isdisjoint(x.probs.input_vars):
return None
if not reduced_vars.issubset(x.value.inputs):
if not reduced_vars.issubset(x.value.input_vars):
return None

backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
total_count = x.total_count
for v in reduced_vars:
if v in total_count.inputs:
if v.name in total_count.inputs:
total_count = total_count.reduce(ops.add, v)
else:
total_count = total_count * x.inputs[v].size
total_count = total_count * v.output.size
return backend_dist.Multinomial(total_count=total_count,
probs=x.probs,
value=x.value.reduce(ops.add, reduced_vars))
Expand All @@ -731,7 +732,7 @@ def _log_beta(x, y):


def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y):
gamma_reduction = frozenset(x.inputs).intersection(reduced_vars)
gamma_reduction = x.input_vars & reduced_vars
if gamma_reduction:
unnormalized = (y.concentration - 1) * ops.log(y.value) \
- (y.concentration + x.concentration) * ops.log(y.value + x.rate)
Expand All @@ -742,7 +743,7 @@ def eager_gamma_gamma(red_op, bin_op, reduced_vars, x, y):


def eager_gamma_poisson(red_op, bin_op, reduced_vars, x, y):
gamma_reduction = frozenset(x.inputs).intersection(reduced_vars)
gamma_reduction = x.input_vars & reduced_vars
if gamma_reduction:
backend_dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
return backend_dist.GammaPoisson(concentration=x.concentration,
Expand Down
6 changes: 4 additions & 2 deletions funsor/einsum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from funsor.interpreter import interpretation
from funsor.optimizer import apply_optimizer
from funsor.sum_product import sum_product
from funsor.terms import Funsor, lazy
from funsor.terms import Funsor, Variable, lazy

# TODO: add numpy einsum here
BACKEND_OPS = {
Expand Down Expand Up @@ -49,7 +49,9 @@ def naive_contract_einsum(eqn, *terms, **kwargs):
assert len(output.split(',')) == 1
input_dims = frozenset(d for inp in inputs for d in inp)
output_dims = frozenset(d for d in output)
reduced_vars = input_dims - output_dims
all_inputs = {k: v for term in terms for k, v in term.inputs.items()}
reduced_vars = frozenset(Variable(k, all_inputs[k])
for k in input_dims - output_dims)
return Contraction(sum_op, prod_op, reduced_vars, *terms)


Expand Down
3 changes: 2 additions & 1 deletion funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def __init__(self, info_vec, precision, inputs):

output = Real
fresh = frozenset(inputs.keys())
bound = frozenset()
bound = {}
super(Gaussian, self).__init__(inputs, output, fresh, bound)
self.info_vec = info_vec
self.precision = precision
Expand Down Expand Up @@ -519,6 +519,7 @@ def _eager_subs_affine(self, subs, remaining_subs):
return Subs(result, remaining_subs) if remaining_subs else result

def eager_reduce(self, op, reduced_vars):
assert reduced_vars.issubset(self.inputs)
if op is ops.logaddexp:
# Marginalize out real variables, but keep mixtures lazy.
assert all(v in self.inputs for v in reduced_vars)
Expand Down
Loading

0 comments on commit d3d2ce7

Please sign in to comment.