Open
Description
The logp
of several multivariate distributions does not work (or is not tested) for arbitrarily batched dimensions. Some cases I could confirm include:
- MvNormal - Allow batched parameters in MvNormal and MvStudentT distributions #6897
- MvStudentT Allow batched parameters in MvNormal and MvStudentT distributions #6897
- KroneckerNormal
- MatrixNormal
- LKJCorr
- LKJCholeskyCov
- StickBreakingWeights (alpha) Allow for batched
alpha
inStickBreakingWeights
#6042 - Wishart (may be fine to ignore)
- CARRV (maybe better to leave this one as is)
Reproducible code:
size = (4, 3)
pm.logp(pm.MvNormal.dist(mu=np.ones(2), cov=np.eye(2), size=size), np.ones((*size, 2)))
# ValueError: Invalid dimension for value: 3
pm.logp(pm.MvStudentT.dist(nu=3, mu=np.ones(2), cov=np.eye(2), size=size), np.ones((*size, 2)))
# ValueError: Invalid dimension for value: 3
Distributions that already support (and have tests for) arbitrary shapes
- Dirichlet
- Multinomial
- DirichletMultinomial - see Tweak DirichletMultinomial logp and refactor some multivariate logp tests #5234