@@ -4015,19 +4015,68 @@ def test_local_sumsqr2dot():
4015
4015
4016
4016
4017
4017
def test_local_mulexp2expadd ():
4018
- # e^x * e^y = e^(x+y)
4019
- # test simple scalars first
4020
4018
x = scalar ("x" )
4021
4019
y = scalar ("y" )
4020
+ z = scalar ("z" )
4021
+ w = scalar ("w" )
4022
4022
expx = exp (x )
4023
4023
expy = exp (y )
4024
- expx_expy = expx * expy
4025
- f = function ([x , y ], expx_expy )
4026
- utt .assert_allclose (f (3 , 4 ), np .exp (3 + 4 ))
4024
+ expz = exp (z )
4025
+ expw = exp (w )
4026
+
4027
+ # e^x * e^y * e^z * e^w = e^(x+y+z+w)
4028
+ op = expx * expy * expz * expw
4029
+ f = function ([x , y , z , w ], op )
4030
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 + 5 + 6 ))
4031
+ graph = f .maker .fgraph .toposort ()
4032
+ assert isinstance (graph [0 ].op , Elemwise )
4033
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4034
+ assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4035
+ assert not any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4036
+
4037
+ # e^x * e^y * e^z / e^w = e^(x+y+z-w)
4038
+ op = expx * expy * expz / expw
4039
+ f = function ([x , y , z , w ], op )
4040
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 + 5 - 6 ))
4041
+ graph = f .maker .fgraph .toposort ()
4042
+ assert isinstance (graph [0 ].op , Elemwise )
4043
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4044
+ assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4045
+ assert any (isinstance (n .op , aes .Sub ) for n in inner_graph )
4046
+ assert not any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4047
+ assert not any (isinstance (n .op , aes .TrueDiv ) for n in inner_graph )
4048
+
4049
+ # e^x * e^y / e^z * e^w = e^(x+y-z+w)
4050
+ op = expx * expy / expz * expw
4051
+ f = function ([x , y , z , w ], op )
4052
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 - 5 + 6 ))
4053
+ graph = f .maker .fgraph .toposort ()
4054
+ assert isinstance (graph [0 ].op , Elemwise )
4055
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4056
+ assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4057
+ assert any (isinstance (n .op , aes .Sub ) for n in inner_graph )
4058
+ assert not any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4059
+ assert not any (isinstance (n .op , aes .TrueDiv ) for n in inner_graph )
4060
+
4061
+ # e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
4062
+ op = expx / expy / expz
4063
+ f = function ([x , y , z ], op )
4064
+ utt .assert_allclose (f (3 , 4 , 5 ), np .exp (3 - 4 - 5 ))
4065
+ graph = f .maker .fgraph .toposort ()
4066
+ assert isinstance (graph [0 ].op , Elemwise )
4067
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4068
+ assert any (isinstance (n .op , aes .Sub ) for n in inner_graph )
4069
+ assert not any (isinstance (n .op , aes .TrueDiv ) for n in inner_graph )
4070
+
4071
+ # e^x * y * e^z * w = e^(x+z) * y * w
4072
+ op = expx * y * expz * w
4073
+ f = function ([x , y , z , w ], op )
4074
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 5 ) * 4 * 6 )
4027
4075
graph = f .maker .fgraph .toposort ()
4028
4076
assert isinstance (graph [0 ].op , Elemwise )
4029
4077
inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4030
4078
assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4079
+ assert any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4031
4080
4032
4081
# expect same for matrices as well
4033
4082
mx = matrix ("mx" )
@@ -4040,28 +4089,20 @@ def test_local_mulexp2expadd():
4040
4089
assert isinstance (graph [0 ].op , Elemwise )
4041
4090
inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4042
4091
assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4092
+ assert not any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4043
4093
4044
4094
# checking whether further rewrites can proceed after this one as one would expect
4045
4095
# e^x * e^(-x) = e^(x-x) = e^0 = 1
4046
4096
f = function ([x ], expx * exp (neg (x )))
4047
- graph = f .maker .fgraph .toposort ()
4048
- assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4049
4097
utt .assert_allclose (f (42 ), 1 )
4050
-
4051
- # e^x / e^y = e^(x-y)
4052
- expx_div_expy = expx / expy
4053
- f = function ([x , y ], expx_div_expy )
4054
- utt .assert_allclose (f (5 , 3 ), np .exp (5 - 3 ))
4055
4098
graph = f .maker .fgraph .toposort ()
4056
- assert isinstance (graph [0 ].op , Elemwise )
4057
- inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4058
- assert any (isinstance (n .op , aes .Sub ) for n in inner_graph )
4099
+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4059
4100
4060
4101
# e^x / e^x = e^(x-x) = e^0 = 1
4061
4102
f = function ([x ], expx / expx )
4103
+ utt .assert_allclose (f (42 ), 1 )
4062
4104
graph = f .maker .fgraph .toposort ()
4063
4105
assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4064
- utt .assert_allclose (f (42 ), 1 )
4065
4106
4066
4107
4067
4108
def test_local_expm1 ():
0 commit comments