Skip to content

Commit

Permalink
Separate failing conditions in test_jax_IncSubtensor
Browse files Browse the repository at this point in the history
Also avoid using `at.arange` in these tests as it always yields `ConcretizationTypeError`s in more recent versions of JAX
  • Loading branch information
Ricardo authored and brandonwillard committed Jun 15, 2022
1 parent 2e35e6c commit 2ccd9cc
Showing 1 changed file with 41 additions and 18 deletions.
59 changes: 41 additions & 18 deletions tests/link/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import pytest
from jax._src.errors import NonConcreteBooleanIndexError
from packaging.version import parse as version_parse

import aesara.scalar.basic as aes
Expand Down Expand Up @@ -674,15 +675,11 @@ def test_jax_Subtensors_omni():
compare_jax_and_py(out_fg, [])


@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_IncSubtensor():
rng = np.random.default_rng(213234)

x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))

# "Set" basic indices
st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
Expand All @@ -707,7 +704,7 @@ def test_jax_IncSubtensor():
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])

Expand All @@ -717,14 +714,8 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])

st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])

# "Set" boolean indices
mask_at = at.as_tensor_variable(x_np) > 0
mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
Expand Down Expand Up @@ -753,7 +744,7 @@ def test_jax_IncSubtensor():
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])

Expand All @@ -763,18 +754,50 @@ def test_jax_IncSubtensor():
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])

st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
# "Increment" boolean indices
mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])

# "Increment" boolean indices

def test_jax_IncSubtensors_unsupported():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))

mask_at = at.as_tensor(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])

mask_at = at.as_tensor_variable(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])

st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])

st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])


def test_jax_ifelse():
Expand Down

0 comments on commit 2ccd9cc

Please sign in to comment.