-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Dirichlet multinomial (continued) #4373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 32 commits
b7492d2
2106f7c
487fc8a
24d7ec8
4fbd1d9
685a428
ad8e77e
8fa717a
4db6b1c
01d359b
4892355
bc5f3bf
ffa705c
c801ef1
c8921ee
e801568
3483ab5
23ba2e4
fe018ec
25fd41a
28b0a62
d363f96
9b6828c
d438dfc
dde5c45
49b432d
83fbda6
9748a9d
7b20680
66c83b0
672ef56
2d5d20e
922515b
aa89d0a
2343004
a08bc51
22beead
7bad831
9bbddba
086459f
f8499d3
1cd2a9f
ef00fe1
f2ac8e9
f5dcdc3
c4e017a
d46dd50
3ab518d
cdd6d27
24447a4
c5e9b67
0bd6c3d
f919456
c082f00
ea0ae59
b451967
128d5cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ | |
Constant, | ||
DensityDist, | ||
Dirichlet, | ||
DirichletMultinomial, | ||
DiscreteUniform, | ||
DiscreteWeibull, | ||
ExGaussian, | ||
|
@@ -265,6 +266,21 @@ def multinomial_logpdf(value, n, p): | |
return -inf | ||
|
||
|
||
def dirichlet_multinomial_logpmf(value, n, a): | ||
value, n, a = [np.asarray(x) for x in [value, n, a]] | ||
assert value.ndim == 1 | ||
assert n.ndim == 0 | ||
assert a.shape == value.shape | ||
gammaln = scipy.special.gammaln | ||
if value.sum() == n and (0 <= value).all() and (value <= n).all(): | ||
sum_a = a.sum(axis=-1) | ||
const = gammaln(n + 1) + gammaln(sum_a) - gammaln(n + sum_a) | ||
series = gammaln(value + a) - gammaln(value + 1) - gammaln(a) | ||
return const + series.sum(axis=-1) | ||
else: | ||
return -inf | ||
|
||
|
||
def beta_mu_sigma(value, mu, sigma): | ||
kappa = mu * (1 - mu) / sigma ** 2 - 1 | ||
if kappa > 0: | ||
|
@@ -1724,6 +1740,172 @@ def test_batch_multinomial(self): | |
sample = dist.random(size=2) | ||
assert_allclose(sample, np.stack([vals, vals], axis=0)) | ||
|
||
@pytest.mark.parametrize("n", [2, 3]) | ||
def test_dirichlet_multinomial(self, n): | ||
self.pymc3_matches_scipy( | ||
DirichletMultinomial, | ||
Vector(Nat, n), | ||
{"a": Vector(Rplus, n), "n": Nat}, | ||
dirichlet_multinomial_logpmf, | ||
) | ||
|
||
def test_dirichlet_multinomial_matches_beta_binomial(self): | ||
a, b, n = 2, 1, 5 | ||
ns = np.arange(n + 1) | ||
ns_dm = np.vstack((ns, n - ns)).T # covert ns=1 to ns_dm=[1, 4], for all ns... | ||
bb_logp = pm.BetaBinomial.dist(n=n, alpha=a, beta=b).logp(ns).tag.test_value | ||
dm_logp = pm.DirichletMultinomial.dist(n=n, a=[a, b]).logp(ns_dm).tag.test_value | ||
dm_logp = dm_logp.ravel() | ||
assert_allclose(bb_logp, dm_logp) | ||
|
||
@pytest.mark.parametrize( | ||
"a, n", | ||
[ | ||
[[0.25, 0.25, 0.25, 0.25], 1], | ||
[[0.3, 0.6, 0.05, 0.05], 2], | ||
[[0.3, 0.6, 0.05, 0.05], 10], | ||
], | ||
) | ||
def test_dirichlet_multinomial_mode(self, a, n): | ||
_a = np.array(a) | ||
with Model() as model: | ||
m = DirichletMultinomial("m", n, _a, _a.shape) | ||
assert_allclose(m.distribution.mode.eval().sum(), n) | ||
_a = np.array([a, a]) | ||
with Model() as model: | ||
m = DirichletMultinomial("m", n, _a, _a.shape) | ||
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n) | ||
|
||
@pytest.mark.parametrize( | ||
"a, shape, n", | ||
[ | ||
[[0.25, 0.25, 0.25, 0.25], 4, 2], | ||
[[0.25, 0.25, 0.25, 0.25], (1, 4), 3], | ||
[[0.25, 0.25, 0.25, 0.25], (10, 4), [2] * 10], | ||
[[0.25, 0.25, 0.25, 0.25], (10, 1, 4), 5], | ||
[[[0.25, 0.25, 0.25, 0.25]], (2, 4), [7, 11]], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), 13], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (1, 2, 4), [23, 29]], | ||
[ | ||
[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], | ||
(10, 2, 4), | ||
[31, 37], | ||
], | ||
[[[0.25, 0.25, 0.25, 0.25], [0.25, 0.25, 0.25, 0.25]], (2, 4), [17, 19]], | ||
], | ||
) | ||
def test_dirichlet_multinomial_random(self, a, shape, n): | ||
a = np.asarray(a) | ||
with Model() as model: | ||
m = DirichletMultinomial("m", n=n, a=a, shape=shape) | ||
m.random() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want to increase your code coverage, you should add a One other thing, I would strongly recommend that you There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Hmm...are you sure? Running def test_dirichlet_multinomial_random(self, a, shape, n):
a = np.asarray(a)
with Model() as model:
m = DirichletMultinomial("m", n=n, a=a, shape=shape)
m.random()
m.random(size=2) # Try w/ and w/out this line That last line doesn't seem to affect my coverage at all. And the test still passes with 😕 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Something like this? if size is not None:
expect_shape = (size, *self.shape)
else:
expect_shape = self.shape
assert tuple(samples.shape) == tuple(expect_shape)
return samples |
||
|
||
def test_dirichlet_multinomial_mode_with_shape(self): | ||
n = [1, 10] | ||
a = np.asarray([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]) | ||
with Model() as model: | ||
m = DirichletMultinomial("m", n=n, a=a, shape=(2, 4)) | ||
assert_allclose(m.distribution.mode.eval().sum(axis=-1), n) | ||
|
||
def test_dirichlet_multinomial_vec(self): | ||
vals = np.array([[2, 4, 4], [3, 3, 4]]) | ||
a = np.array([0.2, 0.3, 0.5]) | ||
n = 10 | ||
|
||
with Model() as model_single: | ||
DirichletMultinomial("m", n=n, a=a, shape=len(a)) | ||
|
||
with Model() as model_many: | ||
DirichletMultinomial("m", n=n, a=a, shape=vals.shape) | ||
|
||
assert_almost_equal( | ||
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]), | ||
np.asarray([model_single.fastlogp({"m": val}) for val in vals]), | ||
decimal=4, | ||
) | ||
|
||
assert_almost_equal( | ||
np.asarray([dirichlet_multinomial_logpmf(v, n, a) for v in vals]), | ||
model_many.free_RVs[0].logp_elemwise({"m": vals}).squeeze(), | ||
decimal=4, | ||
) | ||
|
||
assert_almost_equal( | ||
sum([model_single.fastlogp({"m": val}) for val in vals]), | ||
model_many.fastlogp({"m": vals}), | ||
decimal=4, | ||
) | ||
|
||
def test_dirichlet_multinomial_vec_1d_n(self): | ||
vals = np.array([[2, 4, 4], [4, 3, 4]]) | ||
a = np.array([0.2, 0.3, 0.5]) | ||
ns = np.array([10, 11]) | ||
|
||
with Model() as model: | ||
DirichletMultinomial("m", n=ns, a=a, shape=vals.shape) | ||
|
||
assert_almost_equal( | ||
sum([dirichlet_multinomial_logpmf(val, n, a) for val, n in zip(vals, ns)]), | ||
model.fastlogp({"m": vals}), | ||
decimal=4, | ||
) | ||
|
||
def test_dirichlet_multinomial_vec_1d_n_2d_a(self): | ||
vals = np.array([[2, 4, 4], [4, 3, 4]]) | ||
as_ = np.array([[0.2, 0.3, 0.5], [0.9, 0.09, 0.01]]) | ||
ns = np.array([10, 11]) | ||
|
||
with Model() as model: | ||
DirichletMultinomial("m", n=ns, a=as_, shape=vals.shape) | ||
|
||
assert_almost_equal( | ||
sum([dirichlet_multinomial_logpmf(val, n, a) for val, n, a in zip(vals, ns, as_)]), | ||
model.fastlogp({"m": vals}), | ||
decimal=4, | ||
) | ||
|
||
def test_dirichlet_multinomial_vec_2d_a(self): | ||
vals = np.array([[2, 4, 4], [3, 3, 4]]) | ||
as_ = np.array([[0.2, 0.3, 0.5], [0.3, 0.3, 0.4]]) | ||
n = 10 | ||
|
||
with Model() as model: | ||
DirichletMultinomial("m", n=n, a=as_, shape=vals.shape) | ||
|
||
assert_almost_equal( | ||
sum([dirichlet_multinomial_logpmf(val, n, a) for val, a in zip(vals, as_)]), | ||
model.fastlogp({"m": vals}), | ||
decimal=4, | ||
) | ||
|
||
def test_batch_dirichlet_multinomial(self): | ||
# Test that DM can handle a 3d array for `a` | ||
n = 10 | ||
# Create an almost deterministic DM by setting a to 0.001, everywehere | ||
# except for one category / dimensions which is given the value fo 100 | ||
vals = np.zeros((4, 5, 3), dtype="int32") | ||
a = np.zeros_like(vals, dtype=theano.config.floatX) + 0.001 | ||
inds = np.random.randint(vals.shape[-1], size=vals.shape[:-1])[..., None] | ||
np.put_along_axis(vals, inds, n, axis=-1) | ||
np.put_along_axis(a, inds, 100, axis=-1) | ||
|
||
dist = DirichletMultinomial.dist(n=n, a=a, shape=vals.shape) | ||
|
||
# TODO: Test logp is as expected (not as simple as the Multinomial case) | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# value = tt.tensor3(dtype="int32") | ||
# value.tag.test_value = np.zeros_like(vals, dtype="int32") | ||
# logp = tt.exp(dist.logp(value)) | ||
# f = theano.function(inputs=[value], outputs=logp) | ||
# assert_almost_equal( | ||
# f(vals), | ||
# np.ones(vals.shape[:-1] + (1,)), | ||
# decimal=select_by_precision(float64=6, float32=3), | ||
# ) | ||
|
||
# Samples should be equal given the almost deterministic DM | ||
sample = dist.random(size=2) | ||
assert_allclose(sample, np.stack([vals, vals], axis=0)) | ||
|
||
def test_categorical_bounds(self): | ||
with Model(): | ||
x = Categorical("x", p=np.array([0.2, 0.3, 0.5])) | ||
|
Uh oh!
There was an error while loading. Please reload this page.