Skip to content

Commit a607709

Browse files
authored
Fixed Dirichlet.random returning output in wrong shapes (#4416)
1 parent 4e2c099 commit a607709

File tree

3 files changed

+23
-23
lines changed

3 files changed

+23
-23
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
3535
- Update the `logcdf` method of several continuous distributions to return -inf for invalid parameters and values, and raise an informative error when multiple values cannot be evaluated in a single call. (see [4393](https://github.com/pymc-devs/pymc3/pull/4393))
3636
- Improve numerical stability in `logp` and `logcdf` methods of `ExGaussian` (see [#4407](https://github.com/pymc-devs/pymc3/pull/4407))
3737
- Issue UserWarning when doing prior or posterior predictive sampling with models containing Potential factors (see [#4419](https://github.com/pymc-devs/pymc3/pull/4419))
38+
- Dirichlet distribution's `random` method is now optimized and gives outputs in correct shape (see [#4416](https://github.com/pymc-devs/pymc3/pull/4407))
3839

3940
## PyMC3 3.10.0 (7 December 2020)
4041

pymc3/distributions/multivariate.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -471,33 +471,11 @@ def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs):
471471

472472
super().__init__(transform=transform, *args, **kwargs)
473473

474-
self.size_prefix = tuple(self.shape[:-1])
475474
self.a = a = tt.as_tensor_variable(a)
476475
self.mean = a / tt.sum(a)
477476

478477
self.mode = tt.switch(tt.all(a > 1), (a - 1) / tt.sum(a - 1), np.nan)
479478

480-
def _random(self, a, size=None):
481-
gen = stats.dirichlet.rvs
482-
shape = tuple(np.atleast_1d(self.shape))
483-
if size[-len(shape) :] == shape:
484-
real_size = size[: -len(shape)]
485-
else:
486-
real_size = size
487-
if self.size_prefix:
488-
if real_size and real_size[0] == 1:
489-
real_size = real_size[1:] + self.size_prefix
490-
else:
491-
real_size = real_size + self.size_prefix
492-
493-
if a.ndim == 1:
494-
samples = gen(alpha=a, size=real_size)
495-
else:
496-
unrolled = a.reshape((np.prod(a.shape[:-1]), a.shape[-1]))
497-
samples = np.array([gen(alpha=aa, size=1) for aa in unrolled])
498-
samples = samples.reshape(a.shape)
499-
return samples
500-
501479
def random(self, point=None, size=None):
502480
"""
503481
Draw random values from Dirichlet distribution.
@@ -516,7 +494,10 @@ def random(self, point=None, size=None):
516494
array
517495
"""
518496
a = draw_values([self.a], point=point, size=size)[0]
519-
samples = generate_samples(self._random, a=a, dist_shape=self.shape, size=size)
497+
output_shape = to_tuple(size) + to_tuple(self.shape)
498+
a = broadcast_dist_samples_to(to_shape=output_shape, samples=[a], size=size)[0]
499+
samples = stats.gamma.rvs(a=a, size=output_shape)
500+
samples = samples / samples.sum(-1, keepdims=True)
520501
return samples
521502

522503
def logp(self, value):

pymc3/tests/test_distributions_random.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,24 @@ def test_probability_vector_shape(self):
542542
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 3, 7)
543543

544544

545+
class TestDirichlet(SeededTest):
546+
@pytest.mark.parametrize(
547+
"shape, size",
548+
[
549+
((2), (1)),
550+
((2), (2)),
551+
((2, 2), (2, 100)),
552+
((3, 4), (3, 4)),
553+
((3, 4), (3, 4, 100)),
554+
((3, 4), (100)),
555+
((3, 4), (1)),
556+
],
557+
)
558+
def test_dirichlet_random_shape(self, shape, size):
559+
out_shape = to_tuple(size) + to_tuple(shape)
560+
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape
561+
562+
545563
class TestScalarParameterSamples(SeededTest):
546564
def test_bounded(self):
547565
# A bit crude...

0 commit comments

Comments
 (0)