diff --git a/aesara/graph/type.py b/aesara/graph/type.py index f26d726642..f35a6d969f 100644 --- a/aesara/graph/type.py +++ b/aesara/graph/type.py @@ -180,10 +180,10 @@ def convert_variable(self, var: Variable) -> Optional[Variable]: return None - def is_valid_value(self, data: D) -> bool: + def is_valid_value(self, data: D, strict: bool = True) -> bool: """Return ``True`` for any python object that would be a legal value for a `Variable` of this `Type`.""" try: - self.filter(data, strict=True) + self.filter(data, strict=strict) return True except (TypeError, ValueError): return False diff --git a/aesara/tensor/random/type.py b/aesara/tensor/random/type.py index ff00669c4d..68373b08c2 100644 --- a/aesara/tensor/random/type.py +++ b/aesara/tensor/random/type.py @@ -1,9 +1,14 @@ +from typing import Generic, TypeVar + import numpy as np import aesara from aesara.graph.type import Type +T = TypeVar("T", np.random.RandomState, np.random.Generator) + + gen_states_keys = { "MT19937": (["state"], ["key", "pos"]), "PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]), @@ -18,22 +23,15 @@ numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"} -class RandomType(Type): +class RandomType(Type, Generic[T]): r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""" - @classmethod - def filter(cls, data, strict=False, allow_downcast=None): - if cls.is_valid_value(data, strict): - return data - else: - raise TypeError() - @staticmethod - def may_share_memory(a, b): + def may_share_memory(a: T, b: T): return a._bit_generator is b._bit_generator -class RandomStateType(RandomType): +class RandomStateType(RandomType[np.random.RandomState]): r"""A Type wrapper for `numpy.random.RandomState`. The reason this exists (and `Generic` doesn't suffice) is that @@ -49,28 +47,38 @@ class RandomStateType(RandomType): def __repr__(self): return "RandomStateType" - @staticmethod - def is_valid_value(a, strict): - if isinstance(a, np.random.RandomState): - return True + def filter(self, data, strict: bool = False, allow_downcast=None): + """ + XXX: This doesn't convert `data` to the same type of underlying RNG type + as `self`. It really only checks that `data` is of the appropriate type + to be a valid `RandomStateType`. + + In other words, it serves as a `Type.is_valid_value` implementation, + but, because the default `Type.is_valid_value` depends on + `Type.filter`, we need to have it here to avoid surprising circular + dependencies in sub-classes. + """ + if isinstance(data, np.random.RandomState): + return data - if not strict and isinstance(a, dict): + if not strict and isinstance(data, dict): gen_keys = ["bit_generator", "gauss", "has_gauss", "state"] state_keys = ["key", "pos"] for key in gen_keys: - if key not in a: - return False + if key not in data: + raise TypeError() for key in state_keys: - if key not in a["state"]: - return False + if key not in data["state"]: + raise TypeError() - state_key = a["state"]["key"] + state_key = data["state"]["key"] if state_key.shape == (624,) and state_key.dtype == np.uint32: - return True + # TODO: Add an option to convert to a `RandomState` instance? + return data - return False + raise TypeError() @staticmethod def values_eq(a, b): @@ -114,7 +122,7 @@ def __hash__(self): random_state_type = RandomStateType() -class RandomGeneratorType(RandomType): +class RandomGeneratorType(RandomType[np.random.Generator]): r"""A Type wrapper for `numpy.random.Generator`. The reason this exists (and `Generic` doesn't suffice) is that @@ -130,16 +138,25 @@ class RandomGeneratorType(RandomType): def __repr__(self): return "RandomGeneratorType" - @staticmethod - def is_valid_value(a, strict): - if isinstance(a, np.random.Generator): - return True + def filter(self, data, strict=False, allow_downcast=None): + """ + XXX: This doesn't convert `data` to the same type of underlying RNG type + as `self`. It really only checks that `data` is of the appropriate type + to be a valid `RandomGeneratorType`. + + In other words, it serves as a `Type.is_valid_value` implementation, + but, because the default `Type.is_valid_value` depends on + `Type.filter`, we need to have it here to avoid surprising circular + dependencies in sub-classes. + """ + if isinstance(data, np.random.Generator): + return data - if not strict and isinstance(a, dict): - if "bit_generator" not in a: - return False + if not strict and isinstance(data, dict): + if "bit_generator" not in data: + raise TypeError() else: - bit_gen_key = a["bit_generator"] + bit_gen_key = data["bit_generator"] if hasattr(bit_gen_key, "_value"): bit_gen_key = int(bit_gen_key._value) @@ -148,16 +165,16 @@ def is_valid_value(a, strict): gen_keys, state_keys = gen_states_keys[bit_gen_key] for key in gen_keys: - if key not in a: - return False + if key not in data: + raise TypeError() for key in state_keys: - if key not in a["state"]: - return False + if key not in data["state"]: + raise TypeError() - return True + return data - return False + raise TypeError() @staticmethod def values_eq(a, b): diff --git a/tests/tensor/random/test_type.py b/tests/tensor/random/test_type.py index 5f45af2fe0..a34d4cbadc 100644 --- a/tests/tensor/random/test_type.py +++ b/tests/tensor/random/test_type.py @@ -56,15 +56,17 @@ def test_filter(self): with pytest.raises(TypeError): rng_type.filter(1) - rng = rng.get_state(legacy=False) - assert rng_type.is_valid_value(rng, strict=False) + rng_dict = rng.get_state(legacy=False) - rng["state"] = {} + assert rng_type.is_valid_value(rng_dict) is False + assert rng_type.is_valid_value(rng_dict, strict=False) - assert rng_type.is_valid_value(rng, strict=False) is False + rng_dict["state"] = {} - rng = {} - assert rng_type.is_valid_value(rng, strict=False) is False + assert rng_type.is_valid_value(rng_dict, strict=False) is False + + rng_dict = {} + assert rng_type.is_valid_value(rng_dict, strict=False) is False def test_values_eq(self): @@ -147,15 +149,17 @@ def test_filter(self): with pytest.raises(TypeError): rng_type.filter(1) - rng = rng.__getstate__() - assert rng_type.is_valid_value(rng, strict=False) + rng_dict = rng.__getstate__() + + assert rng_type.is_valid_value(rng_dict) is False + assert rng_type.is_valid_value(rng_dict, strict=False) - rng["state"] = {} + rng_dict["state"] = {} - assert rng_type.is_valid_value(rng, strict=False) is False + assert rng_type.is_valid_value(rng_dict, strict=False) is False - rng = {} - assert rng_type.is_valid_value(rng, strict=False) is False + rng_dict = {} + assert rng_type.is_valid_value(rng_dict, strict=False) is False def test_values_eq(self):