Skip to content

Fixed Dirichlet.random returning output of wrong shapes #4416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ It also brings some dreadfully awaited fixes, so be sure to go through the chang
- Update the `logcdf` method of several continuous distributions to return -inf for invalid parameters and values, and raise an informative error when multiple values cannot be evaluated in a single call. (see [4393](https://github.com/pymc-devs/pymc3/pull/4393))
- Improve numerical stability in `logp` and `logcdf` methods of `ExGaussian` (see [#4407](https://github.com/pymc-devs/pymc3/pull/4407))
- Issue UserWarning when doing prior or posterior predictive sampling with models containing Potential factors (see [#4419](https://github.com/pymc-devs/pymc3/pull/4419))
- Dirichlet distribution's `random` method is now optimized and gives outputs in correct shape (see [#4416](https://github.com/pymc-devs/pymc3/pull/4407))

## PyMC3 3.10.0 (7 December 2020)

Expand Down
27 changes: 4 additions & 23 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,33 +471,11 @@ def __init__(self, a, transform=transforms.stick_breaking, *args, **kwargs):

super().__init__(transform=transform, *args, **kwargs)

self.size_prefix = tuple(self.shape[:-1])
self.a = a = tt.as_tensor_variable(a)
self.mean = a / tt.sum(a)

self.mode = tt.switch(tt.all(a > 1), (a - 1) / tt.sum(a - 1), np.nan)

def _random(self, a, size=None):
gen = stats.dirichlet.rvs
shape = tuple(np.atleast_1d(self.shape))
if size[-len(shape) :] == shape:
real_size = size[: -len(shape)]
else:
real_size = size
if self.size_prefix:
if real_size and real_size[0] == 1:
real_size = real_size[1:] + self.size_prefix
else:
real_size = real_size + self.size_prefix

if a.ndim == 1:
samples = gen(alpha=a, size=real_size)
else:
unrolled = a.reshape((np.prod(a.shape[:-1]), a.shape[-1]))
samples = np.array([gen(alpha=aa, size=1) for aa in unrolled])
samples = samples.reshape(a.shape)
return samples

def random(self, point=None, size=None):
"""
Draw random values from Dirichlet distribution.
Expand All @@ -516,7 +494,10 @@ def random(self, point=None, size=None):
array
"""
a = draw_values([self.a], point=point, size=size)[0]
samples = generate_samples(self._random, a=a, dist_shape=self.shape, size=size)
output_shape = to_tuple(size) + to_tuple(self.shape)
a = broadcast_dist_samples_to(to_shape=output_shape, samples=[a], size=size)[0]
samples = stats.gamma.rvs(a=a, size=output_shape)
samples = samples / samples.sum(-1, keepdims=True)
return samples

def logp(self, value):
Expand Down
18 changes: 18 additions & 0 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,24 @@ def test_probability_vector_shape(self):
assert pm.Categorical.dist(p=p).random(size=4).shape == (4, 3, 7)


class TestDirichlet(SeededTest):
@pytest.mark.parametrize(
"shape, size",
[
((2), (1)),
((2), (2)),
((2, 2), (2, 100)),
((3, 4), (3, 4)),
((3, 4), (3, 4, 100)),
((3, 4), (100)),
((3, 4), (1)),
],
)
def test_dirichlet_random_shape(self, shape, size):
out_shape = to_tuple(size) + to_tuple(shape)
assert pm.Dirichlet.dist(a=np.ones(shape)).random(size=size).shape == out_shape


class TestScalarParameterSamples(SeededTest):
def test_bounded(self):
# A bit crude...
Expand Down