Skip to content

Commit

Permalink
Parametrize test suite by backend (#333)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Jun 1, 2020
1 parent 8b456dd commit deed901
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 43 deletions.
31 changes: 21 additions & 10 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,11 @@ cache:

install:
- pip install -U pip
- pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html

# Keep track of Pyro dev branch
- pip install https://github.com/pyro-ppl/pyro/archive/dev.zip

# Keep track of pyro-api master branch
- pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip

# Keep track of NumPyro master branch
- pip install https://github.com/pyro-ppl/numpyro/archive/master.zip

- pip install .[torch,jax,test]
- pip install .[test]
- pip freeze

branches:
Expand All @@ -32,5 +25,23 @@ branches:

jobs:
include:
- python: 3.6
script: make test
- stage: default
name: numpy
python: 3.6
script:
- make test
- name: torch
python: 3.6
script:
- pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
# Keep track of Pyro dev branch
- pip install https://github.com/pyro-ppl/pyro/archive/dev.zip
- pip install -e .[torch]
- FUNSOR_BACKEND=torch make test
- name: jax
python: 3.6
script:
- pip install -e .[jax]
# Keep track of NumPyro master branch
- pip install https://github.com/pyro-ppl/numpyro/archive/master.zip
- CI=1 FUNSOR_BACKEND=jax make test
56 changes: 30 additions & 26 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,37 @@ format: FORCE
isort -y

test: lint FORCE
pytest -v -n auto test/*py
FUNSOR_BACKEND=torch pytest -v -n auto test/
FUNSOR_BACKEND=jax pytest -v test/test_tensor.py
FUNSOR_BACKEND=jax pytest -v test/test_gaussian.py
FUNSOR_BACKEND=jax pytest -v test/test_einsum.py
FUNSOR_BACKEND=jax pytest -v test/test_distribution.py
FUNSOR_BACKEND=torch FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py
FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 pytest -v test/test_terms.py
FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py
FUNSOR_BACKEND=torch python examples/discrete_hmm.py -n 2
FUNSOR_BACKEND=torch python examples/discrete_hmm.py -n 2 -t 50 --lazy
FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 50 --lazy
FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 500 --lazy
FUNSOR_BACKEND=torch python examples/kalman_filter.py -n 2
FUNSOR_BACKEND=torch python examples/kalman_filter.py -n 2 -t 50 --lazy
FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 50 --lazy
FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 500 --lazy
FUNSOR_BACKEND=torch python examples/minipyro.py
FUNSOR_BACKEND=torch python examples/minipyro.py --jit
FUNSOR_BACKEND=torch python examples/slds.py -n 2 -t 50
FUNSOR_BACKEND=torch python examples/pcfg.py --size 3
FUNSOR_BACKEND=torch python examples/vae.py --smoke-test
FUNSOR_BACKEND=torch python examples/eeg_slds.py --num-steps 2 --fon --test
FUNSOR_BACKEND=torch python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --smoke
FUNSOR_BACKEND=torch python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --parallel --smoke
FUNSOR_BACKEND=torch python examples/sensor.py --seed=0 --num-frames=2 -n 1
ifeq (${FUNSOR_BACKEND}, torch)
pytest -v -n auto test/
FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py
FUNSOR_USE_TCO=1 pytest -v test/test_terms.py
FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py
python examples/discrete_hmm.py -n 2
python examples/discrete_hmm.py -n 2 -t 50 --lazy
FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 50 --lazy
FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 500 --lazy
python examples/kalman_filter.py -n 2
python examples/kalman_filter.py -n 2 -t 50 --lazy
FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 50 --lazy
FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 500 --lazy
python examples/minipyro.py
python examples/minipyro.py --jit
python examples/slds.py -n 2 -t 50
python examples/pcfg.py --size 3
python examples/vae.py --smoke-test
python examples/eeg_slds.py --num-steps 2 --fon --test
python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --smoke
python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --parallel --smoke
python examples/sensor.py --seed=0 --num-frames=2 -n 1
@echo PASS
else ifeq (${FUNSOR_BACKEND}, jax)
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi
@echo PASS
else
# default backend
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi
@echo PASS
endif

clean: FORCE
git clean -dfx -e funsor-egg.info
Expand Down
1 change: 1 addition & 0 deletions funsor/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def monte_carlo_interpretation(**sample_inputs):

@monte_carlo.register(Integrate, Funsor, Funsor, frozenset)
def monte_carlo_integrate(log_measure, integrand, reduced_vars):
# FIXME: how to pass rng_key to here?
sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs)
if sample is log_measure:
return None # cannot progress
Expand Down
2 changes: 2 additions & 0 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6):
elif isinstance(actual, Contraction) and isinstance(actual.terms[0], Tensor) \
and is_array(actual.terms[0].data):
assert isinstance(expected, Contraction) and is_array(expected.terms[0].data)
elif isinstance(actual, Gaussian) and is_array(actual.info_vec):
assert isinstance(expected, Gaussian) and is_array(expected.info_vec)
else:
assert type(actual) == type(expected), msg

Expand Down
5 changes: 4 additions & 1 deletion test/test_alpha_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@

from collections import OrderedDict

import numpy as np
import pytest

import funsor.ops as ops
from funsor.domains import bint, reals
from funsor.interpreter import gensym, interpretation, reinterpret
from funsor.terms import Cat, Independent, Lambda, Number, Slice, Stack, Variable, reflect
from funsor.testing import assert_close, check_funsor, random_tensor
from funsor.util import get_backend


def test_sample_subs_smoke():
x = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals())
with interpretation(reflect):
z = x(i=1)
actual = z.sample(frozenset({"j"}), OrderedDict({"i": bint(4)}))
rng_key = None if get_backend() == "torch" else np.array([0, 1], dtype=np.uint32)
actual = z.sample(frozenset({"j"}), OrderedDict({"i": bint(4)}), rng_key=rng_key)
check_funsor(actual, {"j": bint(2), "i": bint(4)}, reals())


Expand Down
2 changes: 1 addition & 1 deletion test/test_cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,4 +101,4 @@ def test_eager_contract_tensor_tensor(red_op, bin_op, x_inputs, x_shape, y_input
print(f"reduced_vars = {reduced_vars}")
expected = xy.reduce(red_op, reduced_vars)
actual = Contraction(red_op, bin_op, reduced_vars, (x, y))
assert_close(actual, expected, atol=1e-4, rtol=5e-4 if backend == "jax" else 1e-4)
assert_close(actual, expected, atol=1e-4, rtol=1e-3 if backend == "jax" else 1e-4)
6 changes: 5 additions & 1 deletion test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,14 @@
from funsor.montecarlo import monte_carlo
from funsor.terms import Variable, eager, lazy, moment_matching, normalize, reflect
from funsor.testing import assert_close, random_tensor
from funsor.util import get_backend


@pytest.mark.parametrize('interp', [
reflect, lazy, normalize, eager, moment_matching, monte_carlo])
reflect, lazy, normalize, eager, moment_matching,
pytest.param(monte_carlo, marks=pytest.mark.xfail(
get_backend() == "jax", reason="Lacking pattern to pass rng_key"))
])
def test_integrate(interp):
log_measure = random_tensor(OrderedDict([('i', bint(2)), ('j', bint(3))]))
integrand = random_tensor(OrderedDict([('j', bint(3)), ('k', bint(4))]))
Expand Down
2 changes: 1 addition & 1 deletion test/test_joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def test_reduce_logaddexp(int_inputs, real_inputs):
actual = state.reduce(ops.logaddexp, frozenset(truth))

expected = t + g(**truth)
assert_close(actual, expected, atol=1e-5, rtol=1e-5)
assert_close(actual, expected, atol=1e-5, rtol=1e-4 if get_backend() == "jax" else 1e-5)


def test_reduce_logaddexp_deltas_lazy():
Expand Down
10 changes: 7 additions & 3 deletions test/test_memoize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest

import funsor.ops as ops
Expand Down Expand Up @@ -102,13 +103,16 @@ def test_memoize_sample(check_sample):
else:
from funsor.torch.distributions import Normal

rng_keys = (None, None, None) if get_backend() == "torch" \
else np.array([[0, 1], [0, 2], [0, 3]], dtype=np.uint32)

with memoize():
m, s = numeric_array(0.), numeric_array(1.)
j1 = Normal(m, s, 'x')
j2 = Normal(m, s, 'x')
x1 = j1.sample(frozenset({'x'}))
x12 = j1.sample(frozenset({'x'}))
x2 = j2.sample(frozenset({'x'}))
x1 = j1.sample(frozenset({'x'}), rng_key=rng_keys[0])
x12 = j1.sample(frozenset({'x'}), rng_key=rng_keys[1])
x2 = j2.sample(frozenset({'x'}), rng_key=rng_keys[2])

# this assertion now passes
assert j1 is j2
Expand Down
4 changes: 4 additions & 0 deletions test/test_sum_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import re
import os
from collections import OrderedDict
from functools import partial, reduce

Expand All @@ -25,6 +26,9 @@
from funsor.tensor import Tensor, get_default_prototype
from funsor.terms import Variable, eager_or_die, moment_matching, reflect
from funsor.testing import assert_close, random_gaussian, random_tensor
from funsor.util import get_backend

pytestmark = pytest.mark.skipif((get_backend() == 'jax') and ('CI' in os.environ), reason='slow tests')


@pytest.mark.parametrize('inputs,dims,expected_num_components', [
Expand Down

0 comments on commit deed901

Please sign in to comment.