|
4 | 4 |
|
5 | 5 | from pymc3.util import get_variable_name
|
6 | 6 | from .continuous import get_tau_sigma, Normal, Flat
|
| 7 | +from .shape_utils import to_tuple |
7 | 8 | from . import multivariate
|
8 | 9 | from . import distribution
|
9 | 10 |
|
@@ -189,7 +190,10 @@ class GaussianRandomWalk(distribution.Continuous):
|
189 | 190 |
|
190 | 191 | def __init__(self, tau=None, init=Flat.dist(), sigma=None, mu=0.,
|
191 | 192 | sd=None, *args, **kwargs):
|
| 193 | + kwargs.setdefault('shape', 1) |
192 | 194 | super().__init__(*args, **kwargs)
|
| 195 | + if sum(self.shape) == 0: |
| 196 | + raise TypeError("GaussianRandomWalk must be supplied a non-zero shape argument!") |
193 | 197 | if sd is not None:
|
194 | 198 | sigma = sd
|
195 | 199 | tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
|
@@ -247,12 +251,17 @@ def random(self, point=None, size=None):
|
247 | 251 | """
|
248 | 252 | sigma, mu = distribution.draw_values([self.sigma, self.mu], point=point, size=size)
|
249 | 253 | return distribution.generate_samples(self._random, sigma=sigma, mu=mu, size=size,
|
250 |
| - dist_shape=self.shape) |
| 254 | + dist_shape=self.shape, |
| 255 | + not_broadcast_kwargs={"sample_shape": to_tuple(size)}) |
251 | 256 |
|
252 |
| - def _random(self, sigma, mu, size): |
| 257 | + def _random(self, sigma, mu, size, sample_shape): |
253 | 258 | """Implement a Gaussian random walk as a cumulative sum of normals."""
|
| 259 | + if size[len(sample_shape)] == sample_shape: |
| 260 | + axis = len(sample_shape) |
| 261 | + else: |
| 262 | + axis = 0 |
254 | 263 | rv = stats.norm(mu, sigma)
|
255 |
| - data = rv.rvs(size).cumsum(axis=0) |
| 264 | + data = rv.rvs(size).cumsum(axis=axis) |
256 | 265 | data = data - data[0] # TODO: this should be a draw from `init`, if available
|
257 | 266 | return data
|
258 | 267 |
|
|
0 commit comments