Skip to content

Commit 1dead4c

Browse files
committed
Remove RandomState type in remaining backends
1 parent ea96384 commit 1dead4c

File tree

15 files changed

+52
-394
lines changed

15 files changed

+52
-394
lines changed

pytensor/link/jax/dispatch/random.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import jax
44
import numpy as np
5-
from numpy.random import Generator, RandomState
5+
from numpy.random import Generator
66
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
77
_coerce_to_uint32_array,
88
)
@@ -54,15 +54,6 @@ def assert_size_argument_jax_compatible(node):
5454
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
5555

5656

57-
@jax_typify.register(RandomState)
58-
def jax_typify_RandomState(state, **kwargs):
59-
state = state.get_state(legacy=False)
60-
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
61-
# XXX: Is this a reasonable approach?
62-
state["jax_state"] = state["state"]["key"][0:2]
63-
return state
64-
65-
6657
@jax_typify.register(Generator)
6758
def jax_typify_Generator(rng, **kwargs):
6859
state = rng.__getstate__()
@@ -214,7 +205,6 @@ def sample_fn(rng, size, dtype, p):
214205
return sample_fn
215206

216207

217-
@jax_sample_fn.register(ptr.RandIntRV)
218208
@jax_sample_fn.register(ptr.IntegersRV)
219209
@jax_sample_fn.register(ptr.UniformRV)
220210
def jax_sample_fn_uniform(op, node):

pytensor/link/numba/dispatch/random.py

-4
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
)
2626
from pytensor.tensor import get_vector_length
2727
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
28-
from pytensor.tensor.random.type import RandomStateType
2928
from pytensor.tensor.type_other import NoneTypeT
3029
from pytensor.tensor.utils import _parse_gufunc_signature
3130

@@ -348,9 +347,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
348347

349348
[rv_node] = op.fgraph.apply_nodes
350349
rv_op: RandomVariable = rv_node.op
351-
rng_param = rv_op.rng_param(rv_node)
352-
if isinstance(rng_param.type, RandomStateType):
353-
raise TypeError("Numba does not support NumPy `RandomStateType`s")
354350
size = rv_op.size_param(rv_node)
355351
dist_params = rv_op.dist_params(rv_node)
356352
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)

pytensor/tensor/random/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
import pytensor.tensor.random.rewriting
33
import pytensor.tensor.random.utils
44
from pytensor.tensor.random.basic import *
5-
from pytensor.tensor.random.op import RandomState, default_rng
5+
from pytensor.tensor.random.op import default_rng
66
from pytensor.tensor.random.utils import RandomStream

pytensor/tensor/random/basic.py

+1-67
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,10 @@
99
from pytensor.tensor.basic import as_tensor_variable
1010
from pytensor.tensor.math import sqrt
1111
from pytensor.tensor.random.op import RandomVariable
12-
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
1312
from pytensor.tensor.random.utils import (
1413
broadcast_params,
1514
normalize_size_param,
1615
)
17-
from pytensor.tensor.random.var import (
18-
RandomGeneratorSharedVariable,
19-
RandomStateSharedVariable,
20-
)
2116

2217

2318
try:
@@ -645,7 +640,7 @@ def __call__(
645640
@classmethod
646641
def rng_fn_scipy(
647642
cls,
648-
rng: np.random.Generator | np.random.RandomState,
643+
rng: np.random.Generator,
649644
loc: np.ndarray | float,
650645
scale: np.ndarray | float,
651646
size: list[int] | int | None,
@@ -1880,58 +1875,6 @@ def rng_fn(cls, rng, p, size):
18801875
categorical = CategoricalRV()
18811876

18821877

1883-
class RandIntRV(RandomVariable):
1884-
r"""A discrete uniform random variable.
1885-
1886-
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
1887-
1888-
"""
1889-
1890-
name = "randint"
1891-
signature = "(),()->()"
1892-
dtype = "int64"
1893-
_print_name = ("randint", "\\operatorname{randint}")
1894-
1895-
def __call__(self, low, high=None, size=None, **kwargs):
1896-
r"""Draw samples from a discrete uniform distribution.
1897-
1898-
Signature
1899-
---------
1900-
1901-
`() -> ()`
1902-
1903-
Parameters
1904-
----------
1905-
low
1906-
Lower boundary of the output interval. All values generated will
1907-
be greater than or equal to `low`, unless `high=None`, in which case
1908-
all values generated are greater than or equal to `0` and
1909-
smaller than `low` (exclusive).
1910-
high
1911-
Upper boundary of the output interval. All values generated
1912-
will be smaller than `high` (exclusive).
1913-
size
1914-
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
1915-
independent, identically distributed samples are
1916-
returned. Default is `None`, in which case a single
1917-
sample is returned.
1918-
1919-
"""
1920-
if high is None:
1921-
low, high = 0, low
1922-
return super().__call__(low, high, size=size, **kwargs)
1923-
1924-
def make_node(self, rng, *args, **kwargs):
1925-
if not isinstance(
1926-
getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable
1927-
):
1928-
raise TypeError("`randint` is only available for `RandomStateType`s")
1929-
return super().make_node(rng, *args, **kwargs)
1930-
1931-
1932-
randint = RandIntRV()
1933-
1934-
19351878
class IntegersRV(RandomVariable):
19361879
r"""A discrete uniform random variable.
19371880
@@ -1971,14 +1914,6 @@ def __call__(self, low, high=None, size=None, **kwargs):
19711914
low, high = 0, low
19721915
return super().__call__(low, high, size=size, **kwargs)
19731916

1974-
def make_node(self, rng, *args, **kwargs):
1975-
if not isinstance(
1976-
getattr(rng, "type", None),
1977-
RandomGeneratorType | RandomGeneratorSharedVariable,
1978-
):
1979-
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
1980-
return super().make_node(rng, *args, **kwargs)
1981-
19821917

19831918
integers = IntegersRV()
19841919

@@ -2201,7 +2136,6 @@ def permutation(x, **kwargs):
22012136
"permutation",
22022137
"choice",
22032138
"integers",
2204-
"randint",
22052139
"categorical",
22062140
"multinomial",
22072141
"betabinom",

pytensor/tensor/random/op.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
infer_static_shape,
2121
)
2222
from pytensor.tensor.blockwise import OpWithCoreShape
23-
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
23+
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
2424
from pytensor.tensor.random.utils import (
2525
compute_batch_shape,
2626
explicit_expand_dims,
@@ -324,9 +324,8 @@ def make_node(self, rng, size, *dist_params):
324324
325325
Parameters
326326
----------
327-
rng: RandomGeneratorType or RandomStateType
328-
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a
329-
new one, if `None`.
327+
rng: RandomGeneratorType
328+
Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
330329
size: int or Sequence
331330
NumPy-like size parameter.
332331
dtype: str
@@ -354,7 +353,7 @@ def make_node(self, rng, size, *dist_params):
354353
rng = pytensor.shared(np.random.default_rng())
355354
elif not isinstance(rng.type, RandomType):
356355
raise TypeError(
357-
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
356+
"The type of rng should be an instance of RandomGeneratorType "
358357
)
359358

360359
inferred_shape = self._infer_shape(size, dist_params)
@@ -436,14 +435,6 @@ def perform(self, node, inputs, output_storage):
436435
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)
437436

438437

439-
class RandomStateConstructor(AbstractRNGConstructor):
440-
random_type = RandomStateType()
441-
random_constructor = "RandomState"
442-
443-
444-
RandomState = RandomStateConstructor()
445-
446-
447438
class DefaultGeneratorMakerOp(AbstractRNGConstructor):
448439
random_type = RandomGeneratorType()
449440
random_constructor = "default_rng"

pytensor/tensor/random/type.py

-91
Original file line numberDiff line numberDiff line change
@@ -31,97 +31,6 @@ def may_share_memory(a: T, b: T):
3131
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
3232

3333

34-
class RandomStateType(RandomType[np.random.RandomState]):
35-
r"""A Type wrapper for `numpy.random.RandomState`.
36-
37-
The reason this exists (and `Generic` doesn't suffice) is that
38-
`RandomState` objects that would appear to be equal do not compare equal
39-
with the ``==`` operator.
40-
41-
This `Type` also works with a ``dict`` derived from
42-
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
43-
is explicitly set to ``True``.
44-
45-
"""
46-
47-
def __repr__(self):
48-
return "RandomStateType"
49-
50-
def filter(self, data, strict: bool = False, allow_downcast=None):
51-
"""
52-
XXX: This doesn't convert `data` to the same type of underlying RNG type
53-
as `self`. It really only checks that `data` is of the appropriate type
54-
to be a valid `RandomStateType`.
55-
56-
In other words, it serves as a `Type.is_valid_value` implementation,
57-
but, because the default `Type.is_valid_value` depends on
58-
`Type.filter`, we need to have it here to avoid surprising circular
59-
dependencies in sub-classes.
60-
"""
61-
if isinstance(data, np.random.RandomState):
62-
return data
63-
64-
if not strict and isinstance(data, dict):
65-
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
66-
state_keys = ["key", "pos"]
67-
68-
for key in gen_keys:
69-
if key not in data:
70-
raise TypeError()
71-
72-
for key in state_keys:
73-
if key not in data["state"]:
74-
raise TypeError()
75-
76-
state_key = data["state"]["key"]
77-
if state_key.shape == (624,) and state_key.dtype == np.uint32:
78-
# TODO: Add an option to convert to a `RandomState` instance?
79-
return data
80-
81-
raise TypeError()
82-
83-
@staticmethod
84-
def values_eq(a, b):
85-
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
86-
sb = b if isinstance(b, dict) else b.get_state(legacy=False)
87-
88-
def _eq(sa, sb):
89-
for key in sa:
90-
if isinstance(sa[key], dict):
91-
if not _eq(sa[key], sb[key]):
92-
return False
93-
elif isinstance(sa[key], np.ndarray):
94-
if not np.array_equal(sa[key], sb[key]):
95-
return False
96-
else:
97-
if sa[key] != sb[key]:
98-
return False
99-
100-
return True
101-
102-
return _eq(sa, sb)
103-
104-
def __eq__(self, other):
105-
return type(self) == type(other)
106-
107-
def __hash__(self):
108-
return hash(type(self))
109-
110-
111-
# Register `RandomStateType`'s C code for `ViewOp`.
112-
pytensor.compile.register_view_op_c_code(
113-
RandomStateType,
114-
"""
115-
Py_XDECREF(%(oname)s);
116-
%(oname)s = %(iname)s;
117-
Py_XINCREF(%(oname)s);
118-
""",
119-
1,
120-
)
121-
122-
random_state_type = RandomStateType()
123-
124-
12534
class RandomGeneratorType(RandomType[np.random.Generator]):
12635
r"""A Type wrapper for `numpy.random.Generator`.
12736

pytensor/tensor/random/utils.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,7 @@ def __init__(
209209
self,
210210
seed: int | None = None,
211211
namespace: ModuleType | None = None,
212-
rng_ctor: Literal[
213-
np.random.RandomState, np.random.Generator
214-
] = np.random.default_rng,
212+
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
215213
):
216214
if namespace is None:
217215
from pytensor.tensor.random import basic # pylint: disable=import-self
@@ -223,12 +221,6 @@ def __init__(
223221
self.default_instance_seed = seed
224222
self.state_updates = []
225223
self.gen_seedgen = np.random.SeedSequence(seed)
226-
227-
if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
228-
# The legacy state does not accept `SeedSequence`s directly
229-
def rng_ctor(seed):
230-
return np.random.RandomState(np.random.MT19937(seed))
231-
232224
self.rng_ctor = rng_ctor
233225

234226
def __getattr__(self, obj):

pytensor/tensor/random/var.py

+8-12
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,12 @@
33
import numpy as np
44

55
from pytensor.compile.sharedvalue import SharedVariable, shared_constructor
6-
from pytensor.tensor.random.type import random_generator_type, random_state_type
7-
8-
9-
class RandomStateSharedVariable(SharedVariable):
10-
def __str__(self):
11-
return self.name or f"RandomStateSharedVariable({self.container!r})"
6+
from pytensor.tensor.random.type import random_generator_type
127

138

149
class RandomGeneratorSharedVariable(SharedVariable):
1510
def __str__(self):
16-
return self.name or f"RandomGeneratorSharedVariable({self.container!r})"
11+
return self.name or f"RNG({self.container!r})"
1712

1813

1914
@shared_constructor.register(np.random.RandomState)
@@ -23,11 +18,12 @@ def randomgen_constructor(
2318
):
2419
r"""`SharedVariable` constructor for NumPy's `Generator` and/or `RandomState`."""
2520
if isinstance(value, np.random.RandomState):
26-
rng_sv_type = RandomStateSharedVariable
27-
rng_type = random_state_type
28-
elif isinstance(value, np.random.Generator):
29-
rng_sv_type = RandomGeneratorSharedVariable
30-
rng_type = random_generator_type
21+
raise TypeError(
22+
"`np.RandomState` is no longer supported in PyTensor. Use `np.random.Generator` instead."
23+
)
24+
25+
rng_sv_type = RandomGeneratorSharedVariable
26+
rng_type = random_generator_type
3127

3228
if not borrow:
3329
value = copy.deepcopy(value)

0 commit comments

Comments
 (0)