-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Dirichlet multinomial (continued) #4373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 56 commits
b7492d2
2106f7c
487fc8a
24d7ec8
4fbd1d9
685a428
ad8e77e
8fa717a
4db6b1c
01d359b
4892355
bc5f3bf
ffa705c
c801ef1
c8921ee
e801568
3483ab5
23ba2e4
fe018ec
25fd41a
28b0a62
d363f96
9b6828c
d438dfc
dde5c45
49b432d
83fbda6
9748a9d
7b20680
66c83b0
672ef56
2d5d20e
922515b
aa89d0a
2343004
a08bc51
22beead
7bad831
9bbddba
086459f
f8499d3
1cd2a9f
ef00fe1
f2ac8e9
f5dcdc3
c4e017a
d46dd50
3ab518d
cdd6d27
24447a4
c5e9b67
0bd6c3d
f919456
c082f00
ea0ae59
b451967
128d5cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,15 +42,17 @@ | |
) | ||
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple | ||
from pymc3.distributions.special import gammaln, multigammaln | ||
from pymc3.exceptions import ShapeError | ||
from pymc3.math import kron_diag, kron_dot, kron_solve_lower, kronecker | ||
from pymc3.model import Deterministic | ||
from pymc3.theanof import floatX | ||
from pymc3.theanof import floatX, intX | ||
|
||
__all__ = [ | ||
"MvNormal", | ||
"MvStudentT", | ||
"Dirichlet", | ||
"Multinomial", | ||
"DirichletMultinomial", | ||
"Wishart", | ||
"WishartBartlett", | ||
"LKJCorr", | ||
|
@@ -690,6 +692,160 @@ def logp(self, x): | |
) | ||
|
||
|
||
class DirichletMultinomial(Discrete): | ||
R"""Dirichlet Multinomial log-likelihood. | ||
|
||
Dirichlet mixture of Multinomials distribution, with a marginalized PMF. | ||
|
||
.. math:: | ||
|
||
f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)} | ||
{\Gamma(\n + \sum a_k)} | ||
\prod_{k=1}^K | ||
\frac{\Gamma(x_k + a_k)} | ||
{\Gamma(x_k + 1)\Gamma(a_k)} | ||
|
||
========== =========================================== | ||
Support :math:`x \in \{0, 1, \ldots, n\}` such that | ||
:math:`\sum x_i = n` | ||
Mean :math:`n \frac{a_i}{\sum{a_k}}` | ||
========== =========================================== | ||
|
||
Parameters | ||
---------- | ||
n : int or array | ||
Total counts in each replicate. If n is an array its shape must be (N,) | ||
with N = a.shape[0] | ||
|
||
a : one- or two-dimensional array | ||
Dirichlet parameter. Elements must be strictly positive. | ||
The number of categories is given by the length of the last axis. | ||
|
||
shape : integer tuple | ||
Sayam753 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Describes shape of distribution. For example if n=array([5, 10]), and | ||
a=array([1, 1, 1]), shape should be (2, 3). | ||
""" | ||
|
||
def __init__(self, n, a, shape, *args, **kwargs): | ||
|
||
super().__init__(shape=shape, defaults=("_defaultval",), *args, **kwargs) | ||
Comment on lines
+730
to
+731
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Dirichlet distribution makes use of Ping @brandonwillard to ask how does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think it is as simple, since the shape can be influenced by the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also the Dirichlet functionality is wrapped in a DeprecationWarning (even though I don't seem to be able to trigger it), which suggests that they wanted to abandon that approach at some point. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ricardoV94 , just a follow up, it indeed makes sense to avoid the use of |
||
|
||
n = intX(n) | ||
a = floatX(a) | ||
if len(self.shape) > 1: | ||
self.n = tt.shape_padright(n) | ||
self.a = tt.as_tensor_variable(a) if a.ndim > 1 else tt.shape_padleft(a) | ||
else: | ||
# n is a scalar, p is a 1d array | ||
self.n = tt.as_tensor_variable(n) | ||
self.a = tt.as_tensor_variable(a) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
p = self.a / self.a.sum(-1, keepdims=True) | ||
|
||
self.mean = self.n * p | ||
# Mode is only an approximation. Exact computation requires a complex | ||
# iterative algorithm as described in https://doi.org/10.1016/j.spl.2009.09.013 | ||
mode = tt.cast(tt.round(self.mean), "int32") | ||
diff = self.n - tt.sum(mode, axis=-1, keepdims=True) | ||
inc_bool_arr = tt.abs_(diff) > 0 | ||
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) | ||
self._defaultval = mode | ||
|
||
def _random(self, n, a, size=None): | ||
# numpy will cast dirichlet and multinomial samples to float64 by default | ||
original_dtype = a.dtype | ||
|
||
# Thanks to the default shape handling done in generate_values, the last | ||
# axis of n is a dummy axis that allows it to broadcast well with `a` | ||
n = np.broadcast_to(n, size) | ||
a = np.broadcast_to(a, size) | ||
n = n[..., 0] | ||
|
||
# np.random.multinomial needs `n` to be a scalar int and `a` a | ||
# sequence so we semi flatten them and iterate over them | ||
n_ = n.reshape([-1]) | ||
a_ = a.reshape([-1, a.shape[-1]]) | ||
p_ = np.array([np.random.dirichlet(aa) for aa in a_]) | ||
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)]) | ||
samples = samples.reshape(a.shape) | ||
Sayam753 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# We cast back to the original dtype | ||
return samples.astype(original_dtype) | ||
|
||
def random(self, point=None, size=None): | ||
""" | ||
Draw random values from Dirichlet-Multinomial distribution. | ||
|
||
Parameters | ||
---------- | ||
point: dict, optional | ||
Dict of variable values on which random values are to be | ||
conditioned (uses default point if not specified). | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
size: int, optional | ||
Desired size of random sample (returns one sample if not | ||
specified). | ||
|
||
Returns | ||
------- | ||
array | ||
""" | ||
n, a = draw_values([self.n, self.a], point=point, size=size) | ||
samples = generate_samples( | ||
self._random, | ||
n, | ||
a, | ||
dist_shape=self.shape, | ||
size=size, | ||
) | ||
|
||
# If distribution is initialized with .dist(), valid init shape is not asserted. | ||
# Under normal use in a model context valid init shape is asserted at start. | ||
expected_shape = to_tuple(size) + to_tuple(self.shape) | ||
sample_shape = tuple(samples.shape) | ||
if sample_shape != expected_shape: | ||
raise ShapeError( | ||
f"Expected sample shape was {expected_shape} but got {sample_shape}. " | ||
"This may reflect an invalid initialization shape." | ||
) | ||
|
||
return samples | ||
|
||
def logp(self, value): | ||
""" | ||
Calculate log-probability of DirichletMultinomial distribution | ||
at specified value. | ||
|
||
Parameters | ||
---------- | ||
value: integer array | ||
Value for which log-probability is calculated. | ||
|
||
Returns | ||
------- | ||
TensorVariable | ||
""" | ||
a = self.a | ||
n = self.n | ||
sum_a = a.sum(axis=-1, keepdims=True) | ||
|
||
const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a) | ||
series = gammaln(value + a) - (gammaln(value + 1) + gammaln(a)) | ||
result = const + series.sum(axis=-1, keepdims=True) | ||
# Bounds checking to confirm parameters and data meet all constraints | ||
# and that each observation value_i sums to n_i. | ||
return bound( | ||
result, | ||
tt.all(tt.ge(value, 0)), | ||
tt.all(tt.gt(a, 0)), | ||
tt.all(tt.ge(n, 0)), | ||
tt.all(tt.eq(value.sum(axis=-1, keepdims=True), n)), | ||
broadcast_conditions=False, | ||
) | ||
|
||
def _distr_parameters_for_repr(self): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ["n", "a"] | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def posdef(AA): | ||
try: | ||
linalg.cholesky(AA) | ||
|
Uh oh!
There was an error while loading. Please reload this page.