Skip to content

Commit 8938851

Browse files
committed
Overload return type of compile_fn
1 parent bd74474 commit 8938851

File tree

1 file changed

+27
-4
lines changed

1 file changed

+27
-4
lines changed

pymc/model/core.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import types
1919
import warnings
2020

21-
from collections.abc import Callable, Iterable, Sequence
21+
from collections.abc import Iterable, Sequence
2222
from sys import modules
2323
from typing import (
2424
TYPE_CHECKING,
@@ -27,6 +27,7 @@
2727
Optional,
2828
TypeVar,
2929
cast,
30+
overload,
3031
)
3132

3233
import numpy as np
@@ -35,7 +36,7 @@
3536
import pytensor.tensor as pt
3637
import scipy.sparse as sps
3738

38-
from pytensor.compile import DeepCopyOp, get_mode
39+
from pytensor.compile import DeepCopyOp, Function, get_mode
3940
from pytensor.compile.sharedvalue import SharedVariable
4041
from pytensor.graph.basic import Constant, Variable, graph_inputs
4142
from pytensor.scalar import Cast
@@ -1524,6 +1525,28 @@ def replace_rvs_by_values(
15241525
rvs_to_transforms=self.rvs_to_transforms,
15251526
)
15261527

1528+
@overload
1529+
def compile_fn(
1530+
self,
1531+
outs: Variable | Sequence[Variable],
1532+
*,
1533+
inputs: Sequence[Variable] | None = None,
1534+
mode=None,
1535+
point_fn: Literal[True] = True,
1536+
**kwargs,
1537+
) -> PointFunc: ...
1538+
1539+
@overload
1540+
def compile_fn(
1541+
self,
1542+
outs: Variable | Sequence[Variable],
1543+
*,
1544+
inputs: Sequence[Variable] | None = None,
1545+
mode=None,
1546+
point_fn: Literal[False],
1547+
**kwargs,
1548+
) -> Function: ...
1549+
15271550
def compile_fn(
15281551
self,
15291552
outs: Variable | Sequence[Variable],
@@ -1532,7 +1555,7 @@ def compile_fn(
15321555
mode=None,
15331556
point_fn: bool = True,
15341557
**kwargs,
1535-
) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]:
1558+
) -> PointFunc | Function:
15361559
"""Compiles an PyTensor function
15371560
15381561
Parameters
@@ -2044,7 +2067,7 @@ def compile_fn(
20442067
point_fn: bool = True,
20452068
model: Model | None = None,
20462069
**kwargs,
2047-
) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]:
2070+
) -> PointFunc | Function:
20482071
"""Compiles an PyTensor function
20492072
20502073
Parameters

0 commit comments

Comments
 (0)