Skip to content

Commit f07c273

Browse files
Remove Dirichlet distribution type restrictions (#4000)
* Remove Dirichlet distribution type restrictions Closes #3999. * Add missing Dirichlet shape parameters to tests * Remove Dirichlet positive concentration parameter constructor tests This test can't be performed in the constructor if we're allowing Theano-type distribution parameters. * Add a hack to statically infer Dirichlet argument shapes Co-authored-by: Brandon T. Willard <[email protected]>
1 parent b2c682e commit f07c273

File tree

5 files changed

+41
-43
lines changed

5 files changed

+41
-43
lines changed

pymc3/distributions/multivariate.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from scipy import stats, linalg
2525

26+
from theano.gof.op import get_test_value
2627
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
2728
from theano.tensor.slinalg import Cholesky
2829
import pymc3 as pm
@@ -487,22 +488,23 @@ class Dirichlet(Continuous):
487488
def __init__(self, a, transform=transforms.stick_breaking,
488489
*args, **kwargs):
489490

490-
if not isinstance(a, pm.model.TensorVariable):
491-
if not isinstance(a, list) and not isinstance(a, np.ndarray):
492-
raise TypeError(
493-
'The vector of concentration parameters (a) must be a python list '
494-
'or numpy array.')
495-
a = np.array(a)
496-
if (a <= 0).any():
497-
raise ValueError("All concentration parameters (a) must be > 0.")
498-
499-
shape = np.atleast_1d(a.shape)[-1]
491+
if kwargs.get('shape') is None:
492+
warnings.warn(
493+
(
494+
"Shape not explicitly set. "
495+
"Please, set the value using the `shape` keyword argument. "
496+
"Using the test value to infer the shape."
497+
),
498+
DeprecationWarning
499+
)
500+
try:
501+
kwargs['shape'] = get_test_value(tt.shape(a))
502+
except AttributeError:
503+
pass
500504

501-
kwargs.setdefault("shape", shape)
502505
super().__init__(transform=transform, *args, **kwargs)
503506

504507
self.size_prefix = tuple(self.shape[:-1])
505-
self.k = tt.as_tensor_variable(shape)
506508
self.a = a = tt.as_tensor_variable(a)
507509
self.mean = a / tt.sum(a)
508510

@@ -569,14 +571,13 @@ def logp(self, value):
569571
-------
570572
TensorVariable
571573
"""
572-
k = self.k
573574
a = self.a
574575

575576
# only defined for sum(value) == 1
576577
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
577578
+ gammaln(tt.sum(a, axis=-1)),
578579
tt.all(value >= 0), tt.all(value <= 1),
579-
k > 1, tt.all(a > 0),
580+
np.logical_not(a.broadcastable), tt.all(a > 0),
580581
broadcast_conditions=False)
581582

582583
def _repr_latex_(self, name=None, dist=None):

pymc3/tests/test_dist_math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,11 @@ def test_multinomial_bound():
126126
n = x.sum()
127127

128128
with pm.Model() as modelA:
129-
p_a = pm.Dirichlet('p', floatX(np.ones(2)))
129+
p_a = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
130130
MultinomialA('x', n, p_a, observed=x)
131131

132132
with pm.Model() as modelB:
133-
p_b = pm.Dirichlet('p', floatX(np.ones(2)))
133+
p_b = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
134134
MultinomialB('x', n, p_b, observed=x)
135135

136136
assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}),

pymc3/tests/test_distributions.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,17 +1328,14 @@ def test_dirichlet(self, n):
13281328
Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf
13291329
)
13301330

1331-
@pytest.mark.parametrize("n", [3, 4])
1332-
def test_dirichlet_init_fail(self, n):
1333-
with Model():
1334-
with pytest.raises(
1335-
ValueError, match=r"All concentration parameters \(a\) must be > 0."
1336-
):
1337-
_ = Dirichlet("x", a=np.zeros(n), shape=n)
1338-
with pytest.raises(
1339-
ValueError, match=r"All concentration parameters \(a\) must be > 0."
1340-
):
1341-
_ = Dirichlet("x", a=np.array([-1.0] * n), shape=n)
1331+
def test_dirichlet_shape(self):
1332+
a = tt.as_tensor_variable(np.r_[1, 2])
1333+
with pytest.warns(DeprecationWarning):
1334+
dir_rv = Dirichlet.dist(a)
1335+
assert dir_rv.shape == (2,)
1336+
1337+
with pytest.warns(DeprecationWarning), theano.change_flags(compute_test_value="ignore"):
1338+
dir_rv = Dirichlet.dist(tt.vector())
13421339

13431340
def test_dirichlet_2D(self):
13441341
self.pymc3_matches_scipy(

pymc3/tests/test_distributions_random.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -912,15 +912,15 @@ def test_mixture_random_shape():
912912
nr.poisson(9, size=10)])
913913
with pm.Model() as m:
914914
comp0 = pm.Poisson.dist(mu=np.ones(2))
915-
w0 = pm.Dirichlet('w0', a=np.ones(2))
915+
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
916916
like0 = pm.Mixture('like0',
917917
w=w0,
918918
comp_dists=comp0,
919919
observed=y)
920920

921921
comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
922922
shape=(20, 2))
923-
w1 = pm.Dirichlet('w1', a=np.ones(2))
923+
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
924924
like1 = pm.Mixture('like1',
925925
w=w1,
926926
comp_dists=comp1,
@@ -967,15 +967,15 @@ def test_mixture_random_shape_fast():
967967
nr.poisson(9, size=10)])
968968
with pm.Model() as m:
969969
comp0 = pm.Poisson.dist(mu=np.ones(2))
970-
w0 = pm.Dirichlet('w0', a=np.ones(2))
970+
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
971971
like0 = pm.Mixture('like0',
972972
w=w0,
973973
comp_dists=comp0,
974974
observed=y)
975975

976976
comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
977977
shape=(20, 2))
978-
w1 = pm.Dirichlet('w1', a=np.ones(2))
978+
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
979979
like1 = pm.Mixture('like1',
980980
w=w1,
981981
comp_dists=comp1,

pymc3/tests/test_mixture.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def test_dimensions(self):
7979

8080
def test_mixture_list_of_normals(self):
8181
with Model() as model:
82-
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
82+
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
8383
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
8484
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
8585
Mixture('x_obs', w,
@@ -98,7 +98,7 @@ def test_mixture_list_of_normals(self):
9898

9999
def test_normal_mixture(self):
100100
with Model() as model:
101-
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
101+
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
102102
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
103103
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
104104
NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)
@@ -135,7 +135,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
135135
with Model() as model0:
136136
mus = Normal('mus', shape=comp_shape)
137137
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
138-
ws = Dirichlet('ws', np.ones(ncomp))
138+
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
139139
mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd,
140140
comp_shape=comp_shape)
141141
obs0 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd,
@@ -145,7 +145,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
145145
with Model() as model1:
146146
mus = Normal('mus', shape=comp_shape)
147147
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
148-
ws = Dirichlet('ws', np.ones(ncomp))
148+
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
149149
comp_dist = [Normal.dist(mu=mus[..., i], tau=taus[..., i],
150150
shape=nd)
151151
for i in range(ncomp)]
@@ -163,7 +163,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
163163
# comp_dists.
164164
mus = Normal('mus', shape=comp_shape)
165165
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
166-
ws = Dirichlet('ws', np.ones(ncomp))
166+
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
167167
if len(nd) > 1:
168168
if nd[-1] != ncomp:
169169
with pytest.raises(ValueError):
@@ -208,7 +208,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
208208

209209
def test_poisson_mixture(self):
210210
with Model() as model:
211-
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
211+
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
212212
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
213213
Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)
214214
step = Metropolis()
@@ -224,7 +224,7 @@ def test_poisson_mixture(self):
224224

225225
def test_mixture_list_of_poissons(self):
226226
with Model() as model:
227-
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
227+
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
228228
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
229229
Mixture('x_obs', w,
230230
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
@@ -247,7 +247,7 @@ def test_mixture_of_mvn(self):
247247
cov2 = np.diag([2.5, 3.5])
248248
obs = np.asarray([[.5, .5], mu1, mu2])
249249
with Model() as model:
250-
w = Dirichlet('w', floatX(np.ones(2)), transform=None)
250+
w = Dirichlet('w', floatX(np.ones(2)), transform=None, shape=(2,))
251251
mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
252252
mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
253253
y = Mixture('x_obs', w, [mvncomp1, mvncomp2],
@@ -291,13 +291,13 @@ def test_mixture_of_mixture(self):
291291
sigma=1,
292292
shape=nbr)
293293
# weight vector for the mixtures
294-
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
295-
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
294+
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
295+
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
296296
# mixture components
297297
g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
298298
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
299299
# mixture of mixtures
300-
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None)
300+
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None, shape=(2,))
301301
mix = Mixture('mix', w=mix_w,
302302
comp_dists=[g_mix, l_mix],
303303
observed=np.exp(self.norm_x))
@@ -378,7 +378,7 @@ def build_toy_dataset(N, K):
378378
X, y = build_toy_dataset(N, K)
379379

380380
with pm.Model() as model:
381-
pi = pm.Dirichlet('pi', np.ones(K))
381+
pi = pm.Dirichlet('pi', np.ones(K), shape=(K,))
382382

383383
comp_dist = []
384384
mu = []

0 commit comments

Comments
 (0)