|
11 | 11 | import tests.unittest_tools as utt
|
12 | 12 | from aesara.compile.mode import Mode
|
13 | 13 | from aesara.configdefaults import config
|
14 |
| -from aesara.graph.basic import Variable |
| 14 | +from aesara.graph.basic import Apply, Variable |
15 | 15 | from aesara.graph.fg import FunctionGraph
|
16 | 16 | from aesara.link.basic import PerformLinker
|
17 | 17 | from aesara.link.c.basic import CLinker, OpWiseCLinker
|
18 | 18 | from aesara.tensor import as_tensor_variable
|
19 | 19 | from aesara.tensor.basic import second
|
| 20 | +from aesara.tensor.basic_opt import ShapeError |
20 | 21 | from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
|
21 | 22 | from aesara.tensor.math import all as at_all
|
22 | 23 | from aesara.tensor.math import any as at_any
|
@@ -800,6 +801,46 @@ def test_str(self):
|
800 | 801 | op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
|
801 | 802 | assert str(op) == "my_op"
|
802 | 803 |
|
| 804 | + def test_partial_static_shape_info(self): |
| 805 | + """Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting.""" |
| 806 | + |
| 807 | + x = TensorType("floatX", shape=(None, None))() |
| 808 | + z = Elemwise(aes.add)(x, x) |
| 809 | + |
| 810 | + x_inferred_shape = (aes.constant(1), aes.constant(1)) |
| 811 | + |
| 812 | + res_shape = z.owner.op.infer_shape( |
| 813 | + None, z.owner, [x_inferred_shape, x_inferred_shape] |
| 814 | + ) |
| 815 | + |
| 816 | + assert len(res_shape) == 1 |
| 817 | + assert len(res_shape[0]) == 2 |
| 818 | + assert res_shape[0][0].data == 1 |
| 819 | + assert res_shape[0][1].data == 1 |
| 820 | + |
| 821 | + def test_multi_output(self): |
| 822 | + class CustomElemwise(Elemwise): |
| 823 | + def make_node(self, *args): |
| 824 | + res = super().make_node(*args) |
| 825 | + return Apply( |
| 826 | + self, |
| 827 | + res.inputs, |
| 828 | + # Return two outputs |
| 829 | + [ |
| 830 | + TensorType(dtype="float64", shape=(None, None))() |
| 831 | + for i in range(2) |
| 832 | + ], |
| 833 | + ) |
| 834 | + |
| 835 | + z_1, z_2 = CustomElemwise(aes.add)( |
| 836 | + as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1)) |
| 837 | + ) |
| 838 | + |
| 839 | + in_1_shape = (aes.constant(1), aes.constant(1)) |
| 840 | + |
| 841 | + with pytest.raises(ShapeError): |
| 842 | + z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) |
| 843 | + |
803 | 844 |
|
804 | 845 | def test_not_implemented_elemwise_grad():
|
805 | 846 | # Regression test for unimplemented gradient in an Elemwise Op.
|
|
0 commit comments