Skip to content

Commit

Permalink
Fix jax.tree_util future warnings (pyro-ppl#1464)
Browse files Browse the repository at this point in the history
* Fix jax.tree_util future warnings

* fix docker link

* Filter out warnings in contrib modules

* fix lint

* bypass isort in cvae

* Fix lint
  • Loading branch information
fehiepsi authored Aug 6, 2022
1 parent 47f56c9 commit fe01f02
Show file tree
Hide file tree
Showing 9 changed files with 21 additions and 10 deletions.
5 changes: 2 additions & 3 deletions docker/release/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04

# declare the image name
# note that this image uses Python 3.8
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
JAXLIB_CUDA=111
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04

# install python3 and pip on top of the base Ubuntu image
RUN apt update && \
Expand All @@ -21,4 +20,4 @@ ENV PATH=/root/.local/bin:$PATH
RUN pip3 install --user \
# we pull wheels from google's api as per https://github.com/google/jax#installation
# the pre-compiled wheels that google provides work for now. This may change in the future (and necessitate building from source)
numpyro[cuda${JAXLIB_CUDA}] -f https://storage.googleapis.com/jax-releases/jax_releases.html
numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
3 changes: 2 additions & 1 deletion examples/cvae-flax/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from data import load_dataset
import matplotlib.pyplot as plt
from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model
from train_baseline import train_baseline
from train_cvae import train_cvae

from numpyro.examples.datasets import MNIST

from models import BaselineNet, Decoder, Encoder, cvae_guide, cvae_model # isort:skip


def main(args):
train_init, train_fetch = load_dataset(
Expand Down
4 changes: 2 additions & 2 deletions examples/cvae-flax/train_baseline.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from models import cross_entropy_loss

from flax.training.train_state import TrainState
import jax
from jax import lax, numpy as jnp, random
import optax

from models import cross_entropy_loss # isort:skip


def create_train_state(model, x, learning_rate_fn):
params = model.init(random.PRNGKey(0), x)
Expand Down
3 changes: 2 additions & 1 deletion numpyro/contrib/control_flow/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from collections import OrderedDict
from functools import partial

from jax import device_put, lax, random, tree_flatten, tree_map, tree_unflatten
from jax import device_put, lax, random
import jax.numpy as jnp
from jax.tree_util import tree_flatten, tree_map, tree_unflatten

from numpyro import handlers
from numpyro.ops.pytree import PytreeTrace
Expand Down
3 changes: 2 additions & 1 deletion numpyro/contrib/tfp/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from collections import namedtuple
import inspect

from jax import random, tree_map, vmap
from jax import random, vmap
from jax.flatten_util import ravel_pytree
import jax.numpy as jnp
from jax.tree_util import tree_map
import tensorflow_probability.substrates.jax as tfp

from numpyro.infer import init_to_uniform
Expand Down
4 changes: 2 additions & 2 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import numpy as np

import jax
from jax import device_get, jacfwd, lax, random, tree_flatten, value_and_grad
from jax import device_get, jacfwd, lax, random, value_and_grad
from jax.flatten_util import ravel_pytree
from jax.lax import broadcast_shapes
import jax.numpy as jnp
from jax.tree_util import tree_map
from jax.tree_util import tree_flatten, tree_map

import numpyro
from numpyro.distributions import constraints
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ skip=docs
filterwarnings = error
ignore:numpy.ufunc size changed,:RuntimeWarning
ignore:Using a non-tuple sequence:FutureWarning
ignore:jax.tree_structure is deprecated:FutureWarning
ignore:numpy.linalg support is experimental:UserWarning
ignore:scipy.linalg support is experimental:UserWarning
once:No GPU:UserWarning
Expand Down
4 changes: 4 additions & 0 deletions test/contrib/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

pytestmark = pytest.mark.filterwarnings(
"ignore:jax.tree_.+ is deprecated:FutureWarning"
)


def haiku_model_by_shape(x, y):
import haiku as hk
Expand Down
4 changes: 4 additions & 0 deletions test/contrib/test_nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import numpyro.distributions as dist
from numpyro.distributions.transforms import AffineTransform, ExpTransform

pytestmark = pytest.mark.filterwarnings(
"ignore:jax.tree_.+ is deprecated:FutureWarning"
)


# Test helper to extract a few central moments from samples.
def get_moments(x):
Expand Down

0 comments on commit fe01f02

Please sign in to comment.