Skip to content

Commit

Permalink
Add some ops required for Gaussian (#297)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored and fritzo committed Jan 9, 2020
1 parent 6a60350 commit 167bd6f
Show file tree
Hide file tree
Showing 6 changed files with 373 additions and 16 deletions.
16 changes: 6 additions & 10 deletions funsor/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,18 @@


def _log_det_tri(x):
return x.diagonal(dim1=-1, dim2=-2).log().sum(-1)
return ops.log(ops.diagonal(x, -1, -2)).sum(-1)


def _vv(vec1, vec2):
"""
Computes the inner product ``< vec1 | vec 2 >``.
"""
return vec1.unsqueeze(-2).matmul(vec2.unsqueeze(-1)).squeeze(-1).squeeze(-1)
return ops.matmul(ops.unsqueeze(vec1, -2), ops.unsqueeze(vec2, -1)).squeeze(-1).squeeze(-1)


def _mv(mat, vec):
return torch.matmul(mat, vec.unsqueeze(-1)).squeeze(-1)
return ops.matmul(mat, ops.unsqueeze(vec, -1)).squeeze(-1)


def _trace_mm(x, y):
Expand All @@ -36,26 +36,22 @@ def _trace_mm(x, y):
"""
assert x.dim() >= 2
assert y.dim() >= 2
return (x * y).sum([-1, -2])
return (x * y).sum((-1, -2))


def cholesky(u):
"""
Like :func:`torch.cholesky` but uses sqrt for scalar matrices.
Works around https://github.com/pytorch/pytorch/issues/24403 often.
"""
if u.size(-1) == 1:
return u.sqrt()
return u.cholesky()
return ops.cholesky(u)


def cholesky_inverse(u):
"""
Like :func:`torch.cholesky_inverse` but supports batching and gradients.
"""
if u.dim() == 2:
return u.cholesky_inverse()
return torch.eye(u.size(-1)).expand(u.size()).cholesky_solve(u)
return ops.cholesky_inverse(u)


def _compute_offsets(inputs):
Expand Down
156 changes: 156 additions & 0 deletions funsor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,3 +299,159 @@ def materialize(x):
subs.append((name, arange(name, domain.dtype)))
subs = tuple(subs)
return substitute(x, subs)


################################################################################
# Register Ops
################################################################################


@ops.abs.register(np.ndarray)
def _abs(x):
return abs(x)


@ops.sigmoid.register(np.ndarray)
def _sigmoid(x):
try:
from scipy.special import expit
return expit(x)
except ImportError:
return 1 / (1 + np.exp(-x))


@ops.sqrt.register(np.ndarray)
def _sqrt(x):
return np.sqrt(x)


@ops.exp.register(np.ndarray)
def _exp(x):
return np.exp(x)


@ops.log.register(np.ndarray)
def _log(x):
return np.log(x)


@ops.log1p.register(np.ndarray)
def _log1p(x):
return np.log1p(x)


@ops.min.register(np.ndarray, np.ndarray)
def _min(x, y):
return np.minimum(x, y)


# TODO: replace (int, float) by object
@ops.min.register((int, float), np.ndarray)
def _min(x, y):
return np.clip(y, a_max=x)


@ops.min.register(np.ndarray, (int, float))
def _min(x, y):
return np.clip(x, a_max=y)


@ops.max.register(np.ndarray, np.ndarray)
def _max(x, y):
return np.maximum(x, y)


@ops.max.register((int, float), np.ndarray)
def _max(x, y):
return np.clip(y, a_min=x)


@ops.max.register(np.ndarray, (int, float))
def _max(x, y):
return np.clip(x, a_min=y)


@ops.reciprocal.register(np.ndarray)
def _reciprocal(x):
result = np.clip(np.reciprocal(x), a_max=np.finfo(x.dtype).max)
return result


@ops.safesub.register((int, float), np.ndarray)
def _safesub(x, y):
try:
finfo = np.finfo(y.dtype)
except ValueError:
finfo = np.iinfo(y.dtype)
return x + np.clip(-y, a_max=finfo)


@ops.safediv.register((int, float), np.ndarray)
def _safediv(x, y):
try:
finfo = np.finfo(y.dtype)
except ValueError:
finfo = np.iinfo(y.dtype)
return x * np.clip(np.reciprocal(y), a_max=finfo)


@ops.cholesky.register(np.ndarray)
def _cholesky(x):
"""
Like :func:`numpy.linalg.cholesky` but uses sqrt for scalar matrices.
"""
if x.shape[-1] == 1:
return np.sqrt(x)
return np.linalg.cholesky(x)


@ops.cholesky_inverse.register(np.ndarray)
def _cholesky_inverse(x):
"""
Like :func:`torch.cholesky_inverse` but supports batching and gradients.
"""
from scipy.linalg import cho_solve

return cho_solve((x, False), np.eye(x.shape[-1]))


@ops.triangular_solve_op.register(np.ndarray, np.ndarray, bool, bool)
def _triangular_solve(x, y, upper, transpose):
from scipy.linalg import solve_triangular

return solve_triangular(x, y, trans=int(transpose), lower=not upper)


@ops.diagonal.register(np.ndarray, int, int)
def _diagonal(x, dim1, dim2):
return np.diagonal(x, axis1=dim1, axis2=dim2)


@ops.cat_op.register(int, [np.ndarray])
def _cat(dim, *x):
return np.concatenate(x, axis=dim)


@ops.new_zeros.register(np.ndarray, tuple)
def _new_zeros(x, shape):
return np.zeros(shape, dtype=x.dtype)


@ops.new_eye.register(np.ndarray, tuple)
def _new_eye(x, shape):
return np.broadcast_to(np.eye(shape[-1]), shape + (-1,))


@ops.unsqueeze.register(np.ndarray, int)
def _unsqueeze(x, dim):
return np.expand_dims(x, dim)


@ops.expand.register(np.ndarray, tuple)
def _expand(x, shape):
return np.broadcast_to(x, shape)


@ops.transpose.register(np.ndarray, int, int)
def _transpose(x, dim0, dim1):
return np.swapaxes(x, dim0, dim1)
70 changes: 70 additions & 0 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,66 @@ def reciprocal(x):
}


# Linear algebra ops

@Op
def cholesky(x):
raise NotImplementedError


@Op
def cholesky_inverse(x):
raise NotImplementedError


@Op
def triangular_solve_op(x, y, upper, transpose):
raise NotImplementedError


def triangular_solve(x, y, upper=False, transpose=False):
return triangular_solve_op(x, y, upper, transpose)


@Op
def cat_op(dim, *x):
raise NotImplementedError


def cat(x, dim=0):
return cat_op(dim, *x)


@Op
def new_zeros(x, shape):
raise NotImplementedError


@Op
def new_eye(x, shape):
raise NotImplementedError


@Op
def unsqueeze(x, dim):
raise NotImplementedError


@Op
def expand(x, dim):
raise NotImplementedError


@Op
def diagonal(x, dim1, dim2):
raise NotImplementedError


@Op
def transpose(x, dim0, dim1):
raise NotImplementedError


__all__ = [
'AddOp',
'AssociativeOp',
Expand All @@ -321,8 +381,13 @@ def reciprocal(x):
'abs',
'add',
'and_',
'cat',
'cholesky',
'cholesky_inverse',
'diagonal',
'eq',
'exp',
'expand',
'ge',
'getitem',
'gt',
Expand All @@ -337,14 +402,19 @@ def reciprocal(x):
'mul',
'ne',
'neg',
'new_eye',
'new_zeros',
'or_',
'pow',
'safediv',
'safesub',
'sigmoid',
'sqrt',
'sub',
'transpose',
'triangular_solve',
'truediv',
'unsqueeze',
'xor',
]

Expand Down
71 changes: 67 additions & 4 deletions funsor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,17 +1034,80 @@ def _reciprocal(x):
@ops.safesub.register(object, torch.Tensor)
def _safesub(x, y):
try:
return x + -y.clamp(max=torch.finfo(y.dtype).max)
finfo = torch.finfo(y.dtype)
except TypeError:
return x + -y.clamp(max=torch.iinfo(y.dtype).max)
finfo = torch.iinfo(y.dtype)
return x + (-y).clamp(max=finfo.max)


@ops.safediv.register(object, torch.Tensor)
def _safediv(x, y):
try:
return x * y.reciprocal().clamp(max=torch.finfo(y.dtype).max)
finfo = torch.finfo(y.dtype)
except TypeError:
return x * y.reciprocal().clamp(max=torch.iinfo(y.dtype).max)
finfo = torch.iinfo(y.dtype)
return x * y.reciprocal().clamp(max=finfo.max)


@ops.cholesky.register(torch.Tensor)
def _cholesky(x):
"""
Like :func:`torch.cholesky` but uses sqrt for scalar matrices.
Works around https://github.com/pytorch/pytorch/issues/24403 often.
"""
if x.size(-1) == 1:
return x.sqrt()
return x.cholesky()


@ops.cholesky_inverse.register(torch.Tensor)
def _cholesky_inverse(x):
"""
Like :func:`torch.cholesky_inverse` but supports batching and gradients.
"""
if x.dim() == 2:
return x.cholesky_inverse()
return torch.eye(x.size(-1)).cholesky_solve(x)


@ops.triangular_solve_op.register(torch.Tensor, torch.Tensor, bool, bool)
def _triangular_solve(x, y, upper, transpose):
return x.triangular_solve(y, upper=upper, transpose=transpose).solution


@ops.diagonal.register(torch.Tensor, int, int)
def _diagonal(x, dim1, dim2):
return x.diagonal(dim1=dim1, dim2=dim2)


@ops.cat_op.register(int, [torch.Tensor])
def _cat(dim, *x):
return torch.cat(x, dim=dim)


@ops.new_zeros.register(torch.Tensor, tuple)
def _new_zeros(x, shape):
return x.new_zeros(shape)


@ops.new_eye.register(torch.Tensor, tuple)
def _new_eye(x, shape):
return torch.eye(shape[-1]).expand(shape + (-1,))


@ops.unsqueeze.register(torch.Tensor, int)
def _unsqueeze(x, dim):
return x.unsqueeze(dim)


@ops.expand.register(torch.Tensor, tuple)
def _expand(x, shape):
return x.expand(shape)


@ops.transpose.register(torch.Tensor, int, int)
def _transpose(x, dim0, dim1):
return x.transpose(dim0, dim1)


REDUCE_OP_TO_TORCH = {
Expand Down
Loading

0 comments on commit 167bd6f

Please sign in to comment.