Skip to content

Commit 438217f

Browse files
Add a hack to statically infer Dirichlet argument shapes
1 parent a115eec commit 438217f

File tree

2 files changed

+25
-0
lines changed

2 files changed

+25
-0
lines changed

pymc3/distributions/multivariate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from scipy import stats, linalg
2525

26+
from theano.gof.op import get_test_value
2627
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
2728
from theano.tensor.slinalg import Cholesky
2829
import pymc3 as pm
@@ -486,6 +487,21 @@ class Dirichlet(Continuous):
486487

487488
def __init__(self, a, transform=transforms.stick_breaking,
488489
*args, **kwargs):
490+
491+
if kwargs.get('shape') is None:
492+
warnings.warn(
493+
(
494+
"Shape not explicitly set. "
495+
"Please, set the value using the `shape` keyword argument. "
496+
"Using the test value to infer the shape."
497+
),
498+
DeprecationWarning
499+
)
500+
try:
501+
kwargs['shape'] = get_test_value(tt.shape(a))
502+
except AttributeError:
503+
pass
504+
489505
super().__init__(transform=transform, *args, **kwargs)
490506

491507
self.size_prefix = tuple(self.shape[:-1])

pymc3/tests/test_distributions.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,6 +1328,15 @@ def test_dirichlet(self, n):
13281328
Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf
13291329
)
13301330

1331+
def test_dirichlet_shape(self):
1332+
a = tt.as_tensor_variable(np.r_[1, 2])
1333+
with pytest.warns(DeprecationWarning):
1334+
dir_rv = Dirichlet.dist(a)
1335+
assert dir_rv.shape == (2,)
1336+
1337+
with pytest.warns(DeprecationWarning), theano.change_flags(compute_test_value="ignore"):
1338+
dir_rv = Dirichlet.dist(tt.vector())
1339+
13311340
def test_dirichlet_2D(self):
13321341
self.pymc3_matches_scipy(
13331342
Dirichlet,

0 commit comments

Comments
 (0)