From deed90145139cd63a394cbc56b8736d5cf200766 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 1 Jun 2020 02:22:49 -0400 Subject: [PATCH] Parametrize test suite by backend (#333) --- .travis.yml | 31 ++++++++++++------- Makefile | 56 +++++++++++++++++++---------------- funsor/montecarlo.py | 1 + funsor/testing.py | 2 ++ test/test_alpha_conversion.py | 5 +++- test/test_cnf.py | 2 +- test/test_integrate.py | 6 +++- test/test_joint.py | 2 +- test/test_memoize.py | 10 +++++-- test/test_sum_product.py | 4 +++ 10 files changed, 76 insertions(+), 43 deletions(-) diff --git a/.travis.yml b/.travis.yml index 6253a2660..646d91d22 100644 --- a/.travis.yml +++ b/.travis.yml @@ -12,18 +12,11 @@ cache: install: - pip install -U pip - - pip install torch==1.4.0+cpu torchvision==0.5.0+cpu -f https://download.pytorch.org/whl/torch_stable.html - - # Keep track of Pyro dev branch - - pip install https://github.com/pyro-ppl/pyro/archive/dev.zip # Keep track of pyro-api master branch - pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip - # Keep track of NumPyro master branch - - pip install https://github.com/pyro-ppl/numpyro/archive/master.zip - - - pip install .[torch,jax,test] + - pip install .[test] - pip freeze branches: @@ -32,5 +25,23 @@ branches: jobs: include: - - python: 3.6 - script: make test + - stage: default + name: numpy + python: 3.6 + script: + - make test + - name: torch + python: 3.6 + script: + - pip install torch==1.5.0+cpu torchvision==0.6.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + # Keep track of Pyro dev branch + - pip install https://github.com/pyro-ppl/pyro/archive/dev.zip + - pip install -e .[torch] + - FUNSOR_BACKEND=torch make test + - name: jax + python: 3.6 + script: + - pip install -e .[jax] + # Keep track of NumPyro master branch + - pip install https://github.com/pyro-ppl/numpyro/archive/master.zip + - CI=1 FUNSOR_BACKEND=jax make test diff --git a/Makefile b/Makefile index 91f363e0b..46bd0306c 100644 --- a/Makefile +++ b/Makefile @@ -19,33 +19,37 @@ format: FORCE isort -y test: lint FORCE - pytest -v -n auto test/*py - FUNSOR_BACKEND=torch pytest -v -n auto test/ - FUNSOR_BACKEND=jax pytest -v test/test_tensor.py - FUNSOR_BACKEND=jax pytest -v test/test_gaussian.py - FUNSOR_BACKEND=jax pytest -v test/test_einsum.py - FUNSOR_BACKEND=jax pytest -v test/test_distribution.py - FUNSOR_BACKEND=torch FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py - FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 pytest -v test/test_terms.py - FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py - FUNSOR_BACKEND=torch python examples/discrete_hmm.py -n 2 - FUNSOR_BACKEND=torch python examples/discrete_hmm.py -n 2 -t 50 --lazy - FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 50 --lazy - FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 500 --lazy - FUNSOR_BACKEND=torch python examples/kalman_filter.py -n 2 - FUNSOR_BACKEND=torch python examples/kalman_filter.py -n 2 -t 50 --lazy - FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 50 --lazy - FUNSOR_BACKEND=torch FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 500 --lazy - FUNSOR_BACKEND=torch python examples/minipyro.py - FUNSOR_BACKEND=torch python examples/minipyro.py --jit - FUNSOR_BACKEND=torch python examples/slds.py -n 2 -t 50 - FUNSOR_BACKEND=torch python examples/pcfg.py --size 3 - FUNSOR_BACKEND=torch python examples/vae.py --smoke-test - FUNSOR_BACKEND=torch python examples/eeg_slds.py --num-steps 2 --fon --test - FUNSOR_BACKEND=torch python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --smoke - FUNSOR_BACKEND=torch python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --parallel --smoke - FUNSOR_BACKEND=torch python examples/sensor.py --seed=0 --num-frames=2 -n 1 +ifeq (${FUNSOR_BACKEND}, torch) + pytest -v -n auto test/ + FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py + FUNSOR_USE_TCO=1 pytest -v test/test_terms.py + FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py + python examples/discrete_hmm.py -n 2 + python examples/discrete_hmm.py -n 2 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/discrete_hmm.py -n 1 -t 500 --lazy + python examples/kalman_filter.py -n 2 + python examples/kalman_filter.py -n 2 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 50 --lazy + FUNSOR_USE_TCO=1 python examples/kalman_filter.py -n 1 -t 500 --lazy + python examples/minipyro.py + python examples/minipyro.py --jit + python examples/slds.py -n 2 -t 50 + python examples/pcfg.py --size 3 + python examples/vae.py --smoke-test + python examples/eeg_slds.py --num-steps 2 --fon --test + python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --smoke + python examples/mixed_hmm/experiment.py -d seal -i discrete -g discrete -zi --parallel --smoke + python examples/sensor.py --seed=0 --num-frames=2 -n 1 @echo PASS +else ifeq (${FUNSOR_BACKEND}, jax) + pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi + @echo PASS +else + # default backend + pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi + @echo PASS +endif clean: FORCE git clean -dfx -e funsor-egg.info diff --git a/funsor/montecarlo.py b/funsor/montecarlo.py index 3e6f42d5d..a26967fb6 100644 --- a/funsor/montecarlo.py +++ b/funsor/montecarlo.py @@ -44,6 +44,7 @@ def monte_carlo_interpretation(**sample_inputs): @monte_carlo.register(Integrate, Funsor, Funsor, frozenset) def monte_carlo_integrate(log_measure, integrand, reduced_vars): + # FIXME: how to pass rng_key to here? sample = log_measure.sample(reduced_vars, monte_carlo.sample_inputs) if sample is log_measure: return None # cannot progress diff --git a/funsor/testing.py b/funsor/testing.py index 5b0474bd2..c309fbd5f 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -78,6 +78,8 @@ def assert_close(actual, expected, atol=1e-6, rtol=1e-6): elif isinstance(actual, Contraction) and isinstance(actual.terms[0], Tensor) \ and is_array(actual.terms[0].data): assert isinstance(expected, Contraction) and is_array(expected.terms[0].data) + elif isinstance(actual, Gaussian) and is_array(actual.info_vec): + assert isinstance(expected, Gaussian) and is_array(expected.info_vec) else: assert type(actual) == type(expected), msg diff --git a/test/test_alpha_conversion.py b/test/test_alpha_conversion.py index 2749367eb..96bb72679 100644 --- a/test/test_alpha_conversion.py +++ b/test/test_alpha_conversion.py @@ -3,6 +3,7 @@ from collections import OrderedDict +import numpy as np import pytest import funsor.ops as ops @@ -10,13 +11,15 @@ from funsor.interpreter import gensym, interpretation, reinterpret from funsor.terms import Cat, Independent, Lambda, Number, Slice, Stack, Variable, reflect from funsor.testing import assert_close, check_funsor, random_tensor +from funsor.util import get_backend def test_sample_subs_smoke(): x = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) with interpretation(reflect): z = x(i=1) - actual = z.sample(frozenset({"j"}), OrderedDict({"i": bint(4)})) + rng_key = None if get_backend() == "torch" else np.array([0, 1], dtype=np.uint32) + actual = z.sample(frozenset({"j"}), OrderedDict({"i": bint(4)}), rng_key=rng_key) check_funsor(actual, {"j": bint(2), "i": bint(4)}, reals()) diff --git a/test/test_cnf.py b/test/test_cnf.py index d5adc65c3..fc995e7fa 100644 --- a/test/test_cnf.py +++ b/test/test_cnf.py @@ -101,4 +101,4 @@ def test_eager_contract_tensor_tensor(red_op, bin_op, x_inputs, x_shape, y_input print(f"reduced_vars = {reduced_vars}") expected = xy.reduce(red_op, reduced_vars) actual = Contraction(red_op, bin_op, reduced_vars, (x, y)) - assert_close(actual, expected, atol=1e-4, rtol=5e-4 if backend == "jax" else 1e-4) + assert_close(actual, expected, atol=1e-4, rtol=1e-3 if backend == "jax" else 1e-4) diff --git a/test/test_integrate.py b/test/test_integrate.py index 513dd67b5..07f013131 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -12,10 +12,14 @@ from funsor.montecarlo import monte_carlo from funsor.terms import Variable, eager, lazy, moment_matching, normalize, reflect from funsor.testing import assert_close, random_tensor +from funsor.util import get_backend @pytest.mark.parametrize('interp', [ - reflect, lazy, normalize, eager, moment_matching, monte_carlo]) + reflect, lazy, normalize, eager, moment_matching, + pytest.param(monte_carlo, marks=pytest.mark.xfail( + get_backend() == "jax", reason="Lacking pattern to pass rng_key")) +]) def test_integrate(interp): log_measure = random_tensor(OrderedDict([('i', bint(2)), ('j', bint(3))])) integrand = random_tensor(OrderedDict([('j', bint(3)), ('k', bint(4))])) diff --git a/test/test_joint.py b/test/test_joint.py index 3e719c5d4..d506a0311 100644 --- a/test/test_joint.py +++ b/test/test_joint.py @@ -158,7 +158,7 @@ def test_reduce_logaddexp(int_inputs, real_inputs): actual = state.reduce(ops.logaddexp, frozenset(truth)) expected = t + g(**truth) - assert_close(actual, expected, atol=1e-5, rtol=1e-5) + assert_close(actual, expected, atol=1e-5, rtol=1e-4 if get_backend() == "jax" else 1e-5) def test_reduce_logaddexp_deltas_lazy(): diff --git a/test/test_memoize.py b/test/test_memoize.py index 0d6a82572..79206d534 100644 --- a/test/test_memoize.py +++ b/test/test_memoize.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import numpy as np import pytest import funsor.ops as ops @@ -102,13 +103,16 @@ def test_memoize_sample(check_sample): else: from funsor.torch.distributions import Normal + rng_keys = (None, None, None) if get_backend() == "torch" \ + else np.array([[0, 1], [0, 2], [0, 3]], dtype=np.uint32) + with memoize(): m, s = numeric_array(0.), numeric_array(1.) j1 = Normal(m, s, 'x') j2 = Normal(m, s, 'x') - x1 = j1.sample(frozenset({'x'})) - x12 = j1.sample(frozenset({'x'})) - x2 = j2.sample(frozenset({'x'})) + x1 = j1.sample(frozenset({'x'}), rng_key=rng_keys[0]) + x12 = j1.sample(frozenset({'x'}), rng_key=rng_keys[1]) + x2 = j2.sample(frozenset({'x'}), rng_key=rng_keys[2]) # this assertion now passes assert j1 is j2 diff --git a/test/test_sum_product.py b/test/test_sum_product.py index 48ecd3d1f..4b9ab0dcc 100644 --- a/test/test_sum_product.py +++ b/test/test_sum_product.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import re +import os from collections import OrderedDict from functools import partial, reduce @@ -25,6 +26,9 @@ from funsor.tensor import Tensor, get_default_prototype from funsor.terms import Variable, eager_or_die, moment_matching, reflect from funsor.testing import assert_close, random_gaussian, random_tensor +from funsor.util import get_backend + +pytestmark = pytest.mark.skipif((get_backend() == 'jax') and ('CI' in os.environ), reason='slow tests') @pytest.mark.parametrize('inputs,dims,expected_num_components', [