1
1
import warnings
2
- from typing import Sequence , Tuple , Union
2
+ from typing import Sequence
3
3
4
4
import numpy as np
5
5
import pymc
6
6
import pytensor .tensor as pt
7
- from arviz import dict_to_dataset
7
+ from arviz import InferenceData , dict_to_dataset
8
8
from pymc import SymbolicRandomVariable
9
9
from pymc .backends .arviz import coords_and_dims_for_inferencedata , dataset_to_point_list
10
10
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
16
from pymc .pytensorf import compile_pymc , constant_fold
17
- from pymc .util import _get_seeds_per_chain , treedict
17
+ from pymc .util import RandomState , _get_seeds_per_chain , treedict
18
18
from pytensor import Mode , scan
19
19
from pytensor .compile import SharedVariable
20
20
from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
@@ -235,7 +235,7 @@ def clone(self):
235
235
236
236
def marginalize (
237
237
self ,
238
- rvs_to_marginalize : Union [ TensorVariable , str , Sequence [TensorVariable ], Sequence [str ] ],
238
+ rvs_to_marginalize : TensorVariable | Sequence [TensorVariable ] | str | Sequence [str ],
239
239
):
240
240
if not isinstance (rvs_to_marginalize , Sequence ):
241
241
rvs_to_marginalize = (rvs_to_marginalize ,)
@@ -292,7 +292,7 @@ def _to_transformed(self):
292
292
fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
293
293
return fn , transformed_names
294
294
295
- def unmarginalize (self , rvs_to_unmarginalize ):
295
+ def unmarginalize (self , rvs_to_unmarginalize : Sequence [ TensorVariable ] ):
296
296
for rv in rvs_to_unmarginalize :
297
297
self .marginalized_rvs .remove (rv )
298
298
if rv .name in self ._marginalized_named_vars_to_dims :
@@ -303,11 +303,11 @@ def unmarginalize(self, rvs_to_unmarginalize):
303
303
304
304
def recover_marginals (
305
305
self ,
306
- idata ,
307
- var_names = None ,
308
- return_samples = True ,
309
- extend_inferencedata = True ,
310
- random_seed = None ,
306
+ idata : InferenceData ,
307
+ var_names : Sequence [ str ] | None = None ,
308
+ return_samples : bool = True ,
309
+ extend_inferencedata : bool = True ,
310
+ random_seed : RandomState = None ,
311
311
):
312
312
"""Computes posterior log-probabilities and samples of marginalized variables
313
313
conditioned on parameters of the model given InferenceData with posterior group
@@ -648,7 +648,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
648
648
return rvs_to_marginalize , marginalized_rvs
649
649
650
650
651
- def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> Tuple [int , ...]:
651
+ def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> tuple [int , ...]:
652
652
op = rv .owner .op
653
653
if isinstance (op , Bernoulli ):
654
654
return (0 , 1 )
0 commit comments