Skip to content

Commit

Permalink
Simplify numpy op wrappers (#300)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored Jan 9, 2020
1 parent 167bd6f commit 6e60b31
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 41 deletions.
53 changes: 12 additions & 41 deletions funsor/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,44 +305,20 @@ def materialize(x):
# Register Ops
################################################################################


@ops.abs.register(np.ndarray)
def _abs(x):
return abs(x)


@ops.sigmoid.register(np.ndarray)
def _sigmoid(x):
try:
from scipy.special import expit
return expit(x)
except ImportError:
try:
from scipy.special import expit as _sigmoid
except ImportError:
def _sigmoid(x):
return 1 / (1 + np.exp(-x))


@ops.sqrt.register(np.ndarray)
def _sqrt(x):
return np.sqrt(x)


@ops.exp.register(np.ndarray)
def _exp(x):
return np.exp(x)


@ops.log.register(np.ndarray)
def _log(x):
return np.log(x)


@ops.log1p.register(np.ndarray)
def _log1p(x):
return np.log1p(x)


@ops.min.register(np.ndarray, np.ndarray)
def _min(x, y):
return np.minimum(x, y)
ops.abs.register(np.ndarray)(abs)
ops.sigmoid.register(np.ndarray)(_sigmoid)
ops.sqrt.register(np.ndarray)(np.sqrt)
ops.exp.register(np.ndarray)(np.exp)
ops.log.register(np.ndarray)(np.log)
ops.log1p.register(np.ndarray)(np.log1p)
ops.min.register(np.ndarray, np.ndarray)(np.minimum)
ops.max.register(np.ndarray, np.ndarray)(np.maximum)


# TODO: replace (int, float) by object
Expand All @@ -356,11 +332,6 @@ def _min(x, y):
return np.clip(x, a_max=y)


@ops.max.register(np.ndarray, np.ndarray)
def _max(x, y):
return np.maximum(x, y)


@ops.max.register((int, float), np.ndarray)
def _max(x, y):
return np.clip(y, a_min=x)
Expand Down
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ known_third_party = opt_einsum, pyro, pyroapi, torch, torchvision

[tool:pytest]
filterwarnings = error
ignore:numpy.ufunc size changed:RuntimeWarning
ignore:numpy.dtype size changed:RuntimeWarning
ignore::DeprecationWarning
once::DeprecationWarning

Expand Down

0 comments on commit 6e60b31

Please sign in to comment.