Skip to content

Commit 2b7f302

Browse files
authored
Merge pull request #3682 from ColCarroll/random-walk
Random walk, random method
2 parents 6c17578 + 6ebf7ba commit 2b7f302

File tree

5 files changed

+3300
-495
lines changed

5 files changed

+3300
-495
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
- Sampling from variational approximation now allows for alternative trace backends [#3550].
1313
- Infix `@` operator now works with random variables and deterministics [#3619](https://github.com/pymc-devs/pymc3/pull/3619).
1414
- [ArviZ](https://arviz-devs.github.io/arviz/) is now a requirement, and handles plotting, diagnostics, and statistical checks.
15+
- Can use GaussianRandomWalk in sample_prior_predictive and sample_prior_predictive [#3682](https://github.com/pymc-devs/pymc3/pull/3682)
16+
- Now 11 years of S&P returns in data set[#3682](https://github.com/pymc-devs/pymc3/pull/3682)
1517

1618
### Maintenance
1719
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)

docs/source/notebooks/stochastic_volatility.ipynb

Lines changed: 315 additions & 79 deletions
Large diffs are not rendered by default.

pymc3/distributions/timeseries.py

Lines changed: 61 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1+
from scipy import stats
12
import theano.tensor as tt
23
from theano import scan
34

45
from pymc3.util import get_variable_name
56
from .continuous import get_tau_sigma, Normal, Flat
7+
from .shape_utils import to_tuple
68
from . import multivariate
79
from . import distribution
810

@@ -166,33 +168,50 @@ def logp(self, value):
166168

167169

168170
class GaussianRandomWalk(distribution.Continuous):
169-
R"""
170-
Random Walk with Normal innovations
171+
R"""Random Walk with Normal innovations
171172
172173
Parameters
173174
----------
174175
mu: tensor
175176
innovation drift, defaults to 0.0
177+
For vector valued mu, first dimension must match shape of the random walk, and
178+
the first element will be discarded (since there is no innovation in the first timestep)
176179
sigma : tensor
177180
sigma > 0, innovation standard deviation (only required if tau is not specified)
181+
For vector valued sigma, first dimension must match shape of the random walk, and
182+
the first element will be discarded (since there is no innovation in the first timestep)
178183
tau : tensor
179184
tau > 0, innovation precision (only required if sigma is not specified)
185+
For vector valued tau, first dimension must match shape of the random walk, and
186+
the first element will be discarded (since there is no innovation in the first timestep)
180187
init : distribution
181188
distribution for initial value (Defaults to Flat())
182189
"""
183190

184191
def __init__(self, tau=None, init=Flat.dist(), sigma=None, mu=0.,
185192
sd=None, *args, **kwargs):
193+
kwargs.setdefault('shape', 1)
186194
super().__init__(*args, **kwargs)
195+
if sum(self.shape) == 0:
196+
raise TypeError("GaussianRandomWalk must be supplied a non-zero shape argument!")
187197
if sd is not None:
188198
sigma = sd
189199
tau, sigma = get_tau_sigma(tau=tau, sigma=sigma)
190-
self.tau = tau = tt.as_tensor_variable(tau)
191-
self.sigma = self.sd = sigma = tt.as_tensor_variable(sigma)
192-
self.mu = mu = tt.as_tensor_variable(mu)
200+
self.tau = tt.as_tensor_variable(tau)
201+
sigma = tt.as_tensor_variable(sigma)
202+
self.sigma = self.sd = sigma
203+
self.mu = tt.as_tensor_variable(mu)
193204
self.init = init
194205
self.mean = tt.as_tensor_variable(0.)
195206

207+
def _mu_and_sigma(self, mu, sigma):
208+
"""Helper to get mu and sigma if they are high dimensional."""
209+
if sigma.ndim > 0:
210+
sigma = sigma[1:]
211+
if mu.ndim > 0:
212+
mu = mu[1:]
213+
return mu, sigma
214+
196215
def logp(self, x):
197216
"""
198217
Calculate log-probability of Gaussian Random Walk distribution at specified value.
@@ -206,15 +225,45 @@ def logp(self, x):
206225
-------
207226
TensorVariable
208227
"""
209-
sigma = self.sigma
210-
mu = self.mu
211-
init = self.init
228+
if x.ndim > 0:
229+
x_im1 = x[:-1]
230+
x_i = x[1:]
231+
mu, sigma = self._mu_and_sigma(self.mu, self.sigma)
232+
innov_like = Normal.dist(mu=x_im1 + mu, sigma=sigma).logp(x_i)
233+
return self.init.logp(x[0]) + tt.sum(innov_like)
234+
return self.init.logp(x)
212235

213-
x_im1 = x[:-1]
214-
x_i = x[1:]
236+
def random(self, point=None, size=None):
237+
"""Draw random values from GaussianRandomWalk.
238+
239+
Parameters
240+
----------
241+
point : dict, optional
242+
Dict of variable values on which random values are to be
243+
conditioned (uses default point if not specified).
244+
size : int, optional
245+
Desired size of random sample (returns one sample if not
246+
specified).
215247
216-
innov_like = Normal.dist(mu=x_im1 + mu, sigma=sigma).logp(x_i)
217-
return init.logp(x[0]) + tt.sum(innov_like)
248+
Returns
249+
-------
250+
array
251+
"""
252+
sigma, mu = distribution.draw_values([self.sigma, self.mu], point=point, size=size)
253+
return distribution.generate_samples(self._random, sigma=sigma, mu=mu, size=size,
254+
dist_shape=self.shape,
255+
not_broadcast_kwargs={"sample_shape": to_tuple(size)})
256+
257+
def _random(self, sigma, mu, size, sample_shape):
258+
"""Implement a Gaussian random walk as a cumulative sum of normals."""
259+
if size[len(sample_shape)] == sample_shape:
260+
axis = len(sample_shape)
261+
else:
262+
axis = 0
263+
rv = stats.norm(mu, sigma)
264+
data = rv.rvs(size).cumsum(axis=axis)
265+
data = data - data[0] # TODO: this should be a draw from `init`, if available
266+
return data
218267

219268
def _repr_latex_(self, name=None, dist=None):
220269
if dist is None:

0 commit comments

Comments
 (0)