Skip to content

Commit b7bed61

Browse files
committed
Remove unnecessary handling of no longer supported RandomState
1 parent 4378d48 commit b7bed61

File tree

10 files changed

+22
-43
lines changed

10 files changed

+22
-43
lines changed

doc/extending/extending_pytensor_solution_1.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def setup_method(self):
118118
self.op_class = SumDiffOp
119119

120120
def test_perform(self):
121-
rng = np.random.RandomState(43)
121+
rng = np.random.default_rng(43)
122122
x = matrix()
123123
y = matrix()
124124
f = pytensor.function([x, y], self.op_class()(x, y))
@@ -128,7 +128,7 @@ def test_perform(self):
128128
assert np.allclose([x_val + y_val, x_val - y_val], out)
129129

130130
def test_gradient(self):
131-
rng = np.random.RandomState(43)
131+
rng = np.random.default_rng(43)
132132

133133
def output_0(x, y):
134134
return self.op_class()(x, y)[0]
@@ -150,7 +150,7 @@ def output_1(x, y):
150150
)
151151

152152
def test_infer_shape(self):
153-
rng = np.random.RandomState(43)
153+
rng = np.random.default_rng(43)
154154

155155
x = dmatrix()
156156
y = dmatrix()

doc/library/d3viz/index.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@
9595
"noutputs = 10\n",
9696
"nhiddens = 50\n",
9797
"\n",
98-
"rng = np.random.RandomState(0)\n",
98+
"rng = np.random.default_rng(0)\n",
9999
"x = pt.dmatrix('x')\n",
100100
"wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)\n",
101101
"bh = pytensor.shared(np.zeros(nhiddens), borrow=True)\n",

doc/library/d3viz/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ hidden layer and a softmax output layer.
5858
noutputs = 10
5959
nhiddens = 50
6060
61-
rng = np.random.RandomState(0)
61+
rng = np.random.default_rng(0)
6262
x = pt.dmatrix('x')
6363
wh = pytensor.shared(rng.normal(0, 1, (nfeatures, nhiddens)), borrow=True)
6464
bh = pytensor.shared(np.zeros(nhiddens), borrow=True)

doc/optimizations.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ Optimization o4 o3 o2
239239
See :func:`insert_inplace_optimizer`
240240

241241
inplace_random
242-
Typically when a graph uses random numbers, the RandomState is stored
242+
Typically when a graph uses random numbers, the random Generator is stored
243243
in a shared variable, used once per call and, updated after each function
244244
call. In this common case, it makes sense to update the random number generator in-place.
245245

pytensor/compile/monitormode.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,7 @@ def detect_nan(fgraph, i, node, fn):
104104
from pytensor.printing import debugprint
105105

106106
for output in fn.outputs:
107-
if (
108-
not isinstance(output[0], np.random.RandomState | np.random.Generator)
109-
and np.isnan(output[0]).any()
110-
):
107+
if not isinstance(output[0], np.random.Generator) and np.isnan(output[0]).any():
111108
print("*** NaN detected ***") # noqa: T201
112109
debugprint(node)
113110
print(f"Inputs : {[input[0] for input in fn.inputs]}") # noqa: T201

pytensor/compile/nanguardmode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def _is_numeric_value(arr, var):
3434

3535
if isinstance(arr, _cdata_type):
3636
return False
37-
elif isinstance(arr, np.random.mtrand.RandomState | np.random.Generator):
37+
elif isinstance(arr, np.random.Generator):
3838
return False
3939
elif var is not None and isinstance(var.type, RandomType):
4040
return False

pytensor/link/jax/linker.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22

3-
from numpy.random import Generator, RandomState
3+
from numpy.random import Generator
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared
66
from pytensor.link.basic import JITLinker
@@ -21,7 +21,7 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2121

2222
# Replace any shared RNG inputs so that their values can be updated in place
2323
# without affecting the original RNG container. This is necessary because
24-
# JAX does not accept RandomState/Generators as inputs, and they will have to
24+
# JAX does not accept Generators as inputs, and they will have to
2525
# be tipyfied
2626
if shared_rng_inputs:
2727
warnings.warn(
@@ -79,7 +79,7 @@ def create_thunk_inputs(self, storage_map):
7979
thunk_inputs = []
8080
for n in self.fgraph.inputs:
8181
sinput = storage_map[n]
82-
if isinstance(sinput[0], RandomState | Generator):
82+
if isinstance(sinput[0], Generator):
8383
new_value = jax_typify(
8484
sinput[0], dtype=getattr(sinput[0], "dtype", None)
8585
)

pytensor/link/numba/linker.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,4 @@ def jit_compile(self, fn):
1616
return jitted_fn
1717

1818
def create_thunk_inputs(self, storage_map):
19-
from numpy.random import RandomState
20-
21-
from pytensor.link.numba.dispatch import numba_typify
22-
23-
thunk_inputs = []
24-
for n in self.fgraph.inputs:
25-
sinput = storage_map[n]
26-
if isinstance(sinput[0], RandomState):
27-
new_value = numba_typify(
28-
sinput[0], dtype=getattr(sinput[0], "dtype", None)
29-
)
30-
# We need to remove the reference-based connection to the
31-
# original `RandomState`/shared variable's storage, because
32-
# subsequent attempts to use the same shared variable within
33-
# other non-Numba-fied graphs will have problems.
34-
sinput = [new_value]
35-
thunk_inputs.append(sinput)
36-
37-
return thunk_inputs
19+
return [storage_map[n] for n in self.fgraph.inputs]

pytensor/tensor/random/type.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from typing import TypeVar
22

33
import numpy as np
4+
from numpy.random import Generator
45

56
import pytensor
67
from pytensor.graph.type import Type
78

89

9-
T = TypeVar("T", np.random.RandomState, np.random.Generator)
10+
T = TypeVar("T")
1011

1112

1213
gen_states_keys = {
@@ -24,14 +25,10 @@
2425

2526

2627
class RandomType(Type[T]):
27-
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
28-
29-
@staticmethod
30-
def may_share_memory(a: T, b: T):
31-
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
28+
r"""A Type wrapper for `numpy.random.Generator."""
3229

3330

34-
class RandomGeneratorType(RandomType[np.random.Generator]):
31+
class RandomGeneratorType(RandomType[Generator]):
3532
r"""A Type wrapper for `numpy.random.Generator`.
3633
3734
The reason this exists (and `Generic` doesn't suffice) is that
@@ -47,6 +44,10 @@ class RandomGeneratorType(RandomType[np.random.Generator]):
4744
def __repr__(self):
4845
return "RandomGeneratorType"
4946

47+
@staticmethod
48+
def may_share_memory(a: Generator, b: Generator):
49+
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
50+
5051
def filter(self, data, strict=False, allow_downcast=None):
5152
"""
5253
XXX: This doesn't convert `data` to the same type of underlying RNG type
@@ -58,7 +59,7 @@ def filter(self, data, strict=False, allow_downcast=None):
5859
`Type.filter`, we need to have it here to avoid surprising circular
5960
dependencies in sub-classes.
6061
"""
61-
if isinstance(data, np.random.Generator):
62+
if isinstance(data, Generator):
6263
return data
6364

6465
if not strict and isinstance(data, dict):

tests/unittest_tools.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ def fetch_seed(pseed=None):
2727
If config.unittest.rseed is set to "random", it will seed the rng with
2828
None, which is equivalent to seeding with a random seed.
2929
30-
Useful for seeding RandomState or Generator objects.
31-
>>> rng = np.random.RandomState(fetch_seed())
30+
Useful for seeding Generator objects.
3231
>>> rng = np.random.default_rng(fetch_seed())
3332
"""
3433

0 commit comments

Comments
 (0)