1
1
from typing import TypeVar
2
2
3
3
import numpy as np
4
+ from numpy .random import Generator
4
5
5
6
import pytensor
6
7
from pytensor .graph .type import Type
7
8
8
9
9
- T = TypeVar ("T" , np . random . RandomState , np . random . Generator )
10
+ T = TypeVar ("T" )
10
11
11
12
12
13
gen_states_keys = {
24
25
25
26
26
27
class RandomType (Type [T ]):
27
- r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
28
-
29
- @staticmethod
30
- def may_share_memory (a : T , b : T ):
31
- return a ._bit_generator is b ._bit_generator # type: ignore[attr-defined]
28
+ r"""A Type wrapper for `numpy.random.Generator."""
32
29
33
30
34
- class RandomGeneratorType (RandomType [np . random . Generator ]):
31
+ class RandomGeneratorType (RandomType [Generator ]):
35
32
r"""A Type wrapper for `numpy.random.Generator`.
36
33
37
34
The reason this exists (and `Generic` doesn't suffice) is that
@@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
47
44
def __repr__ (self ):
48
45
return "RandomGeneratorType"
49
46
47
+ @staticmethod
48
+ def may_share_memory (a : Generator , b : Generator ):
49
+ return a ._bit_generator is b ._bit_generator # type: ignore[attr-defined]
50
+
50
51
def filter (self , data , strict = False , allow_downcast = None ):
51
52
"""
52
53
XXX: This doesn't convert `data` to the same type of underlying RNG type
@@ -58,7 +59,7 @@ def filter(self, data, strict=False, allow_downcast=None):
58
59
`Type.filter`, we need to have it here to avoid surprising circular
59
60
dependencies in sub-classes.
60
61
"""
61
- if isinstance (data , np . random . Generator ):
62
+ if isinstance (data , Generator ):
62
63
return data
63
64
64
65
if not strict and isinstance (data , dict ):
0 commit comments