18
18
import types
19
19
import warnings
20
20
21
- from collections .abc import Callable , Iterable , Sequence
21
+ from collections .abc import Iterable , Sequence
22
22
from sys import modules
23
23
from typing import (
24
24
TYPE_CHECKING ,
27
27
Optional ,
28
28
TypeVar ,
29
29
cast ,
30
+ overload ,
30
31
)
31
32
32
33
import numpy as np
35
36
import pytensor .tensor as pt
36
37
import scipy .sparse as sps
37
38
38
- from pytensor .compile import DeepCopyOp , get_mode
39
+ from pytensor .compile import DeepCopyOp , Function , get_mode
39
40
from pytensor .compile .sharedvalue import SharedVariable
40
41
from pytensor .graph .basic import Constant , Variable , graph_inputs
41
42
from pytensor .scalar import Cast
@@ -1524,6 +1525,28 @@ def replace_rvs_by_values(
1524
1525
rvs_to_transforms = self .rvs_to_transforms ,
1525
1526
)
1526
1527
1528
+ @overload
1529
+ def compile_fn (
1530
+ self ,
1531
+ outs : Variable | Sequence [Variable ],
1532
+ * ,
1533
+ inputs : Sequence [Variable ] | None = None ,
1534
+ mode = None ,
1535
+ point_fn : Literal [True ] = True ,
1536
+ ** kwargs ,
1537
+ ) -> PointFunc : ...
1538
+
1539
+ @overload
1540
+ def compile_fn (
1541
+ self ,
1542
+ outs : Variable | Sequence [Variable ],
1543
+ * ,
1544
+ inputs : Sequence [Variable ] | None = None ,
1545
+ mode = None ,
1546
+ point_fn : Literal [False ],
1547
+ ** kwargs ,
1548
+ ) -> Function : ...
1549
+
1527
1550
def compile_fn (
1528
1551
self ,
1529
1552
outs : Variable | Sequence [Variable ],
@@ -1532,7 +1555,7 @@ def compile_fn(
1532
1555
mode = None ,
1533
1556
point_fn : bool = True ,
1534
1557
** kwargs ,
1535
- ) -> PointFunc | Callable [[ Sequence [ np . ndarray ]], Sequence [ np . ndarray ]] :
1558
+ ) -> PointFunc | Function :
1536
1559
"""Compiles an PyTensor function
1537
1560
1538
1561
Parameters
@@ -2044,7 +2067,7 @@ def compile_fn(
2044
2067
point_fn : bool = True ,
2045
2068
model : Model | None = None ,
2046
2069
** kwargs ,
2047
- ) -> PointFunc | Callable [[ Sequence [ np . ndarray ]], Sequence [ np . ndarray ]] :
2070
+ ) -> PointFunc | Function :
2048
2071
"""Compiles an PyTensor function
2049
2072
2050
2073
Parameters
0 commit comments