Skip to content

Rectify return type hints in logprob module rewrites #7125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pymc/logprob/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import GE, GT, LE, LT, Invert
from pytensor.tensor import TensorVariable
from pytensor.tensor.math import ge, gt, invert, le, lt

from pymc.logprob.abstract import (
Expand All @@ -41,7 +42,7 @@ class MeasurableComparison(MeasurableElemwise):
@node_rewriter(tracks=[gt, lt, ge, le])
def find_measurable_comparisons(
fgraph: FunctionGraph, node: Node
) -> Optional[list[MeasurableComparison]]:
) -> Optional[list[TensorVariable]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand Down Expand Up @@ -133,7 +134,7 @@ class MeasurableBitwise(MeasurableElemwise):


@node_rewriter(tracks=[invert])
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableBitwise]]:
def find_measurable_bitwise(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions pymc/logprob/censoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.scalar.basic import Ceil, Clip, Floor, RoundHalfToEven
from pytensor.scalar.basic import clip as scalar_clip
from pytensor.tensor import TensorVariable
from pytensor.tensor.math import ceil, clip, floor, round_half_to_even
from pytensor.tensor.variable import TensorConstant

Expand All @@ -62,7 +63,7 @@ class MeasurableClip(MeasurableElemwise):


@node_rewriter(tracks=[clip])
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableClip]]:
def find_measurable_clips(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
# TODO: Canonicalize x[x>ub] = ub -> clip(x, x, ub)

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down Expand Up @@ -157,7 +158,7 @@ class MeasurableRound(MeasurableElemwise):


@node_rewriter(tracks=[ceil, floor, round_half_to_even])
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[MeasurableRound]]:
def find_measurable_roundings(fgraph: FunctionGraph, node: Node) -> Optional[list[TensorVariable]]:
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
if rv_map_feature is None:
return None # pragma: no cover
Expand Down
5 changes: 3 additions & 2 deletions pymc/logprob/checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable
from pytensor.tensor.shape import SpecifyShape

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
Expand All @@ -63,7 +64,7 @@ def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):


@node_rewriter([SpecifyShape])
def find_measurable_specify_shapes(fgraph, node) -> Optional[list[MeasurableSpecifyShape]]:
def find_measurable_specify_shapes(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed."""

if isinstance(node.op, MeasurableSpecifyShape):
Expand Down Expand Up @@ -116,7 +117,7 @@ def logprob_check_and_raise(op, values, inner_rv, *assertions, **kwargs):


@node_rewriter([CheckAndRaise])
def find_measurable_check_and_raise(fgraph, node) -> Optional[list[MeasurableCheckAndRaise]]:
def find_measurable_check_and_raise(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `AssertOp`\s for which a `logprob` can be computed."""

if isinstance(node.op, MeasurableCheckAndRaise):
Expand Down
3 changes: 2 additions & 1 deletion pymc/logprob/cumsum.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import pytensor.tensor as pt

from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import CumOp

from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
Expand Down Expand Up @@ -77,7 +78,7 @@ def logprob_cumsum(op, values, base_rv, **kwargs):


@node_rewriter([CumOp])
def find_measurable_cumsums(fgraph, node) -> Optional[list[MeasurableCumsum]]:
def find_measurable_cumsums(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `Cumsums`\s for which a `logprob` can be computed."""

if not (isinstance(node.op, CumOp) and node.op.mode == "add"):
Expand Down
7 changes: 3 additions & 4 deletions pymc/logprob/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from pytensor import tensor as pt
from pytensor.graph.op import compute_test_value
from pytensor.graph.rewriting.basic import node_rewriter
from pytensor.tensor import TensorVariable
from pytensor.tensor.basic import Alloc, Join, MakeVector
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -197,9 +198,7 @@ def logprob_join(op, values, axis, *base_rvs, **kwargs):


@node_rewriter([MakeVector, Join])
def find_measurable_stacks(
fgraph, node
) -> Optional[list[Union[MeasurableMakeVector, MeasurableJoin]]]:
def find_measurable_stacks(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `Joins`\s and `MakeVector`\s for which a `logprob` can be computed."""

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down Expand Up @@ -273,7 +272,7 @@ def logprob_dimshuffle(op, values, base_var, **kwargs):


@node_rewriter([DimShuffle])
def find_measurable_dimshuffles(fgraph, node) -> Optional[list[MeasurableDimShuffle]]:
def find_measurable_dimshuffles(fgraph, node) -> Optional[list[TensorVariable]]:
r"""Finds `Dimshuffle`\s for which a `logprob` can be computed."""

rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,4 @@ sphinx>=1.5
sphinxext-rediraffe
types-cachetools
typing-extensions>=3.7.4
watermark
watermark