Skip to content

Refactor old Distribution base class #5308

Open
@ricardoV94

Description

@ricardoV94

PyMC distribution classes are weird objects that hold RandomVariables, logp, logcdf and moment methods together (basically doing runtime dispatching) and manage most of the non-RandomVariable kwargs that users are familiar with (observed, transformed, size/dims) and behind the scenes actions like registration in the model.

This exists mostly for backwards compatibility with V3 and ease of developer refactoring, but the current result is far from pretty.

We need to figure out a more elegant/permanent architecture now that many things that existed to accommodate V3 limitations no longer hold.

Distribution

Distribution is currently performing the following tasks:

class Distribution(metaclass=DistributionMeta):

  1. Input validation:
    1. Raising FutureWarnings for testval kwarg
    2. Raising TypeError when distribution is initialized outside of a Model context
    3. Raising TypeError when name is not given to a distribution
    4. Raising ValueError when more than one of dims/shape/size is given
  2. Convert alternative parametrizations to standard parametrization (e.g, tau -> sigma). This is done by the .dist methods.
  3. Add informative attribute errors for deprecated logp, logcdf, random methods
  4. Resize the final RV based on observed, shape, dims or size
  5. Provides the .dist() API to create an unnamed RV that is not registered in the model. This type of variables is necessary for use in Potentials and other distribution factories that use RVs as building blocks such as Bound and Censored distributions, as well as Mixtures and Timeseries once they get refactored for V4

DistributionMeta

In addition we have a DistributionMeta that does the following:

class DistributionMeta(ABCMeta):

  1. Dispatch the logp, logcdf, moment, default_transform methods defined in the old PyMC distributions to apply to the respective rv_op
  2. Register the rv_op type as subclass of the old style PyMC distribution, so that V3 Discrete/Continuous subclass checks still work?
isinstance(pm.Normal.dist().owner.op, pm.distributions.Continuous)  # True

If we want to get rid of Distribution we probably need to statically dispatch our methods to the respective rv_op. That is nothing special, and is how we do it for aeppl from the get go: https://github.com/aesara-devs/aeppl/blob/38d0c2ea4ecf8505f85317047089ab9999d2f78e/aeppl/logprob.py#L104-L130

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions