Skip to content

Commit 9aeb6b5

Browse files
Add ZeroSumNormal distribution (#6121)
Also: * Refactor get_steps to work with multivariate support shapes * Replace get_steps by get_support_shape_1d in timeseries.py Co-authored-by: Ricardo Vieira <[email protected]>
1 parent faebc60 commit 9aeb6b5

File tree

10 files changed

+769
-141
lines changed

10 files changed

+769
-141
lines changed

docs/source/api/distributions/multivariate.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Multivariate
88

99
MvNormal
1010
MvStudentT
11+
ZeroSumNormal
1112
Dirichlet
1213
Multinomial
1314
DirichletMultinomial

docs/source/api/distributions/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Specific Transform Classes
3333
LogExpM1
3434
Ordered
3535
SumTo1
36+
ZeroSumTransform
3637

3738

3839
Transform Composition Classes

pymc/distributions/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
StickBreakingWeights,
100100
Wishart,
101101
WishartBartlett,
102+
ZeroSumNormal,
102103
)
103104
from pymc.distributions.simulator import Simulator
104105
from pymc.distributions.timeseries import (
@@ -116,8 +117,8 @@
116117
"Uniform",
117118
"Flat",
118119
"HalfFlat",
119-
"TruncatedNormal",
120120
"Normal",
121+
"TruncatedNormal",
121122
"Beta",
122123
"Kumaraswamy",
123124
"Exponential",
@@ -160,6 +161,7 @@
160161
"Continuous",
161162
"Discrete",
162163
"MvNormal",
164+
"ZeroSumNormal",
163165
"MatrixNormal",
164166
"KroneckerNormal",
165167
"MvStudentT",

pymc/distributions/multivariate.py

Lines changed: 206 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import warnings
1919

2020
from functools import reduce
21+
from typing import Optional
2122

2223
import aesara
2324
import aesara.tensor as at
@@ -63,15 +64,17 @@
6364
_change_dist_size,
6465
broadcast_dist_samples_to,
6566
change_dist_size,
67+
get_support_shape,
6668
rv_size_is_none,
6769
to_tuple,
6870
)
69-
from pymc.distributions.transforms import Interval, _default_transform
71+
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
7072
from pymc.math import kron_diag, kron_dot
7173
from pymc.util import check_dist_not_registered
7274

7375
__all__ = [
7476
"MvNormal",
77+
"ZeroSumNormal",
7578
"MvStudentT",
7679
"Dirichlet",
7780
"Multinomial",
@@ -2380,3 +2383,205 @@ def logp(value, alpha, K):
23802383
K > 0,
23812384
msg="alpha > 0, K > 0",
23822385
)
2386+
2387+
2388+
class ZeroSumNormalRV(SymbolicRandomVariable):
2389+
"""ZeroSumNormal random variable"""
2390+
2391+
_print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}")
2392+
default_output = 0
2393+
2394+
2395+
class ZeroSumNormal(Distribution):
2396+
r"""
2397+
ZeroSumNormal distribution, i.e Normal distribution where one or
2398+
several axes are constrained to sum to zero.
2399+
By default, the last axis is constrained to sum to zero.
2400+
See `zerosum_axes` kwarg for more details.
2401+
2402+
.. math::
2403+
2404+
\begin{align*}
2405+
ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\
2406+
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
2407+
n = \text{nbr of zero-sum axes}
2408+
\end{align*}
2409+
2410+
Parameters
2411+
----------
2412+
sigma : tensor_like of float
2413+
Scale parameter (sigma > 0).
2414+
It's actually the standard deviation of the underlying, unconstrained Normal distribution.
2415+
Defaults to 1 if not specified.
2416+
For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2417+
zerosum_axes: int, defaults to 1
2418+
Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position.
2419+
Defaults to 1, i.e the rightmost axis.
2420+
dims: sequence of strings, optional
2421+
Dimension names of the distribution. Works the same as for other PyMC distributions.
2422+
Necessary if ``shape`` is not passed.
2423+
shape: tuple of integers, optional
2424+
Shape of the distribution. Works the same as for other PyMC distributions.
2425+
Necessary if ``dims`` or ``observed`` is not passed.
2426+
2427+
Warnings
2428+
--------
2429+
``sigma`` has to be a scalar, to ensure the zero-sum constraint.
2430+
The ability to specifiy a vector of ``sigma`` may be added in future versions.
2431+
2432+
``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``,
2433+
just use ``pm.Normal``.
2434+
2435+
Examples
2436+
--------
2437+
Define a `ZeroSumNormal` variable, with `sigma=1` and
2438+
`zerosum_axes=1` by default::
2439+
2440+
COORDS = {
2441+
"regions": ["a", "b", "c"],
2442+
"answers": ["yes", "no", "whatever", "don't understand question"],
2443+
}
2444+
with pm.Model(coords=COORDS) as m:
2445+
# the zero sum axis will be 'answers'
2446+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"))
2447+
2448+
with pm.Model(coords=COORDS) as m:
2449+
# the zero sum axes will be 'answers' and 'regions'
2450+
v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2)
2451+
2452+
with pm.Model(coords=COORDS) as m:
2453+
# the zero sum axes will be the last two
2454+
v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2)
2455+
"""
2456+
rv_type = ZeroSumNormalRV
2457+
2458+
def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs):
2459+
if dims is not None or kwargs.get("observed") is not None:
2460+
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
2461+
2462+
support_shape = get_support_shape(
2463+
support_shape=support_shape,
2464+
shape=None, # Shape will be checked in `cls.dist`
2465+
dims=dims,
2466+
observed=kwargs.get("observed", None),
2467+
ndim_supp=zerosum_axes,
2468+
)
2469+
2470+
return super().__new__(
2471+
cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs
2472+
)
2473+
2474+
@classmethod
2475+
def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs):
2476+
zerosum_axes = cls.check_zerosum_axes(zerosum_axes)
2477+
2478+
sigma = at.as_tensor_variable(floatX(sigma))
2479+
if sigma.ndim > 0:
2480+
raise ValueError("sigma has to be a scalar")
2481+
2482+
support_shape = get_support_shape(
2483+
support_shape=support_shape,
2484+
shape=kwargs.get("shape"),
2485+
ndim_supp=zerosum_axes,
2486+
)
2487+
2488+
if support_shape is None:
2489+
if zerosum_axes > 0:
2490+
raise ValueError("You must specify dims, shape or support_shape parameter")
2491+
# TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails
2492+
# else:
2493+
# support_shape = () # because it's just a Normal in that case
2494+
support_shape = at.as_tensor_variable(intX(support_shape))
2495+
2496+
assert zerosum_axes == at.get_vector_length(
2497+
support_shape
2498+
), "support_shape has to be as long as zerosum_axes"
2499+
2500+
return super().dist(
2501+
[sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs
2502+
)
2503+
2504+
@classmethod
2505+
def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int:
2506+
if zerosum_axes is None:
2507+
zerosum_axes = 1
2508+
if not isinstance(zerosum_axes, int):
2509+
raise TypeError("zerosum_axes has to be an integer")
2510+
if not zerosum_axes > 0:
2511+
raise ValueError("zerosum_axes has to be > 0")
2512+
return zerosum_axes
2513+
2514+
@classmethod
2515+
def rv_op(cls, sigma, zerosum_axes, support_shape, size=None):
2516+
2517+
shape = to_tuple(size) + tuple(support_shape)
2518+
normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape))
2519+
2520+
if zerosum_axes > normal_dist.ndim:
2521+
raise ValueError("Shape of distribution is too small for the number of zerosum axes")
2522+
2523+
normal_dist_, sigma_, support_shape_ = (
2524+
normal_dist.type(),
2525+
sigma.type(),
2526+
support_shape.type(),
2527+
)
2528+
2529+
# Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes
2530+
zerosum_rv_ = normal_dist_
2531+
for axis in range(zerosum_axes):
2532+
zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True)
2533+
2534+
return ZeroSumNormalRV(
2535+
inputs=[normal_dist_, sigma_, support_shape_],
2536+
outputs=[zerosum_rv_, support_shape_],
2537+
ndim_supp=zerosum_axes,
2538+
)(normal_dist, sigma, support_shape)
2539+
2540+
2541+
@_change_dist_size.register(ZeroSumNormalRV)
2542+
def change_zerosum_size(op, normal_dist, new_size, expand=False):
2543+
2544+
normal_dist, sigma, support_shape = normal_dist.owner.inputs
2545+
2546+
if expand:
2547+
original_shape = tuple(normal_dist.shape)
2548+
old_size = original_shape[: len(original_shape) - op.ndim_supp]
2549+
new_size = tuple(new_size) + old_size
2550+
2551+
return ZeroSumNormal.rv_op(
2552+
sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size
2553+
)
2554+
2555+
2556+
@_moment.register(ZeroSumNormalRV)
2557+
def zerosumnormal_moment(op, rv, *rv_inputs):
2558+
return at.zeros_like(rv)
2559+
2560+
2561+
@_default_transform.register(ZeroSumNormalRV)
2562+
def zerosum_default_transform(op, rv):
2563+
zerosum_axes = tuple(np.arange(-op.ndim_supp, 0))
2564+
return ZeroSumTransform(zerosum_axes)
2565+
2566+
2567+
@_logprob.register(ZeroSumNormalRV)
2568+
def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs):
2569+
(value,) = values
2570+
shape = value.shape
2571+
zerosum_axes = op.ndim_supp
2572+
2573+
_deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1)
2574+
_full_size = at.prod(shape)
2575+
_degrees_of_freedom = at.prod(_deg_free_support_shape)
2576+
2577+
zerosums = [
2578+
at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9))
2579+
for axis in range(zerosum_axes)
2580+
]
2581+
2582+
out = at.sum(
2583+
pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size,
2584+
axis=tuple(np.arange(-zerosum_axes, 0)),
2585+
)
2586+
2587+
return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0")

0 commit comments

Comments
 (0)