Skip to content

Increase support for batched multivariate distributions #5383

Open
@ricardoV94

Description

@ricardoV94

The logp of several multivariate distributions does not work (or is not tested) for arbitrarily batched dimensions. Some cases I could confirm include:

Reproducible code:

size = (4, 3)
pm.logp(pm.MvNormal.dist(mu=np.ones(2), cov=np.eye(2), size=size), np.ones((*size, 2)))
# ValueError: Invalid dimension for value: 3
pm.logp(pm.MvStudentT.dist(nu=3, mu=np.ones(2), cov=np.eye(2), size=size), np.ones((*size, 2)))
# ValueError: Invalid dimension for value: 3

Distributions that already support (and have tests for) arbitrary shapes

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions