Skip to content

Commit c3dfc5a

Browse files
committed
Refactor helper to create safe gufunc signature
1 parent 16a4f3b commit c3dfc5a

File tree

2 files changed

+28
-20
lines changed

2 files changed

+28
-20
lines changed

pytensor/tensor/blockwise.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from pytensor import config
88
from pytensor.gradient import DisconnectedType
9-
from pytensor.graph.basic import Apply, Constant, Variable
9+
from pytensor.graph.basic import Apply, Constant
1010
from pytensor.graph.null_type import NullType
1111
from pytensor.graph.op import Op
1212
from pytensor.graph.replace import (
@@ -22,27 +22,11 @@
2222
_parse_gufunc_signature,
2323
broadcast_static_dim_lengths,
2424
import_func_from_string,
25+
safe_signature,
2526
)
2627
from pytensor.tensor.variable import TensorVariable
2728

2829

29-
def safe_signature(
30-
core_inputs: Sequence[Variable],
31-
core_outputs: Sequence[Variable],
32-
) -> str:
33-
def operand_sig(operand: Variable, prefix: str) -> str:
34-
operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim))
35-
return f"({operands})"
36-
37-
inputs_sig = ",".join(
38-
operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs)
39-
)
40-
outputs_sig = ",".join(
41-
operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs)
42-
)
43-
return f"{inputs_sig}->{outputs_sig}"
44-
45-
4630
class Blockwise(Op):
4731
"""Generalizes a core `Op` to work with batched dimensions.
4832
@@ -385,7 +369,10 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
385369
else:
386370
# TODO: This is pretty bad for shape inference and merge optimization!
387371
# Should get better as we add signatures to our Ops
388-
signature = safe_signature(node.inputs, node.outputs)
372+
signature = safe_signature(
373+
[inp.type.ndim for inp in node.inputs],
374+
[out.type.ndim for out in node.outputs],
375+
)
389376
return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
390377

391378

pytensor/tensor/utils.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,11 @@ def broadcast_static_dim_lengths(
172172
_SIGNATURE = f"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
173173

174174

175-
def _parse_gufunc_signature(signature):
175+
def _parse_gufunc_signature(
176+
signature,
177+
) -> tuple[
178+
list[tuple[str, ...]], ...
179+
]: # mypy doesn't know it's alwayl a length two tuple
176180
"""
177181
Parse string signatures for a generalized universal function.
178182
@@ -198,3 +202,20 @@ def _parse_gufunc_signature(signature):
198202
]
199203
for arg_list in signature.split("->")
200204
)
205+
206+
207+
def safe_signature(
208+
core_inputs_ndim: Sequence[int],
209+
core_outputs_ndim: Sequence[int],
210+
) -> str:
211+
def operand_sig(operand_ndim: int, prefix: str) -> str:
212+
operands = ",".join(f"{prefix}{i}" for i in range(operand_ndim))
213+
return f"({operands})"
214+
215+
inputs_sig = ",".join(
216+
operand_sig(ndim, prefix=f"i{n}") for n, ndim in enumerate(core_inputs_ndim)
217+
)
218+
outputs_sig = ",".join(
219+
operand_sig(ndim, prefix=f"o{n}") for n, ndim in enumerate(core_outputs_ndim)
220+
)
221+
return f"{inputs_sig}->{outputs_sig}"

0 commit comments

Comments
 (0)