@@ -4014,6 +4014,56 @@ def test_local_sumsqr2dot():
4014
4014
)
4015
4015
4016
4016
4017
+ def test_local_mulexp2expadd ():
4018
+ # e^x * e^y = e^(x+y)
4019
+ # test simple scalars first
4020
+ x = scalar ("x" )
4021
+ y = scalar ("y" )
4022
+ expx = exp (x )
4023
+ expy = exp (y )
4024
+ expx_expy = expx * expy
4025
+ f = function ([x , y ], expx_expy )
4026
+ utt .assert_allclose (f (3 , 4 ), np .exp (3 + 4 ))
4027
+ graph = f .maker .fgraph .toposort ()
4028
+ assert isinstance (graph [0 ].op , Elemwise )
4029
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4030
+ assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4031
+
4032
+ # expect same for matrices as well
4033
+ mx = matrix ("mx" )
4034
+ my = matrix ("my" )
4035
+ f = function ([mx , my ], exp (mx ) * exp (my ))
4036
+ M1 = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
4037
+ M2 = np .array ([[5.0 , 6.0 ], [7.0 , 8.0 ]])
4038
+ utt .assert_allclose (f (M1 , M2 ), np .exp (M1 + M2 ))
4039
+ graph = f .maker .fgraph .toposort ()
4040
+ assert isinstance (graph [0 ].op , Elemwise )
4041
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4042
+ assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4043
+
4044
+ # checking whether further rewrites can proceed after this one as one would expect
4045
+ # e^x * e^(-x) = e^(x-x) = e^0 = 1
4046
+ f = function ([x ], expx * exp (neg (x )))
4047
+ graph = f .maker .fgraph .toposort ()
4048
+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4049
+ utt .assert_allclose (f (42 ), 1 )
4050
+
4051
+ # e^x / e^y = e^(x-y)
4052
+ expx_div_expy = expx / expy
4053
+ f = function ([x , y ], expx_div_expy )
4054
+ utt .assert_allclose (f (5 , 3 ), np .exp (5 - 3 ))
4055
+ graph = f .maker .fgraph .toposort ()
4056
+ assert isinstance (graph [0 ].op , Elemwise )
4057
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4058
+ assert any (isinstance (n .op , aes .Sub ) for n in inner_graph )
4059
+
4060
+ # e^x / e^x = e^(x-x) = e^0 = 1
4061
+ f = function ([x ], expx / expx )
4062
+ graph = f .maker .fgraph .toposort ()
4063
+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4064
+ utt .assert_allclose (f (42 ), 1 )
4065
+
4066
+
4017
4067
def test_local_expm1 ():
4018
4068
x = matrix ("x" )
4019
4069
u = scalar ("u" )
0 commit comments