Skip to content

Don't try to infer support shape of multivariate RVs by default #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 60 additions & 21 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
import scipy.stats as stats

import pytensor
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.random.op import RandomVariable, default_supp_shape_from_params
from pytensor.tensor.basic import arange, as_tensor_variable
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from pytensor.tensor.random.utils import broadcast_params
from pytensor.tensor.random.utils import (
broadcast_params,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.random.var import (
RandomGeneratorSharedVariable,
RandomStateSharedVariable,
Expand Down Expand Up @@ -855,6 +858,14 @@ class MvNormalRV(RandomVariable):
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)

def __call__(self, mean=None, cov=None, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.

Expand Down Expand Up @@ -933,6 +944,14 @@ class DirichletRV(RandomVariable):
dtype = "floatX"
_print_name = ("Dirichlet", "\\operatorname{Dirichlet}")

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)

def __call__(self, alphas, size=None, **kwargs):
r"""Draw samples from a dirichlet distribution.

Expand Down Expand Up @@ -1776,9 +1795,12 @@ def __call__(self, n, p, size=None, **kwargs):
"""
return super().__call__(n, p, size=size, **kwargs)

def _supp_shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
return default_supp_shape_from_params(
self.ndim_supp, dist_params, rep_param_idx, param_shapes
def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=1,
)

@classmethod
Expand Down Expand Up @@ -2050,18 +2072,15 @@ class PermutationRV(RandomVariable):

@classmethod
def rng_fn(cls, rng, x, size):
return rng.permutation(x if x.ndim > 0 else x.item())

def _infer_shape(self, size, dist_params, param_shapes=None):
param_shapes = param_shapes or [p.shape for p in dist_params]

(x,) = dist_params
(x_shape,) = param_shapes

if x.ndim == 0:
return (x,)
else:
return x_shape
return rng.permutation(x)

def _supp_shape_from_params(self, dist_params, param_shapes=None):
return supp_shape_from_ref_param_shape(
ndim_supp=self.ndim_supp,
dist_params=dist_params,
param_shapes=param_shapes,
ref_param_idx=0,
)

def __call__(self, x, **kwargs):
r"""Randomly permute a sequence or a range of values.
Expand All @@ -2074,15 +2093,35 @@ def __call__(self, x, **kwargs):
Parameters
----------
x
If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
shuffle its elements randomly.
Elements to be shuffled.

"""
x = as_tensor_variable(x)
return super().__call__(x, dtype=x.dtype, **kwargs)


permutation = PermutationRV()
_permutation = PermutationRV()


def permutation(x, **kwargs):
r"""Randomly permute a sequence or a range of values.

Signature
---------

`(x) -> (x)`

Parameters
----------
x
If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
shuffle its elements randomly.

"""
x = as_tensor_variable(x)
if x.type.ndim == 0:
x = arange(x)
return _permutation(x, **kwargs)


__all__ = [
Expand Down
125 changes: 39 additions & 86 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,64 +24,6 @@
from pytensor.tensor.var import TensorVariable


def default_supp_shape_from_params(
ndim_supp: int,
dist_params: Sequence[Variable],
rep_param_idx: int = 0,
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Infer the dimensions for the output of a `RandomVariable`.

This is a function that derives a random variable's support
shape/dimensions from one of its parameters.

XXX: It's not always possible to determine a random variable's support
shape from its parameters, so this function has fundamentally limited
applicability and must be replaced by custom logic in such cases.

XXX: This function is not expected to handle `ndim_supp = 0` (i.e.
scalars), since that is already definitively handled in the `Op` that
calls this.

TODO: Consider using `pytensor.compile.ops.shape_i` alongside `ShapeFeature`.

Parameters
----------
ndim_supp: int
Total number of dimensions for a single draw of the random variable
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `pytensor.graph.basic.Variable`
The distribution parameters.
rep_param_idx: int (optional)
The index of the distribution parameter to use as a reference
In other words, a parameter in `dist_param` with a shape corresponding
to the support's shape.
The default is the first parameter (i.e. the value 0).
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.

Results
-------
out: a tuple representing the support shape for a distribution with the
given `dist_params`.

"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None:
ref_param = param_shapes[rep_param_idx]
return (ref_param[-ndim_supp],)
else:
ref_param = dist_params[rep_param_idx]
if ref_param.ndim < ndim_supp:
raise ValueError(
"Reference parameter does not match the "
f"expected dimensions; {ref_param} has less than {ndim_supp} dim(s)."
)
return ref_param.shape[-ndim_supp:]


class RandomVariable(Op):
"""An `Op` that produces a sample from a random variable.

Expand Down Expand Up @@ -151,15 +93,29 @@ def __init__(
if self.inplace:
self.destroy_map = {0: [0]}

def _supp_shape_from_params(self, dist_params, **kwargs):
"""Determine the support shape of a `RandomVariable`'s output given its parameters.
def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.

This does *not* consider the extra dimensions added by the `size` parameter
or independent (batched) parameters.

Defaults to `param_supp_shape_fn`.
When provided, `param_shapes` should be given preference over `[d.shape for d in dist_params]`,
as it will avoid redundancies in PyTensor shape inference.

Examples
--------
Common multivariate `RandomVariable`s derive their support shapes implicitly from the
last dimension of some of their parameters. For example `multivariate_normal` support shape
corresponds to the last dimension of the mean or covariance parameters, `support_shape=(mu.shape[-1])`.
For this case the helper `pytensor.tensor.random.utils.supp_shape_from_ref_param_shape` can be used.

Other variables have fixed support shape such as `support_shape=(2,)` or it is determined by the
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`.
"""
return default_supp_shape_from_params(self.ndim_supp, dist_params, **kwargs)
raise NotImplementedError(
"`_supp_shape_from_params` must be implemented for multivariate RVs"
)

def rng_fn(self, rng, *args, **kwargs) -> Union[int, float, np.ndarray]:
"""Sample a numeric random variate."""
Expand Down Expand Up @@ -191,6 +147,8 @@ def _infer_shape(

"""

from pytensor.tensor.extra_ops import broadcast_shape_iter

size_len = get_vector_length(size)

if size_len > 0:
Expand All @@ -216,57 +174,52 @@ def _infer_shape(

# Broadcast the parameters
param_shapes = params_broadcast_shapes(
param_shapes or [shape_tuple(p) for p in dist_params], self.ndims_params
param_shapes or [shape_tuple(p) for p in dist_params],
self.ndims_params,
)

def slice_ind_dims(p, ps, n):
def extract_batch_shape(p, ps, n):
shape = tuple(ps)

if n == 0:
return (p, shape)
return shape

ind_slice = (slice(None),) * (p.ndim - n) + (0,) * n
ind_shape = [
batch_shape = [
s if b is False else constant(1, "int64")
for s, b in zip(shape[:-n], p.broadcastable[:-n])
for s, b in zip(shape[:-n], p.type.broadcastable[:-n])
]
return (
p[ind_slice],
ind_shape,
)
return batch_shape

# These are versions of our actual parameters with the anticipated
# dimensions (i.e. support dimensions) removed so that only the
# independent variate dimensions are left.
params_ind_slice = tuple(
slice_ind_dims(p, ps, n)
params_batch_shape = tuple(
extract_batch_shape(p, ps, n)
for p, ps, n in zip(dist_params, param_shapes, self.ndims_params)
)

if len(params_ind_slice) == 1:
_, shape_ind = params_ind_slice[0]
elif len(params_ind_slice) > 1:
if len(params_batch_shape) == 1:
[batch_shape] = params_batch_shape
elif len(params_batch_shape) > 1:
# If there are multiple parameters, the dimensions of their
# independent variates should broadcast together.
p_slices, p_shapes = zip(*params_ind_slice)

shape_ind = pytensor.tensor.extra_ops.broadcast_shape_iter(
p_shapes, arrays_are_shapes=True
batch_shape = broadcast_shape_iter(
params_batch_shape,
arrays_are_shapes=True,
)

else:
# Distribution has no parameters
shape_ind = ()
batch_shape = ()

if self.ndim_supp == 0:
shape_supp = ()
supp_shape = ()
else:
shape_supp = self._supp_shape_from_params(
supp_shape = self._supp_shape_from_params(
dist_params,
param_shapes=param_shapes,
)

shape = tuple(shape_ind) + tuple(shape_supp)
shape = tuple(batch_shape) + tuple(supp_shape)
if not shape:
shape = constant([], dtype="int64")

Expand Down
51 changes: 49 additions & 2 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections.abc import Sequence
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Sequence, Tuple, Union

import numpy as np

from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
from pytensor.scalar import ScalarVariable
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor_variable, cast, constant
from pytensor.tensor.extra_ops import broadcast_to
Expand Down Expand Up @@ -285,3 +285,50 @@ def gen(self, op: "RandomVariable", *args, **kwargs) -> TensorVariable:
rng.default_update = new_rng

return out


def supp_shape_from_ref_param_shape(
*,
ndim_supp: int,
dist_params: Sequence[Variable],
param_shapes: Optional[Sequence[Tuple[ScalarVariable, ...]]] = None,
ref_param_idx: int,
) -> Union[TensorVariable, Tuple[ScalarVariable, ...]]:
"""Extract the support shape of a multivariate `RandomVariable` from the shape of a reference parameter.

Several multivariate `RandomVariable`s have a support shape determined by the last dimensions of a parameter.
For example `multivariate_normal(zeros(5, 3), eye(3)) has a support shape of (3,) that is determined by the
last dimension of the mean or the covariance.

Parameters
----------
ndim_supp: int
Support dimensionality of the `RandomVariable`.
(e.g. a multivariate normal draw is 1D, so `ndim_supp = 1`).
dist_params: list of `pytensor.graph.basic.Variable`
The distribution parameters.
param_shapes: list of tuple of `ScalarVariable` (optional)
Symbolic shapes for each distribution parameter. These will
be used in place of distribution parameter-generated shapes.
ref_param_idx: int
The index of the distribution parameter to use as a reference

Returns
-------
out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`.

"""
if ndim_supp <= 0:
raise ValueError("ndim_supp must be greater than 0")
if param_shapes is not None:
ref_param = param_shapes[ref_param_idx]
return tuple(ref_param[i] for i in range(-ndim_supp, 0))
else:
ref_param = dist_params[ref_param_idx]
if ref_param.ndim < ndim_supp:
raise ValueError(
"Reference parameter does not match the expected dimensions; "
f"{ref_param} has less than {ndim_supp} dim(s)."
)
return ref_param.shape[-ndim_supp:]
8 changes: 8 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,14 @@ def test_permutation_samples():
compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX))


def test_permutation_shape():
assert tuple(permutation(5).shape.eval()) == (5,)
assert tuple(permutation(np.arange(5)).shape.eval()) == (5,)
assert tuple(permutation(np.arange(10).reshape(2, 5)).shape.eval()) == (2, 5)
assert tuple(permutation(5, size=(2, 3)).shape.eval()) == (2, 3, 5)
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)


@config.change_flags(compute_test_value="off")
def test_pickle():
# This is an interesting `Op` case, because it has `None` types and a
Expand Down
Loading