|
18 | 18 | import warnings
|
19 | 19 |
|
20 | 20 | from functools import reduce
|
| 21 | +from typing import Optional |
21 | 22 |
|
22 | 23 | import aesara
|
23 | 24 | import aesara.tensor as at
|
|
63 | 64 | _change_dist_size,
|
64 | 65 | broadcast_dist_samples_to,
|
65 | 66 | change_dist_size,
|
| 67 | + get_support_shape, |
66 | 68 | rv_size_is_none,
|
67 | 69 | to_tuple,
|
68 | 70 | )
|
69 |
| -from pymc.distributions.transforms import Interval, _default_transform |
| 71 | +from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform |
70 | 72 | from pymc.math import kron_diag, kron_dot
|
71 | 73 | from pymc.util import check_dist_not_registered
|
72 | 74 |
|
73 | 75 | __all__ = [
|
74 | 76 | "MvNormal",
|
| 77 | + "ZeroSumNormal", |
75 | 78 | "MvStudentT",
|
76 | 79 | "Dirichlet",
|
77 | 80 | "Multinomial",
|
@@ -2380,3 +2383,205 @@ def logp(value, alpha, K):
|
2380 | 2383 | K > 0,
|
2381 | 2384 | msg="alpha > 0, K > 0",
|
2382 | 2385 | )
|
| 2386 | + |
| 2387 | + |
| 2388 | +class ZeroSumNormalRV(SymbolicRandomVariable): |
| 2389 | + """ZeroSumNormal random variable""" |
| 2390 | + |
| 2391 | + _print_name = ("ZeroSumNormal", "\\operatorname{ZeroSumNormal}") |
| 2392 | + default_output = 0 |
| 2393 | + |
| 2394 | + |
| 2395 | +class ZeroSumNormal(Distribution): |
| 2396 | + r""" |
| 2397 | + ZeroSumNormal distribution, i.e Normal distribution where one or |
| 2398 | + several axes are constrained to sum to zero. |
| 2399 | + By default, the last axis is constrained to sum to zero. |
| 2400 | + See `zerosum_axes` kwarg for more details. |
| 2401 | +
|
| 2402 | + .. math:: |
| 2403 | +
|
| 2404 | + \begin{align*} |
| 2405 | + ZSN(\sigma) = N \Big( 0, \sigma^2 (I - \tfrac{1}{n}J) \Big) \\ |
| 2406 | + \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ |
| 2407 | + n = \text{nbr of zero-sum axes} |
| 2408 | + \end{align*} |
| 2409 | +
|
| 2410 | + Parameters |
| 2411 | + ---------- |
| 2412 | + sigma : tensor_like of float |
| 2413 | + Scale parameter (sigma > 0). |
| 2414 | + It's actually the standard deviation of the underlying, unconstrained Normal distribution. |
| 2415 | + Defaults to 1 if not specified. |
| 2416 | + For now, ``sigma`` has to be a scalar, to ensure the zero-sum constraint. |
| 2417 | + zerosum_axes: int, defaults to 1 |
| 2418 | + Number of axes along which the zero-sum constraint is enforced, starting from the rightmost position. |
| 2419 | + Defaults to 1, i.e the rightmost axis. |
| 2420 | + dims: sequence of strings, optional |
| 2421 | + Dimension names of the distribution. Works the same as for other PyMC distributions. |
| 2422 | + Necessary if ``shape`` is not passed. |
| 2423 | + shape: tuple of integers, optional |
| 2424 | + Shape of the distribution. Works the same as for other PyMC distributions. |
| 2425 | + Necessary if ``dims`` or ``observed`` is not passed. |
| 2426 | +
|
| 2427 | + Warnings |
| 2428 | + -------- |
| 2429 | + ``sigma`` has to be a scalar, to ensure the zero-sum constraint. |
| 2430 | + The ability to specifiy a vector of ``sigma`` may be added in future versions. |
| 2431 | +
|
| 2432 | + ``zerosum_axes`` has to be > 0. If you want the behavior of ``zerosum_axes = 0``, |
| 2433 | + just use ``pm.Normal``. |
| 2434 | +
|
| 2435 | + Examples |
| 2436 | + -------- |
| 2437 | + Define a `ZeroSumNormal` variable, with `sigma=1` and |
| 2438 | + `zerosum_axes=1` by default:: |
| 2439 | +
|
| 2440 | + COORDS = { |
| 2441 | + "regions": ["a", "b", "c"], |
| 2442 | + "answers": ["yes", "no", "whatever", "don't understand question"], |
| 2443 | + } |
| 2444 | + with pm.Model(coords=COORDS) as m: |
| 2445 | + # the zero sum axis will be 'answers' |
| 2446 | + v = pm.ZeroSumNormal("v", dims=("regions", "answers")) |
| 2447 | +
|
| 2448 | + with pm.Model(coords=COORDS) as m: |
| 2449 | + # the zero sum axes will be 'answers' and 'regions' |
| 2450 | + v = pm.ZeroSumNormal("v", dims=("regions", "answers"), zerosum_axes=2) |
| 2451 | +
|
| 2452 | + with pm.Model(coords=COORDS) as m: |
| 2453 | + # the zero sum axes will be the last two |
| 2454 | + v = pm.ZeroSumNormal("v", shape=(3, 4, 5), zerosum_axes=2) |
| 2455 | + """ |
| 2456 | + rv_type = ZeroSumNormalRV |
| 2457 | + |
| 2458 | + def __new__(cls, *args, zerosum_axes=None, support_shape=None, dims=None, **kwargs): |
| 2459 | + if dims is not None or kwargs.get("observed") is not None: |
| 2460 | + zerosum_axes = cls.check_zerosum_axes(zerosum_axes) |
| 2461 | + |
| 2462 | + support_shape = get_support_shape( |
| 2463 | + support_shape=support_shape, |
| 2464 | + shape=None, # Shape will be checked in `cls.dist` |
| 2465 | + dims=dims, |
| 2466 | + observed=kwargs.get("observed", None), |
| 2467 | + ndim_supp=zerosum_axes, |
| 2468 | + ) |
| 2469 | + |
| 2470 | + return super().__new__( |
| 2471 | + cls, *args, zerosum_axes=zerosum_axes, support_shape=support_shape, dims=dims, **kwargs |
| 2472 | + ) |
| 2473 | + |
| 2474 | + @classmethod |
| 2475 | + def dist(cls, sigma=1, zerosum_axes=None, support_shape=None, **kwargs): |
| 2476 | + zerosum_axes = cls.check_zerosum_axes(zerosum_axes) |
| 2477 | + |
| 2478 | + sigma = at.as_tensor_variable(floatX(sigma)) |
| 2479 | + if sigma.ndim > 0: |
| 2480 | + raise ValueError("sigma has to be a scalar") |
| 2481 | + |
| 2482 | + support_shape = get_support_shape( |
| 2483 | + support_shape=support_shape, |
| 2484 | + shape=kwargs.get("shape"), |
| 2485 | + ndim_supp=zerosum_axes, |
| 2486 | + ) |
| 2487 | + |
| 2488 | + if support_shape is None: |
| 2489 | + if zerosum_axes > 0: |
| 2490 | + raise ValueError("You must specify dims, shape or support_shape parameter") |
| 2491 | + # TODO: edge-case doesn't work for now, because at.stack in get_support_shape fails |
| 2492 | + # else: |
| 2493 | + # support_shape = () # because it's just a Normal in that case |
| 2494 | + support_shape = at.as_tensor_variable(intX(support_shape)) |
| 2495 | + |
| 2496 | + assert zerosum_axes == at.get_vector_length( |
| 2497 | + support_shape |
| 2498 | + ), "support_shape has to be as long as zerosum_axes" |
| 2499 | + |
| 2500 | + return super().dist( |
| 2501 | + [sigma], zerosum_axes=zerosum_axes, support_shape=support_shape, **kwargs |
| 2502 | + ) |
| 2503 | + |
| 2504 | + @classmethod |
| 2505 | + def check_zerosum_axes(cls, zerosum_axes: Optional[int]) -> int: |
| 2506 | + if zerosum_axes is None: |
| 2507 | + zerosum_axes = 1 |
| 2508 | + if not isinstance(zerosum_axes, int): |
| 2509 | + raise TypeError("zerosum_axes has to be an integer") |
| 2510 | + if not zerosum_axes > 0: |
| 2511 | + raise ValueError("zerosum_axes has to be > 0") |
| 2512 | + return zerosum_axes |
| 2513 | + |
| 2514 | + @classmethod |
| 2515 | + def rv_op(cls, sigma, zerosum_axes, support_shape, size=None): |
| 2516 | + |
| 2517 | + shape = to_tuple(size) + tuple(support_shape) |
| 2518 | + normal_dist = ignore_logprob(pm.Normal.dist(sigma=sigma, shape=shape)) |
| 2519 | + |
| 2520 | + if zerosum_axes > normal_dist.ndim: |
| 2521 | + raise ValueError("Shape of distribution is too small for the number of zerosum axes") |
| 2522 | + |
| 2523 | + normal_dist_, sigma_, support_shape_ = ( |
| 2524 | + normal_dist.type(), |
| 2525 | + sigma.type(), |
| 2526 | + support_shape.type(), |
| 2527 | + ) |
| 2528 | + |
| 2529 | + # Zerosum-normaling is achieved by substracting the mean along the given zerosum_axes |
| 2530 | + zerosum_rv_ = normal_dist_ |
| 2531 | + for axis in range(zerosum_axes): |
| 2532 | + zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) |
| 2533 | + |
| 2534 | + return ZeroSumNormalRV( |
| 2535 | + inputs=[normal_dist_, sigma_, support_shape_], |
| 2536 | + outputs=[zerosum_rv_, support_shape_], |
| 2537 | + ndim_supp=zerosum_axes, |
| 2538 | + )(normal_dist, sigma, support_shape) |
| 2539 | + |
| 2540 | + |
| 2541 | +@_change_dist_size.register(ZeroSumNormalRV) |
| 2542 | +def change_zerosum_size(op, normal_dist, new_size, expand=False): |
| 2543 | + |
| 2544 | + normal_dist, sigma, support_shape = normal_dist.owner.inputs |
| 2545 | + |
| 2546 | + if expand: |
| 2547 | + original_shape = tuple(normal_dist.shape) |
| 2548 | + old_size = original_shape[: len(original_shape) - op.ndim_supp] |
| 2549 | + new_size = tuple(new_size) + old_size |
| 2550 | + |
| 2551 | + return ZeroSumNormal.rv_op( |
| 2552 | + sigma=sigma, zerosum_axes=op.ndim_supp, support_shape=support_shape, size=new_size |
| 2553 | + ) |
| 2554 | + |
| 2555 | + |
| 2556 | +@_moment.register(ZeroSumNormalRV) |
| 2557 | +def zerosumnormal_moment(op, rv, *rv_inputs): |
| 2558 | + return at.zeros_like(rv) |
| 2559 | + |
| 2560 | + |
| 2561 | +@_default_transform.register(ZeroSumNormalRV) |
| 2562 | +def zerosum_default_transform(op, rv): |
| 2563 | + zerosum_axes = tuple(np.arange(-op.ndim_supp, 0)) |
| 2564 | + return ZeroSumTransform(zerosum_axes) |
| 2565 | + |
| 2566 | + |
| 2567 | +@_logprob.register(ZeroSumNormalRV) |
| 2568 | +def zerosumnormal_logp(op, values, normal_dist, sigma, support_shape, **kwargs): |
| 2569 | + (value,) = values |
| 2570 | + shape = value.shape |
| 2571 | + zerosum_axes = op.ndim_supp |
| 2572 | + |
| 2573 | + _deg_free_support_shape = at.inc_subtensor(shape[-zerosum_axes:], -1) |
| 2574 | + _full_size = at.prod(shape) |
| 2575 | + _degrees_of_freedom = at.prod(_deg_free_support_shape) |
| 2576 | + |
| 2577 | + zerosums = [ |
| 2578 | + at.all(at.isclose(at.mean(value, axis=-axis - 1), 0, atol=1e-9)) |
| 2579 | + for axis in range(zerosum_axes) |
| 2580 | + ] |
| 2581 | + |
| 2582 | + out = at.sum( |
| 2583 | + pm.logp(normal_dist, value) * _degrees_of_freedom / _full_size, |
| 2584 | + axis=tuple(np.arange(-zerosum_axes, 0)), |
| 2585 | + ) |
| 2586 | + |
| 2587 | + return check_parameters(out, *zerosums, msg="at.mean(value, axis=zerosum_axes) == 0") |
0 commit comments