Skip to content

Commit f4ae645

Browse files
committed
Add type hints to MarginalModel methods
1 parent 8158627 commit f4ae645

File tree

1 file changed

+11
-11
lines changed

1 file changed

+11
-11
lines changed

pymc_experimental/model/marginal_model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import warnings
2-
from typing import Sequence, Tuple, Union
2+
from typing import Sequence
33

44
import numpy as np
55
import pymc
66
import pytensor.tensor as pt
7-
from arviz import dict_to_dataset
7+
from arviz import InferenceData, dict_to_dataset
88
from pymc import SymbolicRandomVariable
99
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
1010
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
@@ -14,7 +14,7 @@
1414
from pymc.logprob.transforms import IntervalTransform
1515
from pymc.model import Model
1616
from pymc.pytensorf import compile_pymc, constant_fold
17-
from pymc.util import _get_seeds_per_chain, treedict
17+
from pymc.util import RandomState, _get_seeds_per_chain, treedict
1818
from pytensor import Mode, scan
1919
from pytensor.compile import SharedVariable
2020
from pytensor.graph import Constant, FunctionGraph, ancestors, clone_replace
@@ -235,7 +235,7 @@ def clone(self):
235235

236236
def marginalize(
237237
self,
238-
rvs_to_marginalize: Union[TensorVariable, str, Sequence[TensorVariable], Sequence[str]],
238+
rvs_to_marginalize: TensorVariable | Sequence[TensorVariable] | str | Sequence[str],
239239
):
240240
if not isinstance(rvs_to_marginalize, Sequence):
241241
rvs_to_marginalize = (rvs_to_marginalize,)
@@ -292,7 +292,7 @@ def _to_transformed(self):
292292
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
293293
return fn, transformed_names
294294

295-
def unmarginalize(self, rvs_to_unmarginalize):
295+
def unmarginalize(self, rvs_to_unmarginalize: Sequence[TensorVariable]):
296296
for rv in rvs_to_unmarginalize:
297297
self.marginalized_rvs.remove(rv)
298298
if rv.name in self._marginalized_named_vars_to_dims:
@@ -303,11 +303,11 @@ def unmarginalize(self, rvs_to_unmarginalize):
303303

304304
def recover_marginals(
305305
self,
306-
idata,
307-
var_names=None,
308-
return_samples=True,
309-
extend_inferencedata=True,
310-
random_seed=None,
306+
idata: InferenceData,
307+
var_names: Sequence[str] | None = None,
308+
return_samples: bool = True,
309+
extend_inferencedata: bool = True,
310+
random_seed: RandomState = None,
311311
):
312312
"""Computes posterior log-probabilities and samples of marginalized variables
313313
conditioned on parameters of the model given InferenceData with posterior group
@@ -648,7 +648,7 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
648648
return rvs_to_marginalize, marginalized_rvs
649649

650650

651-
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
651+
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
652652
op = rv.owner.op
653653
if isinstance(op, Bernoulli):
654654
return (0, 1)

0 commit comments

Comments
 (0)