Skip to content

Probability distributions guide update #7671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions docs/source/guides/Probability_Distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,32 @@ A variable requires at least a ``name`` argument, and zero or more model paramet

Probability distributions are all subclasses of ``Distribution``, which in turn has two major subclasses: ``Discrete`` and ``Continuous``. In terms of data types, a ``Continuous`` random variable is given whichever floating point type is defined by ``pytensor.config.floatX``, while ``Discrete`` variables are given ``int16`` types when ``pytensor.config.floatX`` is ``float32``, and ``int64`` otherwise.

All distributions in ``pm.distributions`` will have two important methods: ``random()`` and ``logp()`` with the following signatures:
All distributions in ``pm.distributions`` are associated with two key functions:

1. ``logp(dist, value)`` - Calculates log-probability at given value
2. ``draw(dist, size=...)`` - Generates random samples

For example, with a normal distribution:

::

class SomeDistribution(Continuous):
with pm.Model():
x = pm.Normal('x', mu=0, sigma=1)

# Calculate log-probability
log_prob = pm.logp(x, 0.5)

# Generate samples
samples = pm.draw(x, size=100)

def random(self, point=None, size=None):
...
return random_samples
Custom distributions using ``CustomDist`` should provide logp via the ``dist`` parameter:

::

def logp(self, value):
...
return total_log_prob
def custom_logp(value, mu):
return -0.5 * (value - mu)**2

PyMC expects the ``logp()`` method to return a log-probability evaluated at the passed ``value`` argument. This method is used internally by all of the inference methods to calculate the model log-probability that is used for fitting models. The ``random()`` method is used to simulate values from the variable, and is used internally for posterior predictive checks.
custom_dist = pm.CustomDist('custom', dist=custom_logp, mu=0)


Custom distributions
Expand All @@ -58,7 +69,7 @@ An exponential survival function, where :math:`c=0` denotes failure (or non-surv
f(c, t) = \left\{ \begin{array}{l} \exp(-\lambda t), \text{if c=1} \\
\lambda \exp(-\lambda t), \text{if c=0} \end{array} \right.

Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``DensityDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.
Such a function can be implemented as a PyMC distribution by writing a function that specifies the log-probability, then passing that function as a keyword argument to the ``CustomDist`` function, which creates an instance of a PyMC distribution with the custom function as its log-probability.

For the exponential survival function, this is:

Expand All @@ -67,7 +78,7 @@ For the exponential survival function, this is:
def logp(value, t, lam):
return (value * log(lam) - lam * t).sum()

exp_surv = pm.DensityDist('exp_surv', t, lam, logp=logp, observed=failure)
exp_surv = pm.CustomDist('exp_surv', dist=logp, t=t, lam=lam, observed=failure)

Similarly, if a random number generator is required, a function returning random numbers corresponding to the probability distribution can be passed as the ``random`` argument.

Expand Down Expand Up @@ -98,10 +109,10 @@ This allows for probabilities to be calculated and random numbers to be drawn.

::

>>> y.logp(4).eval()
>>> pm.logp(y, 4).eval()
array(-1.5843639373779297, dtype=float32)

>>> y.random(size=3)
>>> pm.draw(y, size=3)
array([5, 4, 3])


Expand Down