Skip to content

Commit eea4186

Browse files
committed
Remove RandomVariable dtype input
1 parent d45f026 commit eea4186

File tree

7 files changed

+74
-61
lines changed

7 files changed

+74
-61
lines changed

pytensor/link/jax/dispatch/random.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -102,12 +102,12 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
102102
if None in out_size:
103103
assert_size_argument_jax_compatible(node)
104104

105-
def sample_fn(rng, size, dtype, *parameters):
105+
def sample_fn(rng, size, *parameters):
106106
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
107107

108108
else:
109109

110-
def sample_fn(rng, size, dtype, *parameters):
110+
def sample_fn(rng, size, *parameters):
111111
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
112112

113113
return sample_fn

pytensor/link/numba/dispatch/random.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,13 @@ def make_numba_random_fn(node, np_random_func):
9696
The functions generated here add parameter broadcasting and the ``size``
9797
argument to the Numba-supported scalar ``np.random`` functions.
9898
"""
99-
if not isinstance(node.inputs[0].type, RandomStateType):
99+
rng_param = node.op.rng_param(node)
100+
if not isinstance(rng_param.type, RandomStateType):
100101
raise TypeError("Numba does not support NumPy `Generator`s")
101102

102-
tuple_size = int(get_vector_length(node.inputs[1]))
103-
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
103+
tuple_size = int(get_vector_length(node.op.size_param(node)))
104+
dist_params = node.op.dist_params(node)
105+
size_dims = tuple_size - max(i.ndim for i in dist_params)
104106

105107
# Make a broadcast-capable version of the Numba supported scalar sampling
106108
# function
@@ -120,13 +122,12 @@ def make_numba_random_fn(node, np_random_func):
120122
"size_dims",
121123
"rng",
122124
"size",
123-
"dtype",
124125
],
125126
suffix_sep="_",
126127
)
127128

128129
bcast_fn_input_names = ", ".join(
129-
[unique_names(i, force_unique=True) for i in node.inputs[3:]]
130+
[unique_names(i, force_unique=True) for i in dist_params]
130131
)
131132
bcast_fn_global_env = {
132133
"np_random_func": np_random_func,
@@ -143,7 +144,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
143144
)
144145

145146
random_fn_input_names = ", ".join(
146-
["rng", "size", "dtype"] + [unique_names(i) for i in node.inputs[3:]]
147+
["rng", "size"] + [unique_names(i) for i in dist_params]
147148
)
148149

149150
# Now, create a Numba JITable function that implements the `size` parameter
@@ -241,11 +242,12 @@ def create_numba_random_fn(
241242
np_global_env["numba_vectorize"] = numba_basic.numba_vectorize
242243

243244
unique_names = unique_name_generator(
244-
[np_random_fn_name, *np_global_env.keys(), "rng", "size", "dtype"],
245+
[np_random_fn_name, *np_global_env.keys(), "rng", "size"],
245246
suffix_sep="_",
246247
)
247248

248-
np_names = [unique_names(i, force_unique=True) for i in node.inputs[3:]]
249+
dist_params = node.op.dist_params(node)
250+
np_names = [unique_names(i, force_unique=True) for i in dist_params]
249251
np_input_names = ", ".join(np_names)
250252
np_random_fn_src = f"""
251253
@numba_vectorize
@@ -307,7 +309,7 @@ def numba_funcify_CategoricalRV(op, node, **kwargs):
307309
p_ndim = node.inputs[-1].ndim
308310

309311
@numba_basic.numba_njit
310-
def categorical_rv(rng, size, dtype, p):
312+
def categorical_rv(rng, size, p):
311313
if not size_len:
312314
size_tpl = p.shape[:-1]
313315
else:
@@ -332,14 +334,14 @@ def categorical_rv(rng, size, dtype, p):
332334
@numba_funcify.register(ptr.DirichletRV)
333335
def numba_funcify_DirichletRV(op, node, **kwargs):
334336
out_dtype = node.outputs[1].type.numpy_dtype
335-
alphas_ndim = node.inputs[3].type.ndim
337+
alphas_ndim = node.op.dist_params(node)[0].type.ndim
336338
neg_ind_shape_len = -alphas_ndim + 1
337-
size_len = int(get_vector_length(node.inputs[1]))
339+
size_len = int(get_vector_length(node.op.size_param(node)))
338340

339341
if alphas_ndim > 1:
340342

341343
@numba_basic.numba_njit
342-
def dirichlet_rv(rng, size, dtype, alphas):
344+
def dirichlet_rv(rng, size, alphas):
343345
if size_len > 0:
344346
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
345347
if (
@@ -362,7 +364,7 @@ def dirichlet_rv(rng, size, dtype, alphas):
362364
else:
363365

364366
@numba_basic.numba_njit
365-
def dirichlet_rv(rng, size, dtype, alphas):
367+
def dirichlet_rv(rng, size, alphas):
366368
size = numba_ndarray.to_fixed_tuple(size, size_len)
367369
return (rng, np.random.dirichlet(alphas, size))
368370

pytensor/tensor/random/op.py

+33-24
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
normalize_size_param,
2727
)
2828
from pytensor.tensor.shape import shape_tuple
29-
from pytensor.tensor.type import TensorType, all_dtypes
29+
from pytensor.tensor.type import TensorType
3030
from pytensor.tensor.type_other import NoneConst
3131
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
3232
from pytensor.tensor.variable import TensorVariable
@@ -64,7 +64,7 @@ def __init__(
6464
signature: str
6565
Numpy-like vectorized signature of the random variable.
6666
dtype: str (optional)
67-
The dtype of the sampled output. If the value ``"floatX"`` is
67+
The default dtype of the sampled output. If the value ``"floatX"`` is
6868
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
6969
``None`` (the default), the `dtype` keyword must be set when
7070
`RandomVariable.make_node` is called.
@@ -289,8 +289,8 @@ def extract_batch_shape(p, ps, n):
289289
return shape
290290

291291
def infer_shape(self, fgraph, node, input_shapes):
292-
_, size, _, *dist_params = node.inputs
293-
_, size_shape, _, *param_shapes = input_shapes
292+
_, size, *dist_params = node.inputs
293+
_, size_shape, *param_shapes = input_shapes
294294

295295
try:
296296
size_len = get_vector_length(size)
@@ -304,14 +304,34 @@ def infer_shape(self, fgraph, node, input_shapes):
304304
return [None, list(shape)]
305305

306306
def __call__(self, *args, size=None, name=None, rng=None, dtype=None, **kwargs):
307-
res = super().__call__(rng, size, dtype, *args, **kwargs)
307+
if dtype is None:
308+
dtype = self.dtype
309+
if dtype == "floatX":
310+
dtype = config.floatX
311+
312+
# We need to recreate the Op with the right dtype
313+
if dtype != self.dtype:
314+
# Check we are not switching from float to int
315+
if self.dtype is not None:
316+
if dtype.startswith("float") != self.dtype.startswith("float"):
317+
raise ValueError(
318+
f"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}"
319+
)
320+
props = self._props_dict()
321+
props["dtype"] = dtype
322+
new_op = type(self)(**props)
323+
return new_op.__call__(
324+
*args, size=size, name=name, rng=rng, dtype=dtype, **kwargs
325+
)
326+
327+
res = super().__call__(rng, size, *args, **kwargs)
308328

309329
if name is not None:
310330
res.name = name
311331

312332
return res
313333

314-
def make_node(self, rng, size, dtype, *dist_params):
334+
def make_node(self, rng, size, *dist_params):
315335
"""Create a random variable node.
316336
317337
Parameters
@@ -351,22 +371,11 @@ def make_node(self, rng, size, dtype, *dist_params):
351371

352372
shape = self._infer_shape(size, dist_params)
353373
_, static_shape = infer_static_shape(shape)
354-
dtype = self.dtype or dtype
355374

356-
if dtype == "floatX":
357-
dtype = config.floatX
358-
elif dtype is None or (isinstance(dtype, str) and dtype not in all_dtypes):
359-
raise TypeError("dtype is unspecified")
360-
361-
if isinstance(dtype, str):
362-
dtype_idx = constant(all_dtypes.index(dtype), dtype="int64")
363-
else:
364-
dtype_idx = constant(dtype, dtype="int64")
365-
366-
dtype = all_dtypes[dtype_idx.data]
367-
368-
inputs = (rng, size, dtype_idx, *dist_params)
375+
dtype = self.dtype
369376
out_var = TensorType(dtype=dtype, shape=static_shape)()
377+
378+
inputs = (rng, size, *dist_params)
370379
outputs = (rng.type(), out_var)
371380

372381
return Apply(self, inputs, outputs)
@@ -381,12 +390,12 @@ def size_param(self, node) -> Variable:
381390

382391
def dist_params(self, node) -> Sequence[Variable]:
383392
"""Return the node inpust corresponding to dist params"""
384-
return node.inputs[3:]
393+
return node.inputs[2:]
385394

386395
def perform(self, node, inputs, outputs):
387396
rng_var_out, smpl_out = outputs
388397

389-
rng, size, dtype, *args = inputs
398+
rng, size, *args = inputs
390399

391400
out_var = node.outputs[1]
392401

@@ -462,7 +471,7 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):
462471

463472
@_vectorize_node.register(RandomVariable)
464473
def vectorize_random_variable(
465-
op: RandomVariable, node: Apply, rng, size, dtype, *new_dist_params
474+
op: RandomVariable, node: Apply, rng, size, *new_dist_params
466475
) -> Apply:
467476
# If size was provided originally and a new size hasn't been provided,
468477
# We extend it to accommodate the new input batch dimensions.
@@ -494,4 +503,4 @@ def vectorize_random_variable(
494503
new_size_dims = new_size[:new_ndim]
495504
size = concatenate([new_size_dims, size])
496505

497-
return op.make_node(rng, size, dtype, *new_dist_params)
506+
return op.make_node(rng, size, *new_dist_params)

pytensor/tensor/random/rewriting/basic.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
8181
if not isinstance(node.op, RandomVariable):
8282
return
8383

84-
rng, size, dtype, *dist_params = node.inputs
84+
rng, size, *dist_params = node.inputs
8585

8686
dist_params = broadcast_params(dist_params, node.op.ndims_params)
8787

@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
105105
else:
106106
return
107107

108-
new_node = node.op.make_node(rng, None, dtype, *dist_params)
108+
new_node = node.op.make_node(rng, None, *dist_params)
109109

110110
if config.compute_test_value != "off":
111111
compute_test_value(new_node)
@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
141141
return False
142142

143143
rv_op = rv_node.op
144-
rng, size, dtype, *dist_params = rv_node.inputs
144+
rng, size, *dist_params = rv_node.inputs
145145
rv = rv_node.default_output()
146146

147147
# Check that Dimshuffle does not affect support dims
@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
185185
)
186186
new_dist_params.append(param.dimshuffle(param_new_order))
187187

188-
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
188+
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
189189

190190
if config.compute_test_value != "off":
191191
compute_test_value(new_node)
@@ -233,7 +233,7 @@ def is_nd_advanced_idx(idx, dtype):
233233
return None
234234

235235
rv_op = rv_node.op
236-
rng, size, dtype, *dist_params = rv_node.inputs
236+
rng, size, *dist_params = rv_node.inputs
237237

238238
# Parse indices
239239
idx_list = getattr(subtensor_op, "idx_list", None)
@@ -346,7 +346,7 @@ def is_nd_advanced_idx(idx, dtype):
346346
new_dist_params.append(batch_param[tuple(batch_indices)])
347347

348348
# Create new RV
349-
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
349+
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
350350
new_rv = new_node.default_output()
351351

352352
copy_stack_trace(rv, new_rv)

tests/tensor/random/rewriting/test_basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ def __init__(self, extra, *args, **kwargs):
111111
self.extra = extra
112112
super().__init__(*args, **kwargs)
113113

114-
def make_node(self, rng, size, dtype, sigma):
115-
return super().make_node(rng, size, dtype, sigma)
114+
def make_node(self, rng, size, sigma):
115+
return super().make_node(rng, size, sigma)
116116

117117
def rng_fn(self, rng, sigma, size):
118118
return rng.normal(scale=sigma, size=size)

tests/tensor/random/test_basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1407,12 +1407,12 @@ def test_choice_samples():
14071407

14081408
def test_choice_infer_shape():
14091409
node = choice([0, 1]).owner
1410-
res = node.op._infer_shape((), node.inputs[3:], None)
1410+
res = node.op._infer_shape((), node.inputs[2:], None)
14111411
assert tuple(res.eval()) == ()
14121412

14131413
node = choice([0, 1]).owner
14141414
res = node.op._infer_shape(
1415-
(), node.inputs[3:], (node.inputs[3].shape, node.inputs[4].shape)
1415+
(), node.inputs[2:], (node.inputs[2].shape, node.inputs[3].shape)
14161416
)
14171417
assert tuple(res.eval()) == ()
14181418

tests/tensor/random/test_op.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
import pytensor.tensor as pt
55
from pytensor import config, function
6-
from pytensor.gradient import NullTypeGradError, grad
76
from pytensor.graph.replace import vectorize_graph
87
from pytensor.raise_op import Assert
98
from pytensor.tensor.math import eq
109
from pytensor.tensor.random import normal
1110
from pytensor.tensor.random.basic import NormalRV
1211
from pytensor.tensor.random.op import RandomState, RandomVariable, default_rng
1312
from pytensor.tensor.shape import specify_shape
14-
from pytensor.tensor.type import all_dtypes, iscalar, tensor
13+
from pytensor.tensor.type import iscalar, tensor
1514

1615

1716
@pytest.fixture(scope="function", autouse=False)
@@ -72,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags):
7271
rv_shape = rv._infer_shape(pt.constant([]), (), [])
7372
assert rv_shape.equals(pt.constant([], dtype="int64"))
7473

75-
# Integer-specified `dtype`
76-
dtype_1 = all_dtypes[1]
77-
rv_node = rv.make_node(None, None, 1)
78-
rv_out = rv_node.outputs[1]
79-
rv_out.tag.test_value = 1
80-
81-
assert rv_out.dtype == dtype_1
82-
83-
with pytest.raises(NullTypeGradError):
84-
grad(rv_out, [rv_node.inputs[0]])
74+
# `dtype` is respected
75+
rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
76+
with config.change_flags(compute_test_value="off"):
77+
rv_out = rv()
78+
assert rv_out.dtype == "int32"
79+
rv_out = rv(dtype="int64")
80+
assert rv_out.dtype == "int64"
81+
82+
with pytest.raises(
83+
ValueError,
84+
match="Cannot change the dtype of a normal RV from int32 to float32",
85+
):
86+
assert rv(dtype="float32").dtype == "float32"
8587

8688

8789
def test_RandomVariable_bcast(strict_test_value_flags):

0 commit comments

Comments
 (0)