Skip to content

Commit 26009f7

Browse files
committed
Remove RandomState type
1 parent e9bf6f2 commit 26009f7

File tree

15 files changed

+76
-415
lines changed

15 files changed

+76
-415
lines changed

pytensor/link/jax/dispatch/random.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import jax
44
import numpy as np
5-
from numpy.random import Generator, RandomState
5+
from numpy.random import Generator
66
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
77
_coerce_to_uint32_array,
88
)
@@ -52,15 +52,6 @@ def assert_size_argument_jax_compatible(node):
5252
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
5353

5454

55-
@jax_typify.register(RandomState)
56-
def jax_typify_RandomState(state, **kwargs):
57-
state = state.get_state(legacy=False)
58-
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
59-
# XXX: Is this a reasonable approach?
60-
state["jax_state"] = state["state"]["key"][0:2]
61-
return state
62-
63-
6455
@jax_typify.register(Generator)
6556
def jax_typify_Generator(rng, **kwargs):
6657
state = rng.__getstate__()
@@ -184,7 +175,6 @@ def sample_fn(rng, size, dtype, *parameters):
184175
return sample_fn
185176

186177

187-
@jax_sample_fn.register(ptr.RandIntRV)
188178
@jax_sample_fn.register(ptr.IntegersRV)
189179
@jax_sample_fn.register(ptr.UniformRV)
190180
def jax_sample_fn_uniform(op):

pytensor/link/numba/dispatch/random.py

-4
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
)
2525
from pytensor.tensor import get_vector_length
2626
from pytensor.tensor.random.op import RandomVariable, RandomVariableWithCoreShape
27-
from pytensor.tensor.random.type import RandomStateType
2827
from pytensor.tensor.type_other import NoneTypeT
2928

3029

@@ -265,9 +264,6 @@ def numba_funcify_RandomVariable(op: RandomVariableWithCoreShape, node, **kwargs
265264

266265
[rv_node] = op.fgraph.apply_nodes
267266
rv_op: RandomVariable = rv_node.op
268-
rng_param = rv_op.rng_param(rv_node)
269-
if isinstance(rng_param.type, RandomStateType):
270-
raise TypeError("Numba does not support NumPy `RandomStateType`s")
271267
size = rv_op.size_param(rv_node)
272268
dist_params = rv_op.dist_params(rv_node)
273269
size_len = None if isinstance(size.type, NoneTypeT) else get_vector_length(size)

pytensor/tensor/random/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
import pytensor.tensor.random.rewriting
33
import pytensor.tensor.random.utils
44
from pytensor.tensor.random.basic import *
5-
from pytensor.tensor.random.op import RandomState, default_rng
5+
from pytensor.tensor.random.op import default_rng
66
from pytensor.tensor.random.utils import RandomStream

pytensor/tensor/random/basic.py

+25-75
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,9 @@
77
import pytensor
88
from pytensor.tensor.basic import arange, as_tensor_variable, constant
99
from pytensor.tensor.random.op import RandomVariable
10-
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
1110
from pytensor.tensor.random.utils import (
1211
broadcast_params,
1312
)
14-
from pytensor.tensor.random.var import (
15-
RandomGeneratorSharedVariable,
16-
RandomStateSharedVariable,
17-
)
1813

1914

2015
try:
@@ -605,7 +600,7 @@ def __call__(
605600
@classmethod
606601
def rng_fn_scipy(
607602
cls,
608-
rng: np.random.Generator | np.random.RandomState,
603+
rng: np.random.Generator,
609604
loc: np.ndarray | float,
610605
scale: np.ndarray | float,
611606
size: list[int] | int | None,
@@ -1548,7 +1543,7 @@ def __call__(self, n, p, size=None, **kwargs):
15481543
binomial = BinomialRV()
15491544

15501545

1551-
class NegBinomialRV(ScipyRandomVariable):
1546+
class NegBinomialRV(RandomVariable):
15521547
r"""A negative binomial discrete random variable.
15531548
15541549
The probability mass function for `nbinom` for the number :math:`k` of draws
@@ -1588,13 +1583,8 @@ def __call__(self, n, p, size=None, **kwargs):
15881583
"""
15891584
return super().__call__(n, p, size=size, **kwargs)
15901585

1591-
@classmethod
1592-
def rng_fn_scipy(cls, rng, n, p, size):
1593-
return stats.nbinom.rvs(n, p, size=size, random_state=rng)
1594-
15951586

1596-
nbinom = NegBinomialRV()
1597-
negative_binomial = NegBinomialRV()
1587+
negative_binomial = nbinom = NegBinomialRV()
15981588

15991589

16001590
class BetaBinomialRV(ScipyRandomVariable):
@@ -1842,58 +1832,6 @@ def rng_fn(cls, rng, p, size):
18421832
categorical = CategoricalRV()
18431833

18441834

1845-
class RandIntRV(RandomVariable):
1846-
r"""A discrete uniform random variable.
1847-
1848-
Only available for `RandomStateType`. Use `integers` with `RandomGeneratorType`\s.
1849-
1850-
"""
1851-
1852-
name = "randint"
1853-
signature = "(),()->()"
1854-
dtype = "int64"
1855-
_print_name = ("randint", "\\operatorname{randint}")
1856-
1857-
def __call__(self, low, high=None, size=None, **kwargs):
1858-
r"""Draw samples from a discrete uniform distribution.
1859-
1860-
Signature
1861-
---------
1862-
1863-
`() -> ()`
1864-
1865-
Parameters
1866-
----------
1867-
low
1868-
Lower boundary of the output interval. All values generated will
1869-
be greater than or equal to `low`, unless `high=None`, in which case
1870-
all values generated are greater than or equal to `0` and
1871-
smaller than `low` (exclusive).
1872-
high
1873-
Upper boundary of the output interval. All values generated
1874-
will be smaller than `high` (exclusive).
1875-
size
1876-
Sample shape. If the given size is `(m, n, k)`, then `m * n * k`
1877-
independent, identically distributed samples are
1878-
returned. Default is `None`, in which case a single
1879-
sample is returned.
1880-
1881-
"""
1882-
if high is None:
1883-
low, high = 0, low
1884-
return super().__call__(low, high, size=size, **kwargs)
1885-
1886-
def make_node(self, rng, *args, **kwargs):
1887-
if not isinstance(
1888-
getattr(rng, "type", None), RandomStateType | RandomStateSharedVariable
1889-
):
1890-
raise TypeError("`randint` is only available for `RandomStateType`s")
1891-
return super().make_node(rng, *args, **kwargs)
1892-
1893-
1894-
randint = RandIntRV()
1895-
1896-
18971835
class IntegersRV(RandomVariable):
18981836
r"""A discrete uniform random variable.
18991837
@@ -1933,14 +1871,6 @@ def __call__(self, low, high=None, size=None, **kwargs):
19331871
low, high = 0, low
19341872
return super().__call__(low, high, size=size, **kwargs)
19351873

1936-
def make_node(self, rng, *args, **kwargs):
1937-
if not isinstance(
1938-
getattr(rng, "type", None),
1939-
RandomGeneratorType | RandomGeneratorSharedVariable,
1940-
):
1941-
raise TypeError("`integers` is only available for `RandomGeneratorType`s")
1942-
return super().make_node(rng, *args, **kwargs)
1943-
19441874

19451875
integers = IntegersRV()
19461876

@@ -1974,7 +1904,28 @@ def rng_fn(self, *params):
19741904
p = None
19751905
else:
19761906
rng, a, p, replace, size = params
1977-
return rng.choice(a, size, replace, p)
1907+
1908+
batch_ndim = a.ndim - self.ndims_params[0]
1909+
1910+
if size is not None:
1911+
a = np.broadcast_to(a, size + a.shape[-self.ndims_params[0] :])
1912+
if p is not None:
1913+
p = np.broadcast_to(p, size + p.shape[-1:])
1914+
elif p is not None:
1915+
a, p = broadcast_params([a, p], self.ndims_params)
1916+
1917+
if batch_ndim:
1918+
# rng.choice does not have a concept of batch dimensionn
1919+
batch_shape = a.shape[:batch_ndim]
1920+
core_shape = a.shape[batch_ndim:-1]
1921+
out = np.empty(batch_shape + core_shape, dtype=a.dtype)
1922+
for idx in np.ndindex(batch_shape):
1923+
out[idx] = rng.choice(
1924+
a[idx], size=None, replace=replace, p=None if p is None else p[idx]
1925+
)
1926+
return out
1927+
else:
1928+
return rng.choice(a, size=size, replace=replace, p=p)
19781929

19791930

19801931
def choice(a, size=None, replace=True, p=None, rng=None):
@@ -2079,7 +2030,6 @@ def permutation(x, **kwargs):
20792030
"permutation",
20802031
"choice",
20812032
"integers",
2082-
"randint",
20832033
"categorical",
20842034
"multinomial",
20852035
"betabinom",

pytensor/tensor/random/op.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
infer_static_shape,
2020
)
2121
from pytensor.tensor.blockwise import OpWithCoreShape
22-
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
22+
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
2323
from pytensor.tensor.random.utils import (
2424
compute_batch_shape,
2525
explicit_expand_dims,
@@ -326,9 +326,8 @@ def make_node(self, rng, size, *dist_params):
326326
327327
Parameters
328328
----------
329-
rng: RandomGeneratorType or RandomStateType
330-
Existing PyTensor `Generator` or `RandomState` object to be used. Creates a
331-
new one, if `None`.
329+
rng: RandomGeneratorType
330+
Existing PyTensor `Generator` object to be used. Creates a new one, if `None`.
332331
size: int or Sequence
333332
NumPy-like size parameter.
334333
dtype: str
@@ -356,7 +355,7 @@ def make_node(self, rng, size, *dist_params):
356355
rng = pytensor.shared(np.random.default_rng())
357356
elif not isinstance(rng.type, RandomType):
358357
raise TypeError(
359-
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
358+
"The type of rng should be an instance of RandomGeneratorType "
360359
)
361360

362361
inferred_shape = self._infer_shape(size, dist_params)
@@ -435,14 +434,6 @@ def perform(self, node, inputs, output_storage):
435434
output_storage[0][0] = getattr(np.random, self.random_constructor)(seed=seed)
436435

437436

438-
class RandomStateConstructor(AbstractRNGConstructor):
439-
random_type = RandomStateType()
440-
random_constructor = "RandomState"
441-
442-
443-
RandomState = RandomStateConstructor()
444-
445-
446437
class DefaultGeneratorMakerOp(AbstractRNGConstructor):
447438
random_type = RandomGeneratorType()
448439
random_constructor = "default_rng"

pytensor/tensor/random/type.py

-91
Original file line numberDiff line numberDiff line change
@@ -31,97 +31,6 @@ def may_share_memory(a: T, b: T):
3131
return a._bit_generator is b._bit_generator # type: ignore[attr-defined]
3232

3333

34-
class RandomStateType(RandomType[np.random.RandomState]):
35-
r"""A Type wrapper for `numpy.random.RandomState`.
36-
37-
The reason this exists (and `Generic` doesn't suffice) is that
38-
`RandomState` objects that would appear to be equal do not compare equal
39-
with the ``==`` operator.
40-
41-
This `Type` also works with a ``dict`` derived from
42-
`RandomState.get_state(legacy=False)`, unless the ``strict`` argument to `Type.filter`
43-
is explicitly set to ``True``.
44-
45-
"""
46-
47-
def __repr__(self):
48-
return "RandomStateType"
49-
50-
def filter(self, data, strict: bool = False, allow_downcast=None):
51-
"""
52-
XXX: This doesn't convert `data` to the same type of underlying RNG type
53-
as `self`. It really only checks that `data` is of the appropriate type
54-
to be a valid `RandomStateType`.
55-
56-
In other words, it serves as a `Type.is_valid_value` implementation,
57-
but, because the default `Type.is_valid_value` depends on
58-
`Type.filter`, we need to have it here to avoid surprising circular
59-
dependencies in sub-classes.
60-
"""
61-
if isinstance(data, np.random.RandomState):
62-
return data
63-
64-
if not strict and isinstance(data, dict):
65-
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
66-
state_keys = ["key", "pos"]
67-
68-
for key in gen_keys:
69-
if key not in data:
70-
raise TypeError()
71-
72-
for key in state_keys:
73-
if key not in data["state"]:
74-
raise TypeError()
75-
76-
state_key = data["state"]["key"]
77-
if state_key.shape == (624,) and state_key.dtype == np.uint32:
78-
# TODO: Add an option to convert to a `RandomState` instance?
79-
return data
80-
81-
raise TypeError()
82-
83-
@staticmethod
84-
def values_eq(a, b):
85-
sa = a if isinstance(a, dict) else a.get_state(legacy=False)
86-
sb = b if isinstance(b, dict) else b.get_state(legacy=False)
87-
88-
def _eq(sa, sb):
89-
for key in sa:
90-
if isinstance(sa[key], dict):
91-
if not _eq(sa[key], sb[key]):
92-
return False
93-
elif isinstance(sa[key], np.ndarray):
94-
if not np.array_equal(sa[key], sb[key]):
95-
return False
96-
else:
97-
if sa[key] != sb[key]:
98-
return False
99-
100-
return True
101-
102-
return _eq(sa, sb)
103-
104-
def __eq__(self, other):
105-
return type(self) == type(other)
106-
107-
def __hash__(self):
108-
return hash(type(self))
109-
110-
111-
# Register `RandomStateType`'s C code for `ViewOp`.
112-
pytensor.compile.register_view_op_c_code(
113-
RandomStateType,
114-
"""
115-
Py_XDECREF(%(oname)s);
116-
%(oname)s = %(iname)s;
117-
Py_XINCREF(%(oname)s);
118-
""",
119-
1,
120-
)
121-
122-
random_state_type = RandomStateType()
123-
124-
12534
class RandomGeneratorType(RandomType[np.random.Generator]):
12635
r"""A Type wrapper for `numpy.random.Generator`.
12736

pytensor/tensor/random/utils.py

+1-9
Original file line numberDiff line numberDiff line change
@@ -209,9 +209,7 @@ def __init__(
209209
self,
210210
seed: int | None = None,
211211
namespace: ModuleType | None = None,
212-
rng_ctor: Literal[
213-
np.random.RandomState, np.random.Generator
214-
] = np.random.default_rng,
212+
rng_ctor: Literal[np.random.Generator] = np.random.default_rng,
215213
):
216214
if namespace is None:
217215
from pytensor.tensor.random import basic # pylint: disable=import-self
@@ -223,12 +221,6 @@ def __init__(
223221
self.default_instance_seed = seed
224222
self.state_updates = []
225223
self.gen_seedgen = np.random.SeedSequence(seed)
226-
227-
if isinstance(rng_ctor, type) and issubclass(rng_ctor, np.random.RandomState):
228-
# The legacy state does not accept `SeedSequence`s directly
229-
def rng_ctor(seed):
230-
return np.random.RandomState(np.random.MT19937(seed))
231-
232224
self.rng_ctor = rng_ctor
233225

234226
def __getattr__(self, obj):

0 commit comments

Comments
 (0)