Skip to content

BUG: pm.sample_prior_predictive reports it will sample all volatile basic RVs even when var_names is supplied #7703

Closed
@nataziel

Description

@nataziel

Describe the issue:

When you pass a list of variable names to pm.sample_prior_predictive a message is logged that suggests a much larger list of variables will be sampled than what you might provide.

The function appears to correctly constrain the list of output variables passed to compile_forward_sampling_function, and the returned data seems to correctly only contain the variables you request, but the logged message suggests it will sample all volatile basic RVs.

Seems to be caused by this line:

_log.info(f"Sampling: {sorted(volatile_basic_rvs, key=lambda var: var.name)}") # type: ignore[arg-type, return-value]

Reproduceable code example:

import pymc as pm
import numpy as np


def main():
    data = [5.1, 5.2, 4.9, 4.8]
    with pm.Model() as model:
        target_value = pm.Data(name="target_y", value=data, dims=("x"))
        a = pm.Normal("a")

        y_hat = pm.Deterministic("y_hat", var=a + 5)

        y_like = pm.Normal("y_like", mu=y_hat, observed=target_value)

        my_model = model

    prior_trace = pm.sample_prior_predictive(model=my_model, var_names=["y_like"])

    print(prior_trace)


if __name__ == "__main__":
    main()

Error message:

Sampling: [a, y_like]
Inference data with groups:
        > prior_predictive
        > observed_data
        > constant_data

PyMC version information:

Currently running PyMC 5.21

Context for the issue:

low priority, but just wanted to make it known (or be told that my understanding is incorrect)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions