-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Dirichlet multinomial (continued) #4373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 42 commits
b7492d2
2106f7c
487fc8a
24d7ec8
4fbd1d9
685a428
ad8e77e
8fa717a
4db6b1c
01d359b
4892355
bc5f3bf
ffa705c
c801ef1
c8921ee
e801568
3483ab5
23ba2e4
fe018ec
25fd41a
28b0a62
d363f96
9b6828c
d438dfc
dde5c45
49b432d
83fbda6
9748a9d
7b20680
66c83b0
672ef56
2d5d20e
922515b
aa89d0a
2343004
a08bc51
22beead
7bad831
9bbddba
086459f
f8499d3
1cd2a9f
ef00fe1
f2ac8e9
f5dcdc3
c4e017a
d46dd50
3ab518d
cdd6d27
24447a4
c5e9b67
0bd6c3d
f919456
c082f00
ea0ae59
b451967
128d5cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,6 +51,7 @@ | |
"MvStudentT", | ||
"Dirichlet", | ||
"Multinomial", | ||
"DirichletMultinomial", | ||
"Wishart", | ||
"WishartBartlett", | ||
"LKJCorr", | ||
|
@@ -690,6 +691,138 @@ def logp(self, x): | |
) | ||
|
||
|
||
class DirichletMultinomial(Discrete): | ||
R"""Dirichlet Multinomial log-likelihood. | ||
|
||
Dirichlet mixture of multinomials distribution, with a marginalized PMF. | ||
|
||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
.. math:: | ||
|
||
f(x \mid n, a) = \frac{\Gamma(n + 1)\Gamma(\sum a_k)} | ||
{\Gamma(\n + \sum a_k)} | ||
\prod_{k=1}^K | ||
\frac{\Gamma(x_k + a_k)} | ||
{\Gamma(x_k + 1)\Gamma(a_k)} | ||
|
||
========== =========================================== | ||
Support :math:`x \in \{0, 1, \ldots, n\}` such that | ||
:math:`\sum x_i = n` | ||
Mean :math:`n \frac{a_i}{\sum{a_k}}` | ||
========== =========================================== | ||
|
||
Parameters | ||
---------- | ||
n : int or array | ||
Total counts in each replicate. If n is an array its shape must be (N,) | ||
with N = a.shape[0] | ||
|
||
a : one- or two-dimensional array | ||
Dirichlet parameter. Elements must be non-negative. | ||
Dimension of each element of the distribution is the length | ||
of the second dimension of *a*. | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
shape : numerical tuple | ||
Describes shape of distribution. For example if n=array([5, 10]), and | ||
p=array([1, 1, 1]), shape should be (2, 3). | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
def __init__(self, n, a, shape, *args, **kwargs): | ||
super().__init__(shape, *args, **kwargs) | ||
|
||
if len(self.shape) > 1: | ||
self.n = tt.shape_padright(n) | ||
self.a = tt.as_tensor_variable(a) if np.ndim(a) > 1 else tt.shape_padleft(a) | ||
else: | ||
# n is a scalar, p is a 1d array | ||
self.n = tt.as_tensor_variable(n) | ||
self.a = tt.as_tensor_variable(a) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
p = self.a / self.a.sum(-1, keepdims=True) | ||
|
||
self.mean = self.n * p | ||
mode = tt.cast(tt.round(self.mean), "int32") | ||
diff = self.n - tt.sum(mode, axis=-1, keepdims=True) | ||
inc_bool_arr = tt.abs_(diff) > 0 | ||
mode = tt.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()]) | ||
self.mode = mode | ||
|
||
def _random(self, n, a, size=None, raw_size=None): | ||
Sayam753 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# numpy will cast dirichlet and multinomial samples to float64 by default | ||
original_dtype = a.dtype | ||
|
||
# Thanks to the default shape handling done in generate_values, the last | ||
# axis of n is a dummy axis that allows it to broadcast well with `a` | ||
n = np.broadcast_to(n, size) | ||
a = np.broadcast_to(a, size) | ||
n = n[..., 0] | ||
|
||
# np.random.multinomial needs `n` to be a scalar int and `a` a | ||
# sequence so we semi flatten them and iterate over them | ||
n_ = n.reshape([-1]) | ||
a_ = a.reshape([-1, a.shape[-1]]) | ||
p_ = np.array([np.random.dirichlet(aa) for aa in a_]) | ||
samples = np.array([np.random.multinomial(nn, pp) for nn, pp in zip(n_, p_)]) | ||
samples = samples.reshape(a.shape) | ||
Sayam753 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# We cast back to the original dtype | ||
return samples.astype(original_dtype) | ||
|
||
def random(self, point=None, size=None): | ||
""" | ||
Draw random values from Dirichlet-Multinomial distribution. | ||
|
||
Parameters | ||
---------- | ||
point: dict, optional | ||
Dict of variable values on which random values are to be | ||
conditioned (uses default point if not specified). | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
size: int, optional | ||
Desired size of random sample (returns one sample if not | ||
specified). | ||
|
||
Returns | ||
------- | ||
array | ||
""" | ||
n, a = draw_values([self.n, self.a], point=point, size=size) | ||
samples = generate_samples( | ||
self._random, | ||
n, | ||
a, | ||
dist_shape=self.shape, | ||
not_broadcast_kwargs={"raw_size": size}, | ||
size=size, | ||
) | ||
|
||
if size is not None: | ||
expect_shape = (size, *self.shape) | ||
else: | ||
expect_shape = self.shape | ||
assert tuple(samples.shape) == tuple(expect_shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want this here? Is any other distribution doing the same? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like it is unnecessary during runtime, since we are already testing this quite a lot in the unittests There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Asserts are not evil, but also not necessarily the best option. (Also see https://stackoverflow.com/a/13534633) To validate external inputs, use
This has the advantage that a coverage check can tell you if the tests cover the exception. If you want to validate an internal assumption and make sure that your code did not mess up, the assert is the right thing to do. It can help tremendously to debug or understand code. In https://github.com/michaelosthege/pyrff/blob/master/pyrff/rff.py#L230-L261 I did both, because it took me weeks to understand the shapes of the code I was re-implementing there... In this case if you're already testing it a lot maybe a comment instead of an assert is more appropriate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Yes, I think this is being sufficiently tested for reasonable inputs; I believe @ricardoV94 demonstrated already that these asserts were passed for a large variety of shapes (...even when some of the actual outputs were somewhat unintuitive). I'm more worried about users "abusing" weirdly shaped inputs, in which case I like your explicit ShapeError to catch corner cases we didn't even think of. On the other hand, Multinomial doesn't do this now. If we wanted to add this check I think we should do it for both distributions in parallel. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I agree with this |
||
|
||
return samples | ||
|
||
def logp(self, x): | ||
a = self.a | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
n = self.n | ||
sum_a = a.sum(axis=-1, keepdims=True) | ||
|
||
const = (gammaln(n + 1) + gammaln(sum_a)) - gammaln(n + sum_a) | ||
series = gammaln(x + a) - (gammaln(x + 1) + gammaln(a)) | ||
result = const + series.sum(axis=-1, keepdims=True) | ||
return bound( | ||
result, | ||
tt.all(tt.ge(x, 0)), | ||
tt.all(tt.gt(a, 0)), | ||
tt.all(tt.ge(n, 0)), | ||
tt.all(tt.eq(x.sum(axis=-1, keepdims=True), n)), | ||
broadcast_conditions=False, | ||
) | ||
AlexAndorra marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def _distr_parameters_for_repr(self): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return ["n", "a"] | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
def posdef(AA): | ||
try: | ||
linalg.cholesky(AA) | ||
|
Uh oh!
There was an error while loading. Please reload this page.