Skip to content

Commit 6ebf7ba

Browse files
committed
Aggregate along the timeseries axis
1 parent b3388a9 commit 6ebf7ba

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

pymc3/distributions/timeseries.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pymc3.util import get_variable_name
66
from .continuous import get_tau_sigma, Normal, Flat
7+
from .shape_utils import to_tuple
78
from . import multivariate
89
from . import distribution
910

@@ -189,7 +190,10 @@ class GaussianRandomWalk(distribution.Continuous):
189190

190191
def __init__(self, tau=None, init=Flat.dist(), sigma=None, mu=0.,
191192
sd=None, *args, **kwargs):
193+
kwargs.setdefault('shape', 1)
192194
super().__init__(*args, **kwargs)
195+
if sum(self.shape) == 0:
196+
raise TypeError("GaussianRandomWalk must be supplied a non-zero shape argument!")
193197
if sd is not None:
194198
sigma = sd
195199
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
@@ -247,12 +251,17 @@ def random(self, point=None, size=None):
247251
"""
248252
sigma, mu = distribution.draw_values([self.sigma, self.mu], point=point, size=size)
249253
return distribution.generate_samples(self._random, sigma=sigma, mu=mu, size=size,
250-
dist_shape=self.shape)
254+
dist_shape=self.shape,
255+
not_broadcast_kwargs={"sample_shape": to_tuple(size)})
251256

252-
def _random(self, sigma, mu, size):
257+
def _random(self, sigma, mu, size, sample_shape):
253258
"""Implement a Gaussian random walk as a cumulative sum of normals."""
259+
if size[len(sample_shape)] == sample_shape:
260+
axis = len(sample_shape)
261+
else:
262+
axis = 0
254263
rv = stats.norm(mu, sigma)
255-
data = rv.rvs(size).cumsum(axis=0)
264+
data = rv.rvs(size).cumsum(axis=axis)
256265
data = data - data[0] # TODO: this should be a draw from `init`, if available
257266
return data
258267

pymc3/tests/test_distributions_random.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,12 @@ def get_random_variable(self, shape, with_vector_params=False, name=None):
174174
if shape is None:
175175
return self.distribution(name, transform=None, **params)
176176
else:
177-
return self.distribution(name, shape=shape, transform=None, **params)
177+
try:
178+
return self.distribution(name, shape=shape, transform=None, **params)
179+
except TypeError:
180+
if np.sum(np.atleast_1d(shape)) == 0:
181+
pytest.skip("Timeseries must have positive shape")
182+
raise
178183

179184
@staticmethod
180185
def sample_random_variable(random_variable, size):

0 commit comments

Comments
 (0)