diff --git a/funsor/gaussian.py b/funsor/gaussian.py index fcd69640c..119373bac 100644 --- a/funsor/gaussian.py +++ b/funsor/gaussian.py @@ -16,18 +16,18 @@ def _log_det_tri(x): - return x.diagonal(dim1=-1, dim2=-2).log().sum(-1) + return ops.log(ops.diagonal(x, -1, -2)).sum(-1) def _vv(vec1, vec2): """ Computes the inner product ``< vec1 | vec 2 >``. """ - return vec1.unsqueeze(-2).matmul(vec2.unsqueeze(-1)).squeeze(-1).squeeze(-1) + return ops.matmul(ops.unsqueeze(vec1, -2), ops.unsqueeze(vec2, -1)).squeeze(-1).squeeze(-1) def _mv(mat, vec): - return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1) + return ops.matmul(mat, ops.unsqueeze(vec, -1)).squeeze(-1) def _trace_mm(x, y): @@ -36,7 +36,7 @@ def _trace_mm(x, y): """ assert x.dim() >= 2 assert y.dim() >= 2 - return (x * y).sum([-1, -2]) + return (x * y).sum((-1, -2)) def cholesky(u): @@ -44,18 +44,14 @@ def cholesky(u): Like :func:`torch.cholesky` but uses sqrt for scalar matrices. Works around https://github.com/pytorch/pytorch/issues/24403 often. """ - if u.size(-1) == 1: - return u.sqrt() - return u.cholesky() + return ops.cholesky(u) def cholesky_inverse(u): """ Like :func:`torch.cholesky_inverse` but supports batching and gradients. """ - if u.dim() == 2: - return u.cholesky_inverse() - return torch.eye(u.size(-1)).expand(u.size()).cholesky_solve(u) + return ops.cholesky_inverse(u) def _compute_offsets(inputs): diff --git a/funsor/numpy.py b/funsor/numpy.py index d5c1f0e87..513fb1849 100644 --- a/funsor/numpy.py +++ b/funsor/numpy.py @@ -299,3 +299,159 @@ def materialize(x): subs.append((name, arange(name, domain.dtype))) subs = tuple(subs) return substitute(x, subs) + + +################################################################################ +# Register Ops +################################################################################ + + +@ops.abs.register(np.ndarray) +def _abs(x): + return abs(x) + + +@ops.sigmoid.register(np.ndarray) +def _sigmoid(x): + try: + from scipy.special import expit + return expit(x) + except ImportError: + return 1 / (1 + np.exp(-x)) + + +@ops.sqrt.register(np.ndarray) +def _sqrt(x): + return np.sqrt(x) + + +@ops.exp.register(np.ndarray) +def _exp(x): + return np.exp(x) + + +@ops.log.register(np.ndarray) +def _log(x): + return np.log(x) + + +@ops.log1p.register(np.ndarray) +def _log1p(x): + return np.log1p(x) + + +@ops.min.register(np.ndarray, np.ndarray) +def _min(x, y): + return np.minimum(x, y) + + +# TODO: replace (int, float) by object +@ops.min.register((int, float), np.ndarray) +def _min(x, y): + return np.clip(y, a_max=x) + + +@ops.min.register(np.ndarray, (int, float)) +def _min(x, y): + return np.clip(x, a_max=y) + + +@ops.max.register(np.ndarray, np.ndarray) +def _max(x, y): + return np.maximum(x, y) + + +@ops.max.register((int, float), np.ndarray) +def _max(x, y): + return np.clip(y, a_min=x) + + +@ops.max.register(np.ndarray, (int, float)) +def _max(x, y): + return np.clip(x, a_min=y) + + +@ops.reciprocal.register(np.ndarray) +def _reciprocal(x): + result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max) + return result + + +@ops.safesub.register((int, float), np.ndarray) +def _safesub(x, y): + try: + finfo = np.finfo(y.dtype) + except ValueError: + finfo = np.iinfo(y.dtype) + return x + np.clip(-y, a_max=finfo) + + +@ops.safediv.register((int, float), np.ndarray) +def _safediv(x, y): + try: + finfo = np.finfo(y.dtype) + except ValueError: + finfo = np.iinfo(y.dtype) + return x * np.clip(np.reciprocal(y), a_max=finfo) + + +@ops.cholesky.register(np.ndarray) +def _cholesky(x): + """ + Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices. + """ + if x.shape[-1] == 1: + return np.sqrt(x) + return np.linalg.cholesky(x) + + +@ops.cholesky_inverse.register(np.ndarray) +def _cholesky_inverse(x): + """ + Like :func:`torch.cholesky_inverse` but supports batching and gradients. + """ + from scipy.linalg import cho_solve + + return cho_solve((x, False), np.eye(x.shape[-1])) + + +@ops.triangular_solve_op.register(np.ndarray, np.ndarray, bool, bool) +def _triangular_solve(x, y, upper, transpose): + from scipy.linalg import solve_triangular + + return solve_triangular(x, y, trans=int(transpose), lower=not upper) + + +@ops.diagonal.register(np.ndarray, int, int) +def _diagonal(x, dim1, dim2): + return np.diagonal(x, axis1=dim1, axis2=dim2) + + +@ops.cat_op.register(int, [np.ndarray]) +def _cat(dim, *x): + return np.concatenate(x, axis=dim) + + +@ops.new_zeros.register(np.ndarray, tuple) +def _new_zeros(x, shape): + return np.zeros(shape, dtype=x.dtype) + + +@ops.new_eye.register(np.ndarray, tuple) +def _new_eye(x, shape): + return np.broadcast_to(np.eye(shape[-1]), shape + (-1,)) + + +@ops.unsqueeze.register(np.ndarray, int) +def _unsqueeze(x, dim): + return np.expand_dims(x, dim) + + +@ops.expand.register(np.ndarray, tuple) +def _expand(x, shape): + return np.broadcast_to(x, shape) + + +@ops.transpose.register(np.ndarray, int, int) +def _transpose(x, dim0, dim1): + return np.swapaxes(x, dim0, dim1) diff --git a/funsor/ops.py b/funsor/ops.py index ef82f8083..9b2d7a7d7 100644 --- a/funsor/ops.py +++ b/funsor/ops.py @@ -303,6 +303,66 @@ def reciprocal(x): } +# Linear algebra ops + +@Op +def cholesky(x): + raise NotImplementedError + + +@Op +def cholesky_inverse(x): + raise NotImplementedError + + +@Op +def triangular_solve_op(x, y, upper, transpose): + raise NotImplementedError + + +def triangular_solve(x, y, upper=False, transpose=False): + return triangular_solve_op(x, y, upper, transpose) + + +@Op +def cat_op(dim, *x): + raise NotImplementedError + + +def cat(x, dim=0): + return cat_op(dim, *x) + + +@Op +def new_zeros(x, shape): + raise NotImplementedError + + +@Op +def new_eye(x, shape): + raise NotImplementedError + + +@Op +def unsqueeze(x, dim): + raise NotImplementedError + + +@Op +def expand(x, dim): + raise NotImplementedError + + +@Op +def diagonal(x, dim1, dim2): + raise NotImplementedError + + +@Op +def transpose(x, dim0, dim1): + raise NotImplementedError + + __all__ = [ 'AddOp', 'AssociativeOp', @@ -321,8 +381,13 @@ def reciprocal(x): 'abs', 'add', 'and_', + 'cat', + 'cholesky', + 'cholesky_inverse', + 'diagonal', 'eq', 'exp', + 'expand', 'ge', 'getitem', 'gt', @@ -337,6 +402,8 @@ def reciprocal(x): 'mul', 'ne', 'neg', + 'new_eye', + 'new_zeros', 'or_', 'pow', 'safediv', @@ -344,7 +411,10 @@ def reciprocal(x): 'sigmoid', 'sqrt', 'sub', + 'transpose', + 'triangular_solve', 'truediv', + 'unsqueeze', 'xor', ] diff --git a/funsor/torch.py b/funsor/torch.py index f135f8933..0b14ee437 100644 --- a/funsor/torch.py +++ b/funsor/torch.py @@ -1034,17 +1034,80 @@ def _reciprocal(x): @ops.safesub.register(object, torch.Tensor) def _safesub(x, y): try: - return x + -y.clamp(max=torch.finfo(y.dtype).max) + finfo = torch.finfo(y.dtype) except TypeError: - return x + -y.clamp(max=torch.iinfo(y.dtype).max) + finfo = torch.iinfo(y.dtype) + return x + (-y).clamp(max=finfo.max) @ops.safediv.register(object, torch.Tensor) def _safediv(x, y): try: - return x * y.reciprocal().clamp(max=torch.finfo(y.dtype).max) + finfo = torch.finfo(y.dtype) except TypeError: - return x * y.reciprocal().clamp(max=torch.iinfo(y.dtype).max) + finfo = torch.iinfo(y.dtype) + return x * y.reciprocal().clamp(max=finfo.max) + + +@ops.cholesky.register(torch.Tensor) +def _cholesky(x): + """ + Like :func:`torch.cholesky` but uses sqrt for scalar matrices. + Works around https://github.com/pytorch/pytorch/issues/24403 often. + """ + if x.size(-1) == 1: + return x.sqrt() + return x.cholesky() + + +@ops.cholesky_inverse.register(torch.Tensor) +def _cholesky_inverse(x): + """ + Like :func:`torch.cholesky_inverse` but supports batching and gradients. + """ + if x.dim() == 2: + return x.cholesky_inverse() + return torch.eye(x.size(-1)).cholesky_solve(x) + + +@ops.triangular_solve_op.register(torch.Tensor, torch.Tensor, bool, bool) +def _triangular_solve(x, y, upper, transpose): + return x.triangular_solve(y, upper=upper, transpose=transpose).solution + + +@ops.diagonal.register(torch.Tensor, int, int) +def _diagonal(x, dim1, dim2): + return x.diagonal(dim1=dim1, dim2=dim2) + + +@ops.cat_op.register(int, [torch.Tensor]) +def _cat(dim, *x): + return torch.cat(x, dim=dim) + + +@ops.new_zeros.register(torch.Tensor, tuple) +def _new_zeros(x, shape): + return x.new_zeros(shape) + + +@ops.new_eye.register(torch.Tensor, tuple) +def _new_eye(x, shape): + return torch.eye(shape[-1]).expand(shape + (-1,)) + + +@ops.unsqueeze.register(torch.Tensor, int) +def _unsqueeze(x, dim): + return x.unsqueeze(dim) + + +@ops.expand.register(torch.Tensor, tuple) +def _expand(x, shape): + return x.expand(shape) + + +@ops.transpose.register(torch.Tensor, int, int) +def _transpose(x, dim0, dim1): + return x.transpose(dim0, dim1) REDUCE_OP_TO_TORCH = { diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 02871935e..0f1b1988d 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -336,7 +336,7 @@ def test_eager_subs_affine(subs, g_ints, subs_ints): assert issubclass(type(g_subs), GaussianMixture) actual = g_subs(**grounding_subs) expected = g(**ground_subs)(**grounding_subs) - assert_close(actual, expected, atol=1e-3, rtol=1e-4) + assert_close(actual, expected, atol=1e-3, rtol=2e-4) @pytest.mark.parametrize('int_inputs', [ diff --git a/test/test_numpy.py b/test/test_numpy.py index eb54abddc..061600d10 100644 --- a/test/test_numpy.py +++ b/test/test_numpy.py @@ -5,8 +5,9 @@ import funsor from funsor import Number, Variable, bint, reals +from funsor.domains import Domain from funsor.interpreter import _USE_TCO, interpretation -from funsor.numpy import Array +from funsor.numpy import Array, align_arrays from funsor.terms import lazy from funsor.testing import assert_equiv, check_funsor, random_array @@ -219,3 +220,74 @@ def test_align(): for j in range(3): for k in range(4): assert x(i=i, j=j, k=k) == y(i=i, j=j, k=k) + + +def unary_eval(symbol, x): + if symbol in ['~', '-']: + return eval('{} x'.format(symbol)) + if isinstance(x, np.ndarray): + return getattr(funsor.ops, symbol)(x) + return getattr(x, symbol)() + + +@pytest.mark.parametrize('dims', [(), ('a',), ('a', 'b')]) +@pytest.mark.parametrize('symbol', [ + '~', '-', 'abs', 'sqrt', 'exp', 'log', 'log1p', 'sigmoid', +]) +@pytest.mark.xfail(reason='Unary ops do not work with Array input yet.' + ' Issue https://github.com/pyro-ppl/funsor/issues/207') +def test_unary(symbol, dims): + sizes = {'a': 3, 'b': 4} + shape = tuple(sizes[d] for d in dims) + inputs = OrderedDict((d, bint(sizes[d])) for d in dims) + dtype = 'real' + data = np.array(np.random.rand(*shape) + 0.5) + if symbol == '~': + data = data.astype(bool) + dtype = 2 + expected_data = unary_eval(symbol, data) + + x = Array(data, inputs, dtype) + actual = unary_eval(symbol, x) + # FIXME: this raises AttributeErorr: 'Unary' object has no attribute 'data' + check_funsor(actual, inputs, funsor.Domain((), dtype), expected_data) + + +BINARY_OPS = [ + '+', '-', '*', '/', '**', '==', '!=', '<', '<=', '>', '>=', + 'min', 'max', +] +BOOLEAN_OPS = ['&', '|', '^'] + + +def binary_eval(symbol, x, y): + if symbol == 'min': + return funsor.ops.min(x, y) + if symbol == 'max': + return funsor.ops.max(x, y) + return eval('x {} y'.format(symbol)) + + +@pytest.mark.parametrize('dims2', [(), ('a',), ('b', 'a'), ('b', 'c', 'a')]) +@pytest.mark.parametrize('dims1', [(), ('a',), ('a', 'b'), ('b', 'a', 'c')]) +@pytest.mark.parametrize('symbol', BINARY_OPS + BOOLEAN_OPS) +def test_binary_funsor_funsor(symbol, dims1, dims2): + sizes = {'a': 3, 'b': 4, 'c': 5} + shape1 = tuple(sizes[d] for d in dims1) + shape2 = tuple(sizes[d] for d in dims2) + inputs1 = OrderedDict((d, bint(sizes[d])) for d in dims1) + inputs2 = OrderedDict((d, bint(sizes[d])) for d in dims2) + data1 = np.array(np.random.rand(*shape1)) + 0.5 + data2 = np.array(np.random.rand(*shape2)) + 0.5 + dtype = 'real' + if symbol in BOOLEAN_OPS: + dtype = 2 + data1 = data1.astype(bool) + data2 = data2.astype(bool) + x1 = Array(data1, inputs1, dtype) + x2 = Array(data2, inputs2, dtype) + inputs, aligned = align_arrays(x1, x2) + expected_data = binary_eval(symbol, aligned[0], aligned[1]) + + actual = binary_eval(symbol, x1, x2) + check_funsor(actual, inputs, Domain((), dtype), expected_data)