Skip to content

Commit

Permalink
Support copy, deepcopy, pickle, detach (#350)
Browse files Browse the repository at this point in the history
* Support copy, deepcopy, pickle, detach

* Simplify

* Simplify

* Fix GetitemOp

* Fix ReshapeOp

* address review comment
  • Loading branch information
fritzo authored Aug 16, 2020
1 parent 769b04a commit ad8c709
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 13 deletions.
9 changes: 9 additions & 0 deletions funsor/domains.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ def __new__(cls, shape, dtype):
raise ValueError(repr(dtype))
return super(Domain, cls).__new__(cls, shape, dtype)

def __copy__(self):
return self

def __deepcopy__(self, memo):
return self

def __reduce__(self):
return Domain, (self.shape, self.dtype)

def __repr__(self):
shape = tuple(self.shape)
if isinstance(self.dtype, int):
Expand Down
25 changes: 21 additions & 4 deletions funsor/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,24 @@


class Op(Dispatcher):
def __init__(self, fn):
super(Op, self).__init__(fn.__name__)
def __init__(self, fn, *, name=None):
if name is None:
name = fn.__name__
super(Op, self).__init__(name)
# register as default operation
for nargs in (1, 2):
default_signature = (object,) * nargs
self.add(default_signature, fn)

def __copy__(self):
return self

def __deepcopy__(self, memo):
return self

def __reduce__(self):
return self.__name__

def __repr__(self):
return "ops." + self.__name__

Expand Down Expand Up @@ -124,6 +135,9 @@ def __init__(self, shape):
self.shape = shape
super().__init__(self._default)

def __reduce__(self):
return ReshapeOp, (self.shape,)

def _default(self, x):
return x.reshape(self.shape)

Expand Down Expand Up @@ -152,6 +166,9 @@ def __init__(self, offset):
super(GetitemOp, self).__init__(self._default)
self.__name__ = 'GetitemOp({})'.format(offset)

def __reduce__(self):
return GetitemOp, (self.offset,)

def _default(self, x, y):
return x[self._prefix + (y,)] if self.offset else x[y]

Expand Down Expand Up @@ -264,8 +281,8 @@ def _logaddexp(x, y):
return log(exp(x - shift) + exp(y - shift)) + shift


logaddexp = LogAddExpOp(_logaddexp)
sample = SampleOp(_logaddexp)
logaddexp = LogAddExpOp(_logaddexp, name="logaddexp")
sample = SampleOp(_logaddexp, name="sample")


@SubOp
Expand Down
14 changes: 10 additions & 4 deletions funsor/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,10 @@ class FunsorMeta(type):
defaults and do type conversion, thereby simplifying logic of
interpretations.
2. Ensure each Funsor class has an attribute ``._ast_fields`` describing
its input args and each Funsor instance has an attribute ``._ast_args``
with values corresponding to its input args. This allows the instance
to be reflectively reconstructed under a different interpretation, and
is used by :func:`funsor.interpreter.reinterpret`.
its input args and each Funsor instance has an attribute
``._ast_values`` with values corresponding to its input args. This
allows the instance to be reflectively reconstructed under a different
interpretation, and is used by :func:`funsor.interpreter.reinterpret`.
3. Cons-hash construction, so that repeatedly calling the constructor
with identical args will product the same object. This enables cheap
syntactic equality testing using the ``is`` operator, which is
Expand Down Expand Up @@ -355,6 +355,12 @@ def dtype(self):
def shape(self):
return self.output.shape

def __copy__(self):
return self

def __reduce__(self):
return type(self).__origin__, self._ast_values

def __hash__(self):
return id(self)

Expand Down
3 changes: 1 addition & 2 deletions scripts/update_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
lineno += 1
while lines[lineno].startswith(comment.format("Copyright")):
lineno += 1
changed = True

# Ensure next line is an SPDX short identifier.
if not lines[lineno].startswith(comment.format("SPDX-License-Identifier")):
Expand All @@ -49,7 +48,7 @@
lineno += 1

# Ensure next line is blank.
if not lines[lineno].isspace():
if lineno < len(lines) and not lines[lineno].isspace():
lines.insert(lineno, "\n")
changed = True

Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ per-file-ignores =
[isort]
line_length = 120
multi_line_output=3
not_skip = __init__.py
known_first_party = funsor, test
known_third_party = opt_einsum, pyro, pyroapi, torch, torchvision

Expand Down
24 changes: 24 additions & 0 deletions test/test_domains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import io
import pickle

import pytest

from funsor.domains import bint, reals # noqa F401


@pytest.mark.parametrize('expr', [
"bint(2)",
"reals()",
"reals(4)",
"reals(3, 2)",
])
def test_pickle(expr):
x = eval(expr)
f = io.BytesIO()
pickle.dump(x, f)
f.seek(0)
y = pickle.load(f)
assert y == x # TODO promote to "y is x"
58 changes: 58 additions & 0 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import copy
import itertools
import io
import pickle
from collections import OrderedDict

import numpy as np
Expand Down Expand Up @@ -61,6 +64,26 @@ def test_cons_hash():
assert Tensor(x) is Tensor(x)


def test_copy():
data = randn(3, 2)
x = Tensor(data)
assert copy.copy(x) is x


def test_deepcopy():
data = randn(3, 2)
x = Tensor(data)

y = copy.deepcopy(x)
assert_close(x, y)
assert y is not x
assert y.data is not x.data

memo = {id(data): data}
z = copy.deepcopy(x, memo)
assert z is x


def test_indexing():
data = randn((4, 5))
inputs = OrderedDict([('i', bint(4)),
Expand Down Expand Up @@ -953,3 +976,38 @@ def test_log_correct_dtype():
assert (x == x).all().log().data.dtype is x.data.dtype
finally:
torch.set_default_dtype(old_dtype)


@pytest.mark.skipif(get_backend() != "numpy", reason="backend-specific")
def test_pickle():
x = Tensor(randn(2, 3))
f = io.BytesIO()
pickle.dump(x, f)
f.seek(0)
y = pickle.load(f)
assert_close(x, y)


@pytest.mark.skipif(get_backend() != "torch", reason="backend-specific")
def test_torch_save():
import torch
x = Tensor(randn(2, 3))
f = io.BytesIO()
torch.save(x, f)
f.seek(0)
y = torch.load(f)
assert_close(x, y)


@pytest.mark.skipif(get_backend() != "torch", reason="backend-specific")
def test_detach():
import torch
try:
from pyro.distributions.util import detach
except ImportError:
pytest.skip("detach() is not available")
x = Tensor(torch.randn(2, 3, requires_grad=True))
y = detach(x)
assert_close(x, y)
assert x.data.requires_grad
assert not y.data.requires_grad
38 changes: 36 additions & 2 deletions test/test_terms.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import copy
import itertools
import io
import pickle
import typing
from collections import OrderedDict
from functools import reduce
Expand Down Expand Up @@ -113,15 +116,46 @@ def test_quote(interp):
check_quote(Lambda(Variable('i', bint(2)), z))


@pytest.mark.parametrize('expr', [
EXPR_STRINGS = [
"Variable('x', bint(3))",
"Variable('x', reals())",
"Number(0.)",
"Number(1, dtype=10)",
"-Variable('x', reals())",
"Variable('x', reals(3))[Variable('i', bint(3))]",
"Variable('x', reals(2, 2)).reshape((4,))",
"Variable('x', reals()) + Variable('y', reals())",
"Variable('x', reals())(x=Number(0.))",
])
"Number(1) / Variable('x', reals())",
"Stack('i', (Number(0), Variable('z', reals())))",
"Cat('i', (Stack('i', (Number(0),)), Stack('i', (Number(1), Number(2)))))",
"Stack('t', (Number(1), Variable('x', reals()))).reduce(ops.logaddexp, 't')",
]


@pytest.mark.parametrize('expr', EXPR_STRINGS)
def test_copy_immutable(expr):
x = eval(expr)
assert copy.copy(x) is x


@pytest.mark.parametrize('expr', EXPR_STRINGS)
def test_deepcopy_immutable(expr):
x = eval(expr)
assert copy.deepcopy(x) is x


@pytest.mark.parametrize('expr', EXPR_STRINGS)
def test_pickle(expr):
x = eval(expr)
f = io.BytesIO()
pickle.dump(x, f)
f.seek(0)
y = pickle.load(f)
assert y is x


@pytest.mark.parametrize('expr', EXPR_STRINGS)
def test_reinterpret(expr):
x = eval(expr)
assert funsor.reinterpret(x) is x
Expand Down

0 comments on commit ad8c709

Please sign in to comment.