Skip to content

Commit 064e72f

Browse files
Only use input shapes to compute output shape in Elemwise.infer_shape
1 parent 22416ba commit 064e72f

File tree

2 files changed

+53
-33
lines changed

2 files changed

+53
-33
lines changed

aesara/tensor/elemwise.py

Lines changed: 11 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from aesara.misc.safe_asarray import _asarray
1717
from aesara.printing import FunctionPrinter, Printer, pprint
1818
from aesara.scalar import get_scalar_type
19-
from aesara.scalar.basic import ScalarType
2019
from aesara.scalar.basic import bool as scalar_bool
2120
from aesara.scalar.basic import identity as scalar_identity
2221
from aesara.scalar.basic import transfer_type, upcast
@@ -804,37 +803,17 @@ def perform(self, node, inputs, output_storage):
804803
storage[0] = variable
805804

806805
def infer_shape(self, fgraph, node, i_shapes):
807-
rval = []
808-
for o in node.outputs:
809-
oshp = []
810-
for dim, b in enumerate(o.type.broadcastable):
811-
b_dim = None
812-
if b:
813-
# this is broadcastable
814-
b_dim = 1
815-
else:
816-
# there must be some input that is not broadcastable in
817-
# dimension 'dim'
818-
for ishp, i in zip(i_shapes, node.inputs):
819-
if isinstance(i.type, ScalarType):
820-
continue # we skip scalar
821-
if not i.type.broadcastable[dim]:
822-
# input i is not broadcastable in position dim
823-
# therefore if its shape is known, we can use it
824-
# as the output shape
825-
if ishp[dim]:
826-
b_dim = ishp[dim]
827-
break
828-
829-
# b_dim might still be None, if every input's shape was unknown
830-
# in dimension 'dim'
831-
oshp.append(b_dim)
832-
# TODO: it would be interesting to return the constraining
833-
# information that if one of the inputs shape[dim] is known
834-
# and another input's shape[dim] is not, that we can now assume
835-
# that the other input's shape[dim] is the same as the first.
836-
rval.append(tuple(oshp))
837-
return rval
806+
807+
if len(node.outputs) > 1:
808+
from aesara.tensor.basic_opt import ShapeError
809+
810+
raise ShapeError(
811+
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
812+
)
813+
814+
out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
815+
816+
return [out_shape]
838817

839818
def _c_all(self, node, nodename, inames, onames, sub):
840819
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`

tests/tensor/test_elemwise.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
import tests.unittest_tools as utt
1212
from aesara.compile.mode import Mode
1313
from aesara.configdefaults import config
14-
from aesara.graph.basic import Variable
14+
from aesara.graph.basic import Apply, Variable
1515
from aesara.graph.fg import FunctionGraph
1616
from aesara.link.basic import PerformLinker
1717
from aesara.link.c.basic import CLinker, OpWiseCLinker
1818
from aesara.tensor import as_tensor_variable
1919
from aesara.tensor.basic import second
20+
from aesara.tensor.basic_opt import ShapeError
2021
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
2122
from aesara.tensor.math import all as at_all
2223
from aesara.tensor.math import any as at_any
@@ -800,6 +801,46 @@ def test_str(self):
800801
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
801802
assert str(op) == "my_op"
802803

804+
def test_partial_static_shape_info(self):
805+
"""Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting."""
806+
807+
x = TensorType("floatX", shape=(None, None))()
808+
z = Elemwise(aes.add)(x, x)
809+
810+
x_inferred_shape = (aes.constant(1), aes.constant(1))
811+
812+
res_shape = z.owner.op.infer_shape(
813+
None, z.owner, [x_inferred_shape, x_inferred_shape]
814+
)
815+
816+
assert len(res_shape) == 1
817+
assert len(res_shape[0]) == 2
818+
assert res_shape[0][0].data == 1
819+
assert res_shape[0][1].data == 1
820+
821+
def test_multi_output(self):
822+
class CustomElemwise(Elemwise):
823+
def make_node(self, *args):
824+
res = super().make_node(*args)
825+
return Apply(
826+
self,
827+
res.inputs,
828+
# Return two outputs
829+
[
830+
TensorType(dtype="float64", shape=(None, None))()
831+
for i in range(2)
832+
],
833+
)
834+
835+
z_1, z_2 = CustomElemwise(aes.add)(
836+
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
837+
)
838+
839+
in_1_shape = (aes.constant(1), aes.constant(1))
840+
841+
with pytest.raises(ShapeError):
842+
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
843+
803844

804845
def test_not_implemented_elemwise_grad():
805846
# Regression test for unimplemented gradient in an Elemwise Op.

0 commit comments

Comments
 (0)