Skip to content

Commit faebc60

Browse files
authored
Update docstrings of pm.set_data and model.Data (#6087)
* To explain how to avoid shape errors when doing posterior predictive sampling * Rewrite docstring for pm.set_data, fix other comments.
1 parent 2f0af64 commit faebc60

File tree

2 files changed

+55
-17
lines changed

2 files changed

+55
-17
lines changed

pymc/data.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,12 @@ def Data(
592592
:func:`pymc.set_data`.
593593
594594
To set the value of the data container variable, check out
595-
:func:`pymc.Model.set_data`.
595+
:meth:`pymc.Model.set_data`.
596+
597+
When making predictions or doing posterior predictive sampling, the shape of the
598+
registered data variable will most likely need to be changed. If you encounter an
599+
Aesara shape mismatch error, refer to the documentation for
600+
:meth:`pymc.model.set_data`.
596601
597602
For more information, read the notebook :ref:`nb:data_container`.
598603

pymc/model.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1847,7 +1847,9 @@ def point_logps(self, point=None, round_vals=2):
18471847

18481848

18491849
def set_data(new_data, model=None, *, coords=None):
1850-
"""Sets the value of one or more data container variables.
1850+
"""Sets the value of one or more data container variables. Note that the shape is also
1851+
dynamic, it is updated when the value is changed. See the examples below for two common
1852+
use-cases that take advantage of this behavior.
18511853
18521854
Parameters
18531855
----------
@@ -1860,25 +1862,56 @@ def set_data(new_data, model=None, *, coords=None):
18601862
Examples
18611863
--------
18621864
1863-
.. code:: ipython
1865+
This example shows how to change the shape of the likelihood to correspond automatically with
1866+
`x`, the predictor in a regression model.
18641867
1865-
>>> import pymc as pm
1866-
>>> with pm.Model() as model:
1867-
... x = pm.MutableData('x', [1., 2., 3.])
1868-
... y = pm.MutableData('y', [1., 2., 3.])
1869-
... beta = pm.Normal('beta', 0, 1)
1870-
... obs = pm.Normal('obs', x * beta, 1, observed=y)
1871-
... idata = pm.sample(1000, tune=1000)
1868+
.. code-block:: python
1869+
1870+
import pymc as pm
1871+
1872+
with pm.Model() as model:
1873+
x = pm.MutableData('x', [1., 2., 3.])
1874+
y = pm.MutableData('y', [1., 2., 3.])
1875+
beta = pm.Normal('beta', 0, 1)
1876+
obs = pm.Normal('obs', x * beta, 1, observed=y, shape=x.shape)
1877+
idata = pm.sample()
1878+
1879+
Then change the value of `x` to predict on new data.
1880+
1881+
.. code-block:: python
1882+
1883+
with model:
1884+
pm.set_data({'x': [5., 6., 9., 12., 15.]})
1885+
y_test = pm.sample_posterior_predictive(idata)
1886+
1887+
print(y_test.posterior_predictive['obs'].mean(('chain', 'draw')))
18721888
1873-
Set the value of `x` to predict on new data.
1889+
>>> array([4.6088569 , 5.54128318, 8.32953844, 11.14044852, 13.94178173])
18741890
1875-
.. code:: ipython
1891+
This example shows how to reuse the same model without recompiling on a new data set. The
1892+
shape of the likelihood, `obs`, automatically tracks the shape of the observed data, `y`.
1893+
1894+
.. code-block:: python
1895+
1896+
import numpy as np
1897+
import pymc as pm
1898+
1899+
rng = np.random.default_rng()
1900+
data = rng.normal(loc=1.0, scale=2.0, size=100)
1901+
1902+
with pm.Model() as model:
1903+
y = pm.MutableData('y', data)
1904+
theta = pm.Normal('theta', mu=0.0, sigma=10.0)
1905+
obs = pm.Normal('obs', theta, 2.0, observed=y, shape=y.shape)
1906+
idata = pm.sample()
1907+
1908+
Now update the model with a new data set.
1909+
1910+
.. code-block:: python
18761911
1877-
>>> with model:
1878-
... pm.set_data({'x': [5., 6., 9.]})
1879-
... y_test = pm.sample_posterior_predictive(idata)
1880-
>>> y_test.posterior_predictive['obs'].mean(('chain', 'draw'))
1881-
array([4.6088569 , 5.54128318, 8.32953844])
1912+
with model:
1913+
pm.set_data({'y': rng.normal(loc=1.0, scale=2.0, size=200)})
1914+
idata = pm.sample()
18821915
"""
18831916
model = modelcontext(model)
18841917

0 commit comments

Comments
 (0)