diff --git a/tests/link/test_jax.py b/tests/link/test_jax.py index 5c1576c7d7..b5303978f4 100644 --- a/tests/link/test_jax.py +++ b/tests/link/test_jax.py @@ -3,6 +3,7 @@ import numpy as np import pytest +from jax._src.errors import NonConcreteBooleanIndexError from packaging.version import parse as version_parse import aesara.scalar.basic as aes @@ -674,15 +675,11 @@ def test_jax_Subtensors_omni(): compare_jax_and_py(out_fg, []) -@pytest.mark.xfail( - version_parse(jax.__version__) >= version_parse("0.2.12"), - reason="Omnistaging cannot be disabled", -) def test_jax_IncSubtensor(): rng = np.random.default_rng(213234) x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) - x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX) + x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) # "Set" basic indices st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX)) @@ -707,7 +704,7 @@ def test_jax_IncSubtensor(): rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) ) out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_at]) compare_jax_and_py(out_fg, []) @@ -717,14 +714,8 @@ def test_jax_IncSubtensor(): out_fg = FunctionGraph([], [out_at]) compare_jax_and_py(out_fg, []) - st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) - out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) - out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) - # "Set" boolean indices - mask_at = at.as_tensor_variable(x_np) > 0 + mask_at = at.constant(x_np > 0) out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_at]) @@ -753,7 +744,7 @@ def test_jax_IncSubtensor(): rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX) ) out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at) - assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor1) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_at]) compare_jax_and_py(out_fg, []) @@ -763,18 +754,50 @@ def test_jax_IncSubtensor(): out_fg = FunctionGraph([], [out_at]) compare_jax_and_py(out_fg, []) - st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) - out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) + # "Increment" boolean indices + mask_at = at.constant(x_np > 0) + out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_at]) compare_jax_and_py(out_fg, []) - # "Increment" boolean indices + +def test_jax_IncSubtensors_unsupported(): + rng = np.random.default_rng(213234) + x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX) + x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX)) + + mask_at = at.as_tensor(x_np) > 0 + out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises( + NonConcreteBooleanIndexError, match="Array boolean indices must be concrete" + ): + compare_jax_and_py(out_fg, []) + mask_at = at.as_tensor_variable(x_np) > 0 out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0) assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) out_fg = FunctionGraph([], [out_at]) - compare_jax_and_py(out_fg, []) + with pytest.raises( + NonConcreteBooleanIndexError, match="Array boolean indices must be concrete" + ): + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) + out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises(IndexError, match="Array slice indices must have static"): + compare_jax_and_py(out_fg, []) + + st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3]) + out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at) + assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor) + out_fg = FunctionGraph([], [out_at]) + with pytest.raises(IndexError, match="Array slice indices must have static"): + compare_jax_and_py(out_fg, []) def test_jax_ifelse():