Skip to content

Commit 13807b4

Browse files
committed
Distinguish between size=None and size=() in RandomVariables
1 parent bb5053b commit 13807b4

File tree

6 files changed

+100
-88
lines changed

6 files changed

+100
-88
lines changed

pytensor/link/numba/dispatch/random.py

+32-14
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
)
2222
from pytensor.tensor.basic import get_vector_length
2323
from pytensor.tensor.random.type import RandomStateType
24+
from pytensor.tensor.type_other import NoneTypeT
2425

2526

2627
class RandomStateNumbaType(types.Type):
@@ -100,9 +101,13 @@ def make_numba_random_fn(node, np_random_func):
100101
if not isinstance(rng_param.type, RandomStateType):
101102
raise TypeError("Numba does not support NumPy `Generator`s")
102103

103-
tuple_size = int(get_vector_length(node.op.size_param(node)))
104+
size_param = node.op.size_param(node)
105+
size_len = (
106+
None
107+
if isinstance(size_param.type, NoneTypeT)
108+
else int(get_vector_length(node.op.size_param(node)))
109+
)
104110
dist_params = node.op.dist_params(node)
105-
size_dims = tuple_size - max(i.ndim for i in dist_params)
106111

107112
# Make a broadcast-capable version of the Numba supported scalar sampling
108113
# function
@@ -118,7 +123,7 @@ def make_numba_random_fn(node, np_random_func):
118123
"np_random_func",
119124
"numba_vectorize",
120125
"to_fixed_tuple",
121-
"tuple_size",
126+
"size_len",
122127
"size_dims",
123128
"rng",
124129
"size",
@@ -154,10 +159,12 @@ def {bcast_fn_name}({bcast_fn_input_names}):
154159
"out_dtype": out_dtype,
155160
}
156161

157-
if tuple_size > 0:
162+
if size_len is not None:
163+
size_dims = size_len - max(i.ndim for i in dist_params)
164+
158165
random_fn_body = dedent(
159166
f"""
160-
size = to_fixed_tuple(size, tuple_size)
167+
size = to_fixed_tuple(size, size_len)
161168
162169
data = np.empty(size, dtype=out_dtype)
163170
for i in np.ndindex(size[:size_dims]):
@@ -169,7 +176,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
169176
{
170177
"np": np,
171178
"to_fixed_tuple": numba_ndarray.to_fixed_tuple,
172-
"tuple_size": tuple_size,
179+
"size_len": size_len,
173180
"size_dims": size_dims,
174181
}
175182
)
@@ -305,19 +312,24 @@ def body_fn(a):
305312
@numba_funcify.register(ptr.CategoricalRV)
306313
def numba_funcify_CategoricalRV(op, node, **kwargs):
307314
out_dtype = node.outputs[1].type.numpy_dtype
308-
size_len = int(get_vector_length(node.inputs[1]))
315+
size_param = node.op.size_param(node)
316+
size_len = (
317+
None
318+
if isinstance(size_param.type, NoneTypeT)
319+
else int(get_vector_length(size_param))
320+
)
309321
p_ndim = node.inputs[-1].ndim
310322

311323
@numba_basic.numba_njit
312324
def categorical_rv(rng, size, p):
313-
if not size_len:
325+
if size_len is None:
314326
size_tpl = p.shape[:-1]
315327
else:
316328
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
317329
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
318330

319331
# Workaround https://github.com/numba/numba/issues/8975
320-
if not size_len and p_ndim == 1:
332+
if size_len is None and p_ndim == 1:
321333
unif_samples = np.asarray(np.random.uniform(0, 1))
322334
else:
323335
unif_samples = np.random.uniform(0, 1, size_tpl)
@@ -336,22 +348,27 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
336348
out_dtype = node.outputs[1].type.numpy_dtype
337349
alphas_ndim = node.op.dist_params(node)[0].type.ndim
338350
neg_ind_shape_len = -alphas_ndim + 1
339-
size_len = int(get_vector_length(node.op.size_param(node)))
351+
size_param = node.op.size_param(node)
352+
size_len = (
353+
None
354+
if isinstance(size_param.type, NoneTypeT)
355+
else int(get_vector_length(size_param))
356+
)
340357

341358
if alphas_ndim > 1:
342359

343360
@numba_basic.numba_njit
344361
def dirichlet_rv(rng, size, alphas):
345-
if size_len > 0:
362+
if size_len is None:
363+
samples_shape = alphas.shape
364+
else:
346365
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
347366
if (
348367
0 < alphas.ndim - 1 <= len(size_tpl)
349368
and size_tpl[neg_ind_shape_len:] != alphas.shape[:-1]
350369
):
351370
raise ValueError("Parameters shape and size do not match.")
352371
samples_shape = size_tpl + alphas.shape[-1:]
353-
else:
354-
samples_shape = alphas.shape
355372

356373
res = np.empty(samples_shape, dtype=out_dtype)
357374
alphas_bcast = np.broadcast_to(alphas, samples_shape)
@@ -365,7 +382,8 @@ def dirichlet_rv(rng, size, alphas):
365382

366383
@numba_basic.numba_njit
367384
def dirichlet_rv(rng, size, alphas):
368-
size = numba_ndarray.to_fixed_tuple(size, size_len)
385+
if size_len is not None:
386+
size = numba_ndarray.to_fixed_tuple(size, size_len)
369387
return (rng, np.random.dirichlet(alphas, size))
370388

371389
return dirichlet_rv

pytensor/tensor/random/basic.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -874,12 +874,12 @@ def rng_fn(cls, rng, mean, cov, size):
874874
# multivariate normals (or any other multivariate distributions),
875875
# so we need to implement that here
876876

877-
size = tuple(size or ())
878-
if size:
877+
if size is None:
878+
mean, cov = broadcast_params([mean, cov], [1, 2])
879+
else:
880+
size = tuple(size)
879881
mean = np.broadcast_to(mean, size + mean.shape[-1:])
880882
cov = np.broadcast_to(cov, size + cov.shape[-2:])
881-
else:
882-
mean, cov = broadcast_params([mean, cov], [1, 2])
883883

884884
res = np.empty(mean.shape)
885885
for idx in np.ndindex(mean.shape[:-1]):
@@ -1760,13 +1760,12 @@ def __call__(self, n, p, size=None, **kwargs):
17601760
@classmethod
17611761
def rng_fn(cls, rng, n, p, size):
17621762
if n.ndim > 0 or p.ndim > 1:
1763-
size = tuple(size or ())
1764-
1765-
if size:
1763+
if size is None:
1764+
n, p = broadcast_params([n, p], [0, 1])
1765+
else:
1766+
size = tuple(size)
17661767
n = np.broadcast_to(n, size)
17671768
p = np.broadcast_to(p, size + p.shape[-1:])
1768-
else:
1769-
n, p = broadcast_params([n, p], [0, 1])
17701769

17711770
res = np.empty(p.shape, dtype=cls.dtype)
17721771
for idx in np.ndindex(p.shape[:-1]):

pytensor/tensor/random/op.py

+12-30
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
as_tensor_variable,
1616
concatenate,
1717
constant,
18-
get_underlying_scalar_constant_value,
1918
get_vector_length,
2019
infer_static_shape,
2120
)
@@ -27,7 +26,7 @@
2726
)
2827
from pytensor.tensor.shape import shape_tuple
2928
from pytensor.tensor.type import TensorType
30-
from pytensor.tensor.type_other import NoneConst
29+
from pytensor.tensor.type_other import NoneConst, NoneTypeT
3130
from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
3231
from pytensor.tensor.variable import TensorVariable
3332

@@ -198,10 +197,10 @@ def batch_ndim(self, node):
198197

199198
def _infer_shape(
200199
self,
201-
size: TensorVariable,
200+
size: TensorVariable | Variable,
202201
dist_params: Sequence[TensorVariable],
203202
param_shapes: Sequence[tuple[Variable, ...]] | None = None,
204-
) -> TensorVariable | tuple[ScalarVariable, ...]:
203+
) -> tuple[ScalarVariable | TensorVariable, ...]:
205204
"""Compute the output shape given the size and distribution parameters.
206205
207206
Parameters
@@ -227,9 +226,9 @@ def _infer_shape(
227226
self._supp_shape_from_params(dist_params, param_shapes=param_shapes)
228227
)
229228

230-
size_len = get_vector_length(size)
229+
if not isinstance(size.type, NoneTypeT):
230+
size_len = get_vector_length(size)
231231

232-
if size_len > 0:
233232
# Fail early when size is incompatible with parameters
234233
for i, (param, param_ndim_supp) in enumerate(
235234
zip(dist_params, self.ndims_params)
@@ -283,21 +282,11 @@ def extract_batch_shape(p, ps, n):
283282

284283
shape = batch_shape + supp_shape
285284

286-
if not shape:
287-
shape = constant([], dtype="int64")
288-
289285
return shape
290286

291287
def infer_shape(self, fgraph, node, input_shapes):
292288
_, size, *dist_params = node.inputs
293-
_, size_shape, *param_shapes = input_shapes
294-
295-
try:
296-
size_len = get_vector_length(size)
297-
except ValueError:
298-
size_len = get_underlying_scalar_constant_value(size_shape[0])
299-
300-
size = tuple(size[n] for n in range(size_len))
289+
_, _, *param_shapes = input_shapes
301290

302291
shape = self._infer_shape(size, dist_params, param_shapes=param_shapes)
303292

@@ -369,8 +358,8 @@ def make_node(self, rng, size, *dist_params):
369358
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
370359
)
371360

372-
shape = self._infer_shape(size, dist_params)
373-
_, static_shape = infer_static_shape(shape)
361+
inferred_shape = self._infer_shape(size, dist_params)
362+
_, static_shape = infer_static_shape(inferred_shape)
374363

375364
dtype = self.dtype
376365
out_var = TensorType(dtype=dtype, shape=static_shape)()
@@ -397,16 +386,7 @@ def perform(self, node, inputs, outputs):
397386

398387
rng, size, *args = inputs
399388

400-
# If `size == []`, that means no size is enforced, and NumPy is trusted
401-
# to draw the appropriate number of samples, NumPy uses `size=None` to
402-
# represent that. Otherwise, NumPy expects a tuple.
403-
if np.size(size) == 0:
404-
size = None
405-
else:
406-
size = tuple(size)
407-
408-
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng`
409-
# otherwise.
389+
# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
410390
if not self.inplace:
411391
rng = copy(rng)
412392

@@ -474,7 +454,9 @@ def vectorize_random_variable(
474454

475455
original_dist_params = op.dist_params(node)
476456
old_size = op.size_param(node)
477-
len_old_size = get_vector_length(old_size)
457+
len_old_size = (
458+
None if isinstance(old_size.type, NoneTypeT) else get_vector_length(old_size)
459+
)
478460

479461
original_expanded_dist_params = explicit_expand_dims(
480462
original_dist_params, op.ndims_params, len_old_size

pytensor/tensor/random/utils.py

+17-19
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
from pytensor.compile.sharedvalue import shared
1010
from pytensor.graph.basic import Constant, Variable
1111
from pytensor.scalar import ScalarVariable
12-
from pytensor.tensor import get_vector_length
13-
from pytensor.tensor.basic import as_tensor_variable, cast, constant
12+
from pytensor.tensor import NoneConst, get_vector_length
13+
from pytensor.tensor.basic import as_tensor_variable, cast
1414
from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
1515
from pytensor.tensor.math import maximum
1616
from pytensor.tensor.shape import shape_padleft, specify_shape
@@ -124,17 +124,15 @@ def broadcast_params(params, ndims_params):
124124
def explicit_expand_dims(
125125
params: Sequence[TensorVariable],
126126
ndim_params: Sequence[int],
127-
size_length: int = 0,
127+
size_length: int | None = None,
128128
) -> list[TensorVariable]:
129129
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
130130

131131
batch_dims = [
132132
param.type.ndim - ndim_param for param, ndim_param in zip(params, ndim_params)
133133
]
134134

135-
if size_length:
136-
# NOTE: PyTensor is currently treating zero-length size as size=None, which is not what Numpy does
137-
# See: https://github.com/pymc-devs/pytensor/issues/568
135+
if size_length is not None:
138136
max_batch_dims = size_length
139137
else:
140138
max_batch_dims = max(batch_dims, default=0)
@@ -152,37 +150,37 @@ def explicit_expand_dims(
152150
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
153151
params = explicit_expand_dims(params, ndims_params)
154152
batch_params = [
155-
param[(..., *((0,) for _ in range(core_ndim)))]
153+
param[(..., *(0,) * core_ndim)]
156154
for param, core_ndim in zip(params, ndims_params)
157155
]
158156
return broadcast_arrays(*batch_params)[0].shape
159157

160158

161159
def normalize_size_param(
162-
size: int | np.ndarray | Variable | Sequence | None,
160+
shape: int | np.ndarray | Variable | Sequence | None,
163161
) -> Variable:
164162
"""Create an PyTensor value for a ``RandomVariable`` ``size`` parameter."""
165-
if size is None:
166-
size = constant([], dtype="int64")
167-
elif isinstance(size, int):
168-
size = as_tensor_variable([size], ndim=1)
169-
elif not isinstance(size, np.ndarray | Variable | Sequence):
163+
if shape is None or NoneConst.equals(shape):
164+
return NoneConst
165+
elif isinstance(shape, int):
166+
shape = as_tensor_variable([shape], ndim=1)
167+
elif not isinstance(shape, np.ndarray | Variable | Sequence):
170168
raise TypeError(
171169
"Parameter size must be None, an integer, or a sequence with integers."
172170
)
173171
else:
174-
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")
172+
shape = cast(as_tensor_variable(shape, ndim=1, dtype="int64"), "int64")
175173

176-
if not isinstance(size, Constant):
174+
if not isinstance(shape, Constant):
177175
# This should help ensure that the length of non-constant `size`s
178176
# will be available after certain types of cloning (e.g. the kind
179177
# `Scan` performs)
180-
size = specify_shape(size, (get_vector_length(size),))
178+
shape = specify_shape(shape, (get_vector_length(shape),))
181179

182-
assert not any(s is None for s in size.type.shape)
183-
assert size.dtype in int_dtypes
180+
assert not any(s is None for s in shape.type.shape)
181+
assert shape.dtype in int_dtypes
184182

185-
return size
183+
return shape
186184

187185

188186
class RandomStream:

0 commit comments

Comments
 (0)