Skip to content

Commit 55394ba

Browse files
committed
Additional edits
1 parent c267885 commit 55394ba

File tree

2 files changed

+474
-304
lines changed

2 files changed

+474
-304
lines changed

examples/time_series/Euler-Maruyama_and_SDEs.ipynb

Lines changed: 346 additions & 277 deletions
Large diffs are not rendered by default.

examples/time_series/Euler-Maruyama_and_SDEs.myst.md

Lines changed: 128 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,12 @@ run_control:
3232
slideshow:
3333
slide_type: '-'
3434
---
35-
import warnings
36-
3735
import arviz as az
3836
import matplotlib.pyplot as plt
3937
import numpy as np
4038
import pymc as pm
4139
import pytensor.tensor as pt
4240
import scipy as sp
43-
44-
# Ignore UserWarnings
45-
warnings.filterwarnings("ignore", category=UserWarning)
46-
47-
RANDOM_SEED = 8927
48-
np.random.seed(RANDOM_SEED)
4941
```
5042

5143
```{code-cell} ipython3
@@ -104,19 +96,16 @@ run_control:
10496
slideshow:
10597
slide_type: subslide
10698
---
107-
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))
108-
109-
ax1.plot(x_t[:30], "k", label="$x(t)$", alpha=0.5)
110-
ax1.plot(z_t[:30], "r", label="$z(t)$", alpha=0.5)
111-
ax1.set_title("Transient")
112-
ax1.legend()
113-
114-
ax2.plot(x_t[30:], "k", label="$x(t)$", alpha=0.5)
115-
ax2.plot(z_t[30:], "r", label="$z(t)$", alpha=0.5)
116-
ax2.set_title("All time")
117-
ax2.legend()
118-
119-
plt.tight_layout()
99+
plt.figure(figsize=(10, 3))
100+
plt.plot(x_t[:30], "k", label="$x(t)$", alpha=0.5)
101+
plt.plot(z_t[:30], "r", label="$z(t)$", alpha=0.5)
102+
plt.title("Transient")
103+
plt.legend()
104+
plt.subplot(122)
105+
plt.plot(x_t[30:], "k", label="$x(t)$", alpha=0.5)
106+
plt.plot(z_t[30:], "r", label="$z(t)$", alpha=0.5)
107+
plt.title("All time")
108+
plt.legend();
120109
```
121110

122111
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
@@ -134,7 +123,7 @@ new_sheet: false
134123
run_control:
135124
read_only: false
136125
---
137-
def lin_sde(x, lam, s2):
126+
def lin_sde(x, lam):
138127
return lam * x, s2
139128
```
140129

@@ -155,12 +144,11 @@ slideshow:
155144
---
156145
with pm.Model() as model:
157146
# uniform prior, but we know it must be negative
158-
l = pm.HalfCauchy("l", beta=1)
159-
s = pm.Uniform("s", 0.005, 0.5)
147+
l = pm.Flat("l")
160148
161149
# "hidden states" following a linear SDE distribution
162150
# parametrized by time step (det. variable) and lam (random variable)
163-
xh = pm.EulerMaruyama("xh", dt=dt, sde_fn=lin_sde, sde_pars=(-l, s**2), shape=N, initval=x_t)
151+
xh = pm.EulerMaruyama("xh", dt=dt, sde_fn=lin_sde, sde_pars=(l,), shape=N)
164152
165153
# predicted observation
166154
zh = pm.Normal("zh", mu=xh, sigma=5e-3, observed=z_t)
@@ -178,7 +166,7 @@ run_control:
178166
read_only: false
179167
---
180168
with model:
181-
trace = pm.sample(nuts_sampler="nutpie", random_seed=RANDOM_SEED, target_accept=0.99)
169+
trace = pm.sample()
182170
```
183171

184172
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
@@ -197,7 +185,7 @@ plt.plot(x_t, "r", label="$x(t)$")
197185
plt.legend()
198186
199187
plt.subplot(122)
200-
plt.hist(-1 * az.extract(trace.posterior)["l"], 30, label=r"$\hat{\lambda}$", alpha=0.5)
188+
plt.hist(az.extract(trace.posterior)["l"], 30, label=r"$\hat{\lambda}$", alpha=0.5)
201189
plt.axvline(lam, color="r", label=r"$\lambda$", alpha=0.5)
202190
plt.legend();
203191
```
@@ -230,6 +218,119 @@ plt.legend();
230218

231219
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
232220

221+
Note that
222+
223+
- inference also estimates the initial conditions
224+
- the observed data $z(t)$ lies fully within the 95% interval of the PPC.
225+
- there are many other ways of evaluating fit
226+
227+
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "slide"}}
228+
229+
### Toy model 2
230+
231+
As the next model, let's use a 2D deterministic oscillator,
232+
\begin{align}
233+
\dot{x} &= \tau (x - x^3/3 + y) \\
234+
\dot{y} &= \frac{1}{\tau} (a - x)
235+
\end{align}
236+
237+
with noisy observation $z(t) = m x + (1 - m) y + N(0, 0.05)$.
238+
239+
```{code-cell} ipython3
240+
N, tau, a, m, s2 = 200, 3.0, 1.05, 0.2, 1e-1
241+
xs, ys = [0.0], [1.0]
242+
for i in range(N):
243+
x, y = xs[-1], ys[-1]
244+
dx = tau * (x - x**3.0 / 3.0 + y)
245+
dy = (1.0 / tau) * (a - x)
246+
xs.append(x + dt * dx + np.sqrt(dt) * s2 * np.random.randn())
247+
ys.append(y + dt * dy + np.sqrt(dt) * s2 * np.random.randn())
248+
xs, ys = np.array(xs), np.array(ys)
249+
zs = m * xs + (1 - m) * ys + np.random.randn(xs.size) * 0.1
250+
251+
plt.figure(figsize=(10, 2))
252+
plt.plot(xs, label="$x(t)$")
253+
plt.plot(ys, label="$y(t)$")
254+
plt.plot(zs, label="$z(t)$")
255+
plt.legend()
256+
```
257+
258+
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
259+
260+
Now, estimate the hidden states $x(t)$ and $y(t)$, as well as parameters $\tau$, $a$ and $m$.
261+
262+
As before, we rewrite our SDE as a function returned drift & diffusion coefficients:
263+
264+
```{code-cell} ipython3
265+
---
266+
button: false
267+
new_sheet: false
268+
run_control:
269+
read_only: false
270+
---
271+
def osc_sde(xy, tau, a):
272+
x, y = xy[:, 0], xy[:, 1]
273+
dx = tau * (x - x**3.0 / 3.0 + y)
274+
dy = (1.0 / tau) * (a - x)
275+
dxy = pt.stack([dx, dy], axis=0).T
276+
return dxy, s2
277+
```
278+
279+
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
280+
281+
As before, the Euler-Maruyama discretization of the SDE is written as a prediction of the state at step $i+1$ based on the state at step $i$.
282+
283+
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}, "slideshow": {"slide_type": "subslide"}}
284+
285+
We can now write our statistical model as before, with uninformative priors on $\tau$, $a$ and $m$:
286+
287+
```{code-cell} ipython3
288+
---
289+
button: false
290+
new_sheet: false
291+
run_control:
292+
read_only: false
293+
---
294+
xys = np.c_[xs, ys]
295+
296+
with pm.Model() as model:
297+
tau_h = pm.Uniform("tau_h", lower=0.1, upper=5.0)
298+
a_h = pm.Uniform("a_h", lower=0.5, upper=1.5)
299+
m_h = pm.Uniform("m_h", lower=0.0, upper=1.0)
300+
xy_h = pm.EulerMaruyama(
301+
"xy_h", dt=dt, sde_fn=osc_sde, sde_pars=(tau_h, a_h), shape=xys.shape, initval=xys
302+
)
303+
zh = pm.Normal("zh", mu=m_h * xy_h[:, 0] + (1 - m_h) * xy_h[:, 1], sigma=0.1, observed=zs)
304+
```
305+
306+
```{code-cell} ipython3
307+
pm.__version__
308+
```
309+
310+
```{code-cell} ipython3
311+
---
312+
button: false
313+
new_sheet: false
314+
run_control:
315+
read_only: false
316+
---
317+
with model:
318+
pm.sample_posterior_predictive(trace, extend_inferencedata=True)
319+
```
320+
321+
```{code-cell} ipython3
322+
plt.figure(figsize=(10, 3))
323+
plt.plot(
324+
trace.posterior_predictive.quantile((0.025, 0.975), dim=("chain", "draw"))["zh"].values.T,
325+
"k",
326+
label=r"$z_{95\% PP}(t)$",
327+
)
328+
plt.plot(z_t, "r", label="$z(t)$")
329+
plt.legend();
330+
```
331+
332+
+++ {"button": false, "new_sheet": false, "run_control": {"read_only": false}}
333+
233334
Note that the initial conditions are also estimated, and that most of the observed data $z(t)$ lies within the 95% interval of the PPC.
234335

235336
Another approach is to look at draws from the sampling distribution of the data relative to the observed data. This too shows a good fit across the range of observations -- the posterior predictive mean almost perfectly tracks the data.

0 commit comments

Comments
 (0)