Skip to content

Marginalization is reset by freeze_dims_and_data  #383

Closed
@jessegrabowski

Description

@jessegrabowski

Wasn't sure which repo this belongs in. If you marginalize a discrete variable with MarginalModel then call freeze_dims_and_data, the marginalization is undone:

import pymc as pm
from pymc_experimental import MarginalModel
from pymc.model.transform.optimization import freeze_dims_and_data
import pytensor.tensor as pt

with MarginalModel() as m:
    p = pm.Beta('p', 1, 1)
    idx = pm.Bernoulli('idx', p=p, size=(100,))
    mu = pm.Normal('mu', 0, [1, 100])
    x = pm.Normal('x', pm.math.switch(pt.eq(idx, 0) , mu[0], mu[1]), 1)

m.marginal(['idx'])
pm.inputvars(m.logp())   # [p_logodds__, mu, x]

pm.inputvars(freeze_dims_and_data(m).logp())  # Raises ValueError: Random variables detected in the logp graph
Full Traceback
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[19], line 1
----> 1 pm.inputvars(freeze_dims_and_data(m).logp())

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum)
    740 rv_logps: list[TensorVariable] = []
    741 if rvs:
--> 742     rv_logps = transformed_conditional_logp(
    743         rvs=rvs,
    744         rvs_to_values=self.rvs_to_values,
    745         rvs_to_transforms=self.rvs_to_transforms,
    746         jacobian=jacobian,
    747     )
    748     assert isinstance(rv_logps, list)
    750 # Replace random variables by their value variables in potential terms

File ~/mambaforge/envs/readystate-bonds/lib/python3.11/site-packages/pymc/logprob/basic.py:630, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    628 rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logp_terms_list)
    629 if rvs_in_logp_expressions:
--> 630     raise ValueError(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions)
    632 return logp_terms_list

ValueError: Random variables detected in the logp graph: {bernoulli_rv{"()->()"}.out}.
This can happen when DensityDist logp or Interval transform functions reference nonlocal variables,
or when not all rvs have a corresponding value variable.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions