Skip to content

Add return type overload for sample_posterior_predictive #7710

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

nataziel
Copy link
Contributor

@nataziel nataziel commented Mar 4, 2025

add return type overload for sample_posterior_predictive

Description

Added overloads to sample_prior_predictive to signal to type checkers the return type of the function, based on the return_inferencedata boolean parameter.

Related Issue

  • Closes #
  • Related to #

Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks)
  • If you are a pro: each commit corresponds to a [relevant logical change]

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7710.org.readthedocs.build/en/7710/

Copy link

codecov bot commented Mar 4, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.66%. Comparing base (cc90212) to head (76c3ff0).
Report is 8 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7710      +/-   ##
==========================================
- Coverage   92.67%   92.66%   -0.01%     
==========================================
  Files         107      107              
  Lines       18329    18333       +4     
==========================================
+ Hits        16986    16989       +3     
- Misses       1343     1344       +1     
Files with missing lines Coverage Δ
pymc/sampling/forward.py 96.38% <100.00%> (+0.05%) ⬆️

... and 1 file with indirect coverage changes

🚀 New features to boost your workflow:
  • Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 changed the title add return type overload for sample_posterior_predictive Add return type overload for sample_posterior_predictive Mar 4, 2025
@ricardoV94 ricardoV94 changed the title Add return type overload for sample_posterior_predictive Add return type overload for sample_posterior_predictive Mar 4, 2025
@ricardoV94 ricardoV94 added the docs label Mar 4, 2025
predictions: bool = False,
idata_kwargs: dict | None = None,
compile_kwargs: dict | None = None,
) -> dict[str, np.ndarray]: ...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a MultiTrace type alias?

Copy link
Contributor Author

@nataziel nataziel Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could import the class from backends/base. Is the return type actually a MultiTrace? Or do you want a local alias for MultiTrace = dict[str, np.ndarray]?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it was a multi-trace, maybe it's a dict indeed. Can you double check

Copy link
Contributor Author

@nataziel nataziel Mar 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's definitely dict[str, np.ndarray]. Can confirm with this:

import pymc as pm
import numpy as np


def main():
    data = [5.1, 5.2, 4.9, 4.8]
    with pm.Model() as model:
        target_value = pm.Data(name="target_y", value=data, dims=("x"))
        a = pm.Normal("a")

        y_hat = pm.Deterministic("y_hat", var=a + 5)

        y_like = pm.Normal("y_like", mu=y_hat, observed=target_value)

        my_model = model

    fit_trace = pm.sample(model=my_model, tune=10, draws=10, chains=4)

    print(fit_trace)
    print(type(fit_trace))

    y_posterior_trace = pm.sample_posterior_predictive(
        model=my_model, trace=fit_trace, return_inferencedata=False
    )

    print(y_posterior_trace)
    print(type(y_posterior_trace))


if __name__ == "__main__":
    main()
Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a]

  Progress                                   Draws   Divergences   Step size   Grad evals   Sampling Speed    Elapsed   Remaining
 ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.84        3            1824.91 draws/s   0:00:00   0:00:00
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.80        3            51.02 draws/s     0:00:00   0:00:00
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.28        1            50.07 draws/s     0:00:00   0:00:00
  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━   20      0             0.15        7            14.61 draws/s     0:00:01   0:00:00

Sampling 4 chains for 10 tune and 10 draw iterations (40 + 40 draws total) took 9 seconds.
The number of samples is too small to check convergence reliably.
Inference data with groups:
        > posterior
        > sample_stats
        > observed_data
        > constant_data
<class 'arviz.data.inference_data.InferenceData'>
Sampling: [y_like]
Sampling ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:00
{'y_like': array([[[5.47525026, 6.7561118 , 6.75108787, 6.44631066],
        [3.08866384, 4.44911318, 2.11949771, 3.58852593],
        [5.78696067, 4.4653521 , 4.18029713, 4.25236639],
        [5.60767872, 4.8688315 , 3.44183177, 5.38761565],
        [4.84709411, 4.66473999, 5.96239288, 5.14296245],
        [4.45261563, 5.01725227, 6.03797216, 4.10837762],
        [6.46749518, 4.55590472, 6.02580747, 4.45176423],
        [3.97070401, 4.67813175, 4.55127384, 5.58474455],
        [5.23346119, 5.09860006, 5.40714872, 3.155104  ],
        [5.01437042, 4.73579148, 4.69969069, 4.73011345]],

       [[6.95786381, 5.44863886, 4.67877711, 6.07620478],
        [6.1479291 , 3.63454873, 4.26082608, 5.05364968],
        [4.94426014, 3.2264388 , 3.49036617, 3.51758425],
        [4.46372822, 4.97982756, 4.71369595, 4.28042535],
        [3.67790877, 4.61166178, 4.859494  , 3.90743623],
        [5.30186549, 5.51060686, 5.72649511, 5.44831013],
        [5.6824159 , 5.00966824, 5.81942202, 7.10113269],
        [6.69604693, 6.22185714, 4.66787917, 6.93183407],
        [2.69863789, 3.15122392, 3.93177678, 5.56284008],
        [5.42486489, 5.06666397, 5.31683066, 3.38231024]],

       [[7.55508934, 6.12556656, 5.7270704 , 7.26913077],
        [3.99491984, 4.29239014, 5.54759873, 4.23275301],
        [2.66913545, 4.87218678, 5.62928131, 5.30816862],
        [4.82329652, 5.23746588, 4.49146018, 5.63991739],
        [4.92370427, 5.95452363, 6.03808899, 5.94568277],
        [5.10762579, 7.12951932, 5.75341415, 5.84696291],
        [3.37979901, 3.62339437, 4.65771825, 3.7923768 ],
        [3.62867257, 2.94668575, 3.61802996, 4.1612957 ],
        [3.05870501, 4.31849566, 4.2617358 , 5.47176507],
        [4.20927715, 3.17447563, 3.20495462, 3.75464106]],

       [[5.35709662, 6.54309611, 3.40905019, 4.99514321],
        [4.92966164, 4.4250744 , 3.70201706, 4.31879397],
        [2.78342991, 4.88726609, 6.05864548, 4.551451  ],
        [3.37619167, 4.62058509, 5.80385571, 3.6307474 ],
        [3.92182737, 5.22549831, 5.04820779, 4.96424299],
        [5.59346296, 5.64233644, 6.81086527, 5.35190658],
        [5.13364673, 6.42014636, 5.2321678 , 3.9226347 ],
        [3.67539437, 3.52974516, 6.81766322, 5.42133207],
        [3.65774333, 4.45101107, 4.96830193, 3.83520201],
        [4.21230467, 3.95171007, 5.6626775 , 6.54087073]]])}
<class 'dict'>

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@ricardoV94 ricardoV94 merged commit ba018b7 into pymc-devs:main Mar 9, 2025
26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants