-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add return type overload for sample_posterior_predictive
#7710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add return type overload for sample_posterior_predictive
#7710
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7710 +/- ##
==========================================
- Coverage 92.67% 92.66% -0.01%
==========================================
Files 107 107
Lines 18329 18333 +4
==========================================
+ Hits 16986 16989 +3
- Misses 1343 1344 +1
🚀 New features to boost your workflow:
|
sample_posterior_predictive
predictions: bool = False, | ||
idata_kwargs: dict | None = None, | ||
compile_kwargs: dict | None = None, | ||
) -> dict[str, np.ndarray]: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have a MultiTrace type alias?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could import the class from backends/base. Is the return type actually a MultiTrace
? Or do you want a local alias for MultiTrace
= dict[str, np.ndarray]
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought it was a multi-trace, maybe it's a dict indeed. Can you double check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's definitely dict[str, np.ndarray]
. Can confirm with this:
import pymc as pm
import numpy as np
def main():
data = [5.1, 5.2, 4.9, 4.8]
with pm.Model() as model:
target_value = pm.Data(name="target_y", value=data, dims=("x"))
a = pm.Normal("a")
y_hat = pm.Deterministic("y_hat", var=a + 5)
y_like = pm.Normal("y_like", mu=y_hat, observed=target_value)
my_model = model
fit_trace = pm.sample(model=my_model, tune=10, draws=10, chains=4)
print(fit_trace)
print(type(fit_trace))
y_posterior_trace = pm.sample_posterior_predictive(
model=my_model, trace=fit_trace, return_inferencedata=False
)
print(y_posterior_trace)
print(type(y_posterior_trace))
if __name__ == "__main__":
main()
Only 10 samples per chain. Reliable r-hat and ESS diagnostics require longer chains for accurate estimate.
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [a]
Progress Draws Divergences Step size Grad evals Sampling Speed Elapsed Remaining
─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 20 0 0.84 3 1824.91 draws/s 0:00:00 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 20 0 0.80 3 51.02 draws/s 0:00:00 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 20 0 0.28 1 50.07 draws/s 0:00:00 0:00:00
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 20 0 0.15 7 14.61 draws/s 0:00:01 0:00:00
Sampling 4 chains for 10 tune and 10 draw iterations (40 + 40 draws total) took 9 seconds.
The number of samples is too small to check convergence reliably.
Inference data with groups:
> posterior
> sample_stats
> observed_data
> constant_data
<class 'arviz.data.inference_data.InferenceData'>
Sampling: [y_like]
Sampling ... ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100% 0:00:00 / 0:00:00
{'y_like': array([[[5.47525026, 6.7561118 , 6.75108787, 6.44631066],
[3.08866384, 4.44911318, 2.11949771, 3.58852593],
[5.78696067, 4.4653521 , 4.18029713, 4.25236639],
[5.60767872, 4.8688315 , 3.44183177, 5.38761565],
[4.84709411, 4.66473999, 5.96239288, 5.14296245],
[4.45261563, 5.01725227, 6.03797216, 4.10837762],
[6.46749518, 4.55590472, 6.02580747, 4.45176423],
[3.97070401, 4.67813175, 4.55127384, 5.58474455],
[5.23346119, 5.09860006, 5.40714872, 3.155104 ],
[5.01437042, 4.73579148, 4.69969069, 4.73011345]],
[[6.95786381, 5.44863886, 4.67877711, 6.07620478],
[6.1479291 , 3.63454873, 4.26082608, 5.05364968],
[4.94426014, 3.2264388 , 3.49036617, 3.51758425],
[4.46372822, 4.97982756, 4.71369595, 4.28042535],
[3.67790877, 4.61166178, 4.859494 , 3.90743623],
[5.30186549, 5.51060686, 5.72649511, 5.44831013],
[5.6824159 , 5.00966824, 5.81942202, 7.10113269],
[6.69604693, 6.22185714, 4.66787917, 6.93183407],
[2.69863789, 3.15122392, 3.93177678, 5.56284008],
[5.42486489, 5.06666397, 5.31683066, 3.38231024]],
[[7.55508934, 6.12556656, 5.7270704 , 7.26913077],
[3.99491984, 4.29239014, 5.54759873, 4.23275301],
[2.66913545, 4.87218678, 5.62928131, 5.30816862],
[4.82329652, 5.23746588, 4.49146018, 5.63991739],
[4.92370427, 5.95452363, 6.03808899, 5.94568277],
[5.10762579, 7.12951932, 5.75341415, 5.84696291],
[3.37979901, 3.62339437, 4.65771825, 3.7923768 ],
[3.62867257, 2.94668575, 3.61802996, 4.1612957 ],
[3.05870501, 4.31849566, 4.2617358 , 5.47176507],
[4.20927715, 3.17447563, 3.20495462, 3.75464106]],
[[5.35709662, 6.54309611, 3.40905019, 4.99514321],
[4.92966164, 4.4250744 , 3.70201706, 4.31879397],
[2.78342991, 4.88726609, 6.05864548, 4.551451 ],
[3.37619167, 4.62058509, 5.80385571, 3.6307474 ],
[3.92182737, 5.22549831, 5.04820779, 4.96424299],
[5.59346296, 5.64233644, 6.81086527, 5.35190658],
[5.13364673, 6.42014636, 5.2321678 , 3.9226347 ],
[3.67539437, 3.52974516, 6.81766322, 5.42133207],
[3.65774333, 4.45101107, 4.96830193, 3.83520201],
[4.21230467, 3.95171007, 5.6626775 , 6.54087073]]])}
<class 'dict'>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
add return type overload for sample_posterior_predictive
Description
Added overloads to
sample_prior_predictive
to signal to type checkers the return type of the function, based on thereturn_inferencedata
boolean parameter.Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7710.org.readthedocs.build/en/7710/