|
22 | 22 | from pytensor.tensor import elemwise_cgen as cgen
|
23 | 23 | from pytensor.tensor import get_vector_length
|
24 | 24 | from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
|
| 25 | +from pytensor.tensor.blockwise import _vectorize_node, vectorize_not_needed |
25 | 26 | from pytensor.tensor.type import (
|
26 | 27 | TensorType,
|
27 | 28 | continuous_dtypes,
|
28 | 29 | discrete_dtypes,
|
29 | 30 | float_dtypes,
|
30 | 31 | lvector,
|
31 | 32 | )
|
| 33 | +from pytensor.tensor.utils import broadcast_static_dim_lengths, import_func_from_string |
32 | 34 | from pytensor.tensor.variable import TensorVariable
|
33 | 35 | from pytensor.utils import uniq
|
34 | 36 |
|
@@ -232,7 +234,7 @@ def __str__(self):
|
232 | 234 | return f"Transpose{{axes={self.shuffle}}}"
|
233 | 235 | return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
|
234 | 236 |
|
235 |
| - def perform(self, node, inp, out, params): |
| 237 | + def perform(self, node, inp, out, params=None): |
236 | 238 | (res,) = inp
|
237 | 239 | (storage,) = out
|
238 | 240 |
|
@@ -429,28 +431,12 @@ def get_output_info(self, dim_shuffle, *inputs):
|
429 | 431 | # of all inputs in parallel... the all() gives us each output
|
430 | 432 | # broadcastable bit in turn.
|
431 | 433 |
|
432 |
| - def get_most_specialized_shape(shapes): |
433 |
| - shapes = set(shapes) |
434 |
| - # All shapes are the same |
435 |
| - if len(shapes) == 1: |
436 |
| - return tuple(shapes)[0] |
437 |
| - |
438 |
| - # Only valid indeterminate case |
439 |
| - if shapes == {None, 1}: |
440 |
| - return None |
441 |
| - |
442 |
| - shapes.discard(1) |
443 |
| - shapes.discard(None) |
444 |
| - if len(shapes) > 1: |
445 |
| - raise ValueError |
446 |
| - return tuple(shapes)[0] |
447 |
| - |
448 | 434 | # it is multiplied by nout because Elemwise supports multiple outputs
|
449 | 435 | # (nout of them)
|
450 | 436 | try:
|
451 | 437 | out_shapes = [
|
452 | 438 | [
|
453 |
| - get_most_specialized_shape(shape) |
| 439 | + broadcast_static_dim_lengths(shape) |
454 | 440 | for shape in zip(*[inp.type.shape for inp in inputs])
|
455 | 441 | ]
|
456 | 442 | ] * shadow.nout
|
@@ -665,22 +651,7 @@ def prepare_node(self, node, storage_map, compute_map, impl):
|
665 | 651 | impl = "c"
|
666 | 652 |
|
667 | 653 | if getattr(self, "nfunc_spec", None) and impl != "c":
|
668 |
| - self.nfunc = getattr(np, self.nfunc_spec[0], None) |
669 |
| - if self.nfunc is None: |
670 |
| - # Not inside NumPy. So probably another package like scipy. |
671 |
| - symb = self.nfunc_spec[0].split(".") |
672 |
| - for idx in range(1, len(self.nfunc_spec[0])): |
673 |
| - try: |
674 |
| - module = __import__(".".join(symb[:idx])) |
675 |
| - except ImportError: |
676 |
| - break |
677 |
| - for sub in symb[1:]: |
678 |
| - try: |
679 |
| - module = getattr(module, sub) |
680 |
| - except AttributeError: |
681 |
| - module = None |
682 |
| - break |
683 |
| - self.nfunc = module |
| 654 | + self.nfunc = import_func_from_string(self.nfunc_spec[0]) |
684 | 655 |
|
685 | 656 | if (
|
686 | 657 | (len(node.inputs) + len(node.outputs)) <= 32
|
@@ -1768,3 +1739,37 @@ def _get_vector_length_Elemwise(op, var):
|
1768 | 1739 | return get_vector_length(var.owner.inputs[0])
|
1769 | 1740 |
|
1770 | 1741 | raise ValueError(f"Length of {var} cannot be determined")
|
| 1742 | + |
| 1743 | + |
| 1744 | +_vectorize_node.register(Elemwise, vectorize_not_needed) |
| 1745 | + |
| 1746 | + |
| 1747 | +@_vectorize_node.register(DimShuffle) |
| 1748 | +def vectorize_dimshuffle(op: DimShuffle, node: Apply, x: TensorVariable) -> Apply: |
| 1749 | + batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
| 1750 | + if not batched_ndims: |
| 1751 | + return node.op.make_node(x) |
| 1752 | + input_broadcastable = x.type.broadcastable[:batched_ndims] + op.input_broadcastable |
| 1753 | + # e.g., ds(matrix, order=(1, "x", 0)) -> ds(tensor4, order=(0, 1, 3, "x", 2)) |
| 1754 | + # e.g., ds(row, order=(1, "x")) -> ds(tensor4, order=(0, 1, 3, "x")) |
| 1755 | + new_order = list(range(batched_ndims)) + [ |
| 1756 | + "x" if (o == "x") else (o + batched_ndims) for o in op.new_order |
| 1757 | + ] |
| 1758 | + return DimShuffle(input_broadcastable, new_order).make_node(x) |
| 1759 | + |
| 1760 | + |
| 1761 | +@_vectorize_node.register(CAReduce) |
| 1762 | +def vectorize_careduce(op: CAReduce, node: Apply, x: TensorVariable) -> Apply: |
| 1763 | + batched_ndims = x.type.ndim - node.inputs[0].type.ndim |
| 1764 | + if not batched_ndims: |
| 1765 | + return node.op.make_node(x) |
| 1766 | + axes = op.axis |
| 1767 | + # e.g., sum(matrix, axis=None) -> sum(tensor4, axis=(2, 3)) |
| 1768 | + # e.g., sum(matrix, axis=0) -> sum(tensor4, axis=(2,)) |
| 1769 | + if axes is None: |
| 1770 | + axes = list(range(node.inputs[0].type.ndim)) |
| 1771 | + else: |
| 1772 | + axes = list(axes) |
| 1773 | + new_axes = [axis + batched_ndims for axis in axes] |
| 1774 | + new_op = op.clone(axis=new_axes) |
| 1775 | + return new_op.make_node(x) |
0 commit comments