Skip to content

Commit

Permalink
Add optional strict to Type.is_valid_value
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 14, 2022
1 parent 174117f commit 08c97f3
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 51 deletions.
4 changes: 2 additions & 2 deletions aesara/graph/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ def convert_variable(self, var: Variable) -> Optional[Variable]:

return None

def is_valid_value(self, data: D) -> bool:
def is_valid_value(self, data: D, strict: bool = True) -> bool:
"""Return ``True`` for any python object that would be a legal value for a `Variable` of this `Type`."""
try:
self.filter(data, strict=True)
self.filter(data, strict=strict)
return True
except (TypeError, ValueError):
return False
Expand Down
91 changes: 54 additions & 37 deletions aesara/tensor/random/type.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
from typing import Generic, TypeVar

import numpy as np

import aesara
from aesara.graph.type import Type


T = TypeVar("T", np.random.RandomState, np.random.Generator)


gen_states_keys = {
"MT19937": (["state"], ["key", "pos"]),
"PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]),
Expand All @@ -18,22 +23,15 @@
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}


class RandomType(Type):
class RandomType(Type, Generic[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""

@classmethod
def filter(cls, data, strict=False, allow_downcast=None):
if cls.is_valid_value(data, strict):
return data
else:
raise TypeError()

@staticmethod
def may_share_memory(a, b):
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator


class RandomStateType(RandomType):
class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
Expand All @@ -49,28 +47,38 @@ class RandomStateType(RandomType):
def __repr__(self):
return "RandomStateType"

@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.RandomState):
return True
def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data

if not strict and isinstance(a, dict):
if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]

for key in gen_keys:
if key not in a:
return False
if key not in data:
raise TypeError()

for key in state_keys:
if key not in a["state"]:
return False
if key not in data["state"]:
raise TypeError()

state_key = a["state"]["key"]
state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
return True
# TODO: Add an option to convert to a `RandomState` instance?
return data

return False
raise TypeError()

@staticmethod
def values_eq(a, b):
Expand Down Expand Up @@ -114,7 +122,7 @@ def __hash__(self):
random_state_type = RandomStateType()


class RandomGeneratorType(RandomType):
class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that
Expand All @@ -130,16 +138,25 @@ class RandomGeneratorType(RandomType):
def __repr__(self):
return "RandomGeneratorType"

@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.Generator):
return True
def filter(self, data, strict=False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomGeneratorType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.Generator):
return data

if not strict and isinstance(a, dict):
if "bit_generator" not in a:
return False
if not strict and isinstance(data, dict):
if "bit_generator" not in data:
raise TypeError()
else:
bit_gen_key = a["bit_generator"]
bit_gen_key = data["bit_generator"]

if hasattr(bit_gen_key, "_value"):
bit_gen_key = int(bit_gen_key._value)
Expand All @@ -148,16 +165,16 @@ def is_valid_value(a, strict):
gen_keys, state_keys = gen_states_keys[bit_gen_key]

for key in gen_keys:
if key not in a:
return False
if key not in data:
raise TypeError()

for key in state_keys:
if key not in a["state"]:
return False
if key not in data["state"]:
raise TypeError()

return True
return data

return False
raise TypeError()

@staticmethod
def values_eq(a, b):
Expand Down
28 changes: 16 additions & 12 deletions tests/tensor/random/test_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,17 @@ def test_filter(self):
with pytest.raises(TypeError):
rng_type.filter(1)

rng = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng, strict=False)
rng_dict = rng.get_state(legacy=False)

rng["state"] = {}
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)

assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict["state"] = {}

rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
assert rng_type.is_valid_value(rng_dict, strict=False) is False

rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False

def test_values_eq(self):

Expand Down Expand Up @@ -147,15 +149,17 @@ def test_filter(self):
with pytest.raises(TypeError):
rng_type.filter(1)

rng = rng.__getstate__()
assert rng_type.is_valid_value(rng, strict=False)
rng_dict = rng.__getstate__()

assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)

rng["state"] = {}
rng_dict["state"] = {}

assert rng_type.is_valid_value(rng, strict=False) is False
assert rng_type.is_valid_value(rng_dict, strict=False) is False

rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False

def test_values_eq(self):

Expand Down

0 comments on commit 08c97f3

Please sign in to comment.