Skip to content

Commit f25a624

Browse files
jessegrabowskiaseyboldtricardoV94zaxtax
committed
Implement Einsum
Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]> Co-authored-by: Rob Zinkov <[email protected]>
1 parent 23427a0 commit f25a624

20 files changed

+1569
-151
lines changed

pytensor/link/jax/dispatch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Load dispatch specializations
55
import pytensor.link.jax.dispatch.blas
66
import pytensor.link.jax.dispatch.blockwise
7+
import pytensor.link.jax.dispatch.einsum
78
import pytensor.link.jax.dispatch.elemwise
89
import pytensor.link.jax.dispatch.extra_ops
910
import pytensor.link.jax.dispatch.pad

pytensor/link/jax/dispatch/einsum.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import jax.numpy as jnp
2+
3+
from pytensor.link.jax.dispatch import jax_funcify
4+
from pytensor.tensor.einsum import Einsum
5+
6+
7+
@jax_funcify.register(Einsum)
8+
def jax_funcify_Einsum(op, **kwargs):
9+
"""Dispatch einsum to JAX.
10+
11+
This dispatch is triggered only when we couldn't optimize einsum at the PyTensor level.
12+
This happens when some of the dimension lengths are unknown. This is never a problem in JAX,
13+
as it always compiles a function per runtime input shape.
14+
"""
15+
subscripts = op.subscripts
16+
17+
def einsum(*operands):
18+
return jnp.einsum(subscripts, *operands, optimize="optimal")
19+
20+
return einsum

pytensor/tensor/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
151151

152152

153153
# isort: off
154+
from pytensor.tensor.einsum import einsum
154155
from pytensor.tensor.functional import vectorize
155156
# isort: on
156157

pytensor/tensor/basic.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -1700,21 +1700,22 @@ def do_constant_folding(self, fgraph, node):
17001700
return False
17011701

17021702
for client, idx in clients:
1703-
if isinstance(client.op, Output):
1703+
client_op = client.op
1704+
if isinstance(client_op, Output):
17041705
# If the output is a constant, it will have to be deepcopied
17051706
# each time the function is called. So we do not fold.
17061707
return False
1707-
# Allow alloc to be lifted out of Elemwise before constant folding it
1708-
elif isinstance(client.op, Elemwise):
1709-
return None
1708+
# Op's through which Alloc can be lifted
1709+
elif isinstance(client_op, Elemwise | DimShuffle | Alloc | Join):
1710+
return False
17101711
# Same for Blockwise, unless it has no batch_dims
1711-
elif isinstance(client.op, Blockwise) and client.op.batch_ndim(client):
1712-
return None
1712+
elif isinstance(client_op, Blockwise) and client.op.batch_ndim(client):
1713+
return False
17131714
elif (
17141715
# The following ops work inplace of their input id 0.
17151716
idx == 0
17161717
and isinstance(
1717-
client.op,
1718+
client_op,
17181719
pytensor.tensor.subtensor.IncSubtensor
17191720
| pytensor.tensor.subtensor.AdvancedIncSubtensor1
17201721
| pytensor.tensor.subtensor.AdvancedIncSubtensor
@@ -2035,10 +2036,15 @@ def transpose(x, axes=None):
20352036
_x = as_tensor_variable(x)
20362037

20372038
if axes is None:
2038-
axes = list(range((_x.type.ndim - 1), -1, -1))
2039+
axes = tuple(range((_x.type.ndim - 1), -1, -1))
2040+
2041+
if tuple(axes) == tuple(range(len(axes))):
2042+
# No-op
2043+
return _x
2044+
20392045
ret = DimShuffle(tuple(s == 1 for s in _x.type.shape), axes)(_x)
20402046

2041-
if _x.name and axes == list(range((_x.type.ndim - 1), -1, -1)):
2047+
if _x.name and axes == tuple(range((_x.type.ndim - 1), -1, -1)):
20422048
ret.name = _x.name + ".T"
20432049

20442050
return ret
@@ -3950,6 +3956,10 @@ def moveaxis(
39503956
source = normalize_axis_tuple(source, a.ndim, "source")
39513957
destination = normalize_axis_tuple(destination, a.ndim, "destination")
39523958

3959+
if source == destination:
3960+
# It's a no-op
3961+
return a
3962+
39533963
if len(source) != len(destination):
39543964
raise ValueError(
39553965
"`source` and `destination` arguments must have the same number of elements"
@@ -4260,9 +4270,7 @@ def atleast_Nd(
42604270
atleast_3d = partial(atleast_Nd, n=3)
42614271

42624272

4263-
def expand_dims(
4264-
a: np.ndarray | TensorVariable, axis: tuple[int, ...]
4265-
) -> TensorVariable:
4273+
def expand_dims(a: np.ndarray | TensorVariable, axis: Sequence[int]) -> TensorVariable:
42664274
"""Expand the shape of an array.
42674275
42684276
Insert a new axis that will appear at the `axis` position in the expanded
@@ -4281,7 +4289,7 @@ def expand_dims(
42814289
"""
42824290
a = as_tensor(a)
42834291

4284-
if not isinstance(axis, tuple | list):
4292+
if not isinstance(axis, Sequence):
42854293
axis = (axis,)
42864294

42874295
out_ndim = len(axis) + a.ndim

0 commit comments

Comments
 (0)