|
6 | 6 |
|
7 | 7 | from pytensor import config
|
8 | 8 | from pytensor.gradient import DisconnectedType
|
9 |
| -from pytensor.graph.basic import Apply, Constant, Variable |
| 9 | +from pytensor.graph.basic import Apply, Constant |
10 | 10 | from pytensor.graph.null_type import NullType
|
11 | 11 | from pytensor.graph.op import Op
|
12 | 12 | from pytensor.graph.replace import (
|
|
22 | 22 | _parse_gufunc_signature,
|
23 | 23 | broadcast_static_dim_lengths,
|
24 | 24 | import_func_from_string,
|
| 25 | + safe_signature, |
25 | 26 | )
|
26 | 27 | from pytensor.tensor.variable import TensorVariable
|
27 | 28 |
|
28 | 29 |
|
29 |
| -def safe_signature( |
30 |
| - core_inputs: Sequence[Variable], |
31 |
| - core_outputs: Sequence[Variable], |
32 |
| -) -> str: |
33 |
| - def operand_sig(operand: Variable, prefix: str) -> str: |
34 |
| - operands = ",".join(f"{prefix}{i}" for i in range(operand.type.ndim)) |
35 |
| - return f"({operands})" |
36 |
| - |
37 |
| - inputs_sig = ",".join( |
38 |
| - operand_sig(i, prefix=f"i{n}") for n, i in enumerate(core_inputs) |
39 |
| - ) |
40 |
| - outputs_sig = ",".join( |
41 |
| - operand_sig(o, prefix=f"o{n}") for n, o in enumerate(core_outputs) |
42 |
| - ) |
43 |
| - return f"{inputs_sig}->{outputs_sig}" |
44 |
| - |
45 |
| - |
46 | 30 | class Blockwise(Op):
|
47 | 31 | """Generalizes a core `Op` to work with batched dimensions.
|
48 | 32 |
|
@@ -385,7 +369,10 @@ def vectorize_node_fallback(op: Op, node: Apply, *bached_inputs) -> Apply:
|
385 | 369 | else:
|
386 | 370 | # TODO: This is pretty bad for shape inference and merge optimization!
|
387 | 371 | # Should get better as we add signatures to our Ops
|
388 |
| - signature = safe_signature(node.inputs, node.outputs) |
| 372 | + signature = safe_signature( |
| 373 | + [inp.type.ndim for inp in node.inputs], |
| 374 | + [out.type.ndim for out in node.outputs], |
| 375 | + ) |
389 | 376 | return cast(Apply, Blockwise(op, signature=signature).make_node(*bached_inputs))
|
390 | 377 |
|
391 | 378 |
|
|
0 commit comments