@@ -4105,6 +4105,64 @@ def test_local_mulexp2expadd():
4105
4105
assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4106
4106
4107
4107
4108
+ def test_local_mulpow2powadd ():
4109
+ x = scalar ("x" )
4110
+ y = scalar ("y" )
4111
+ z = scalar ("z" )
4112
+ w = scalar ("w" )
4113
+ v = scalar ("v" )
4114
+ u = scalar ("u" )
4115
+ t = scalar ("t" )
4116
+ s = scalar ("s" )
4117
+ a = scalar ("a" )
4118
+ b = scalar ("b" )
4119
+ c = scalar ("c" )
4120
+
4121
+ # 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
4122
+ op = 2 ** x * 2 ** y * 2 ** z * 2 ** w
4123
+ f = function ([x , y , z , w ], op )
4124
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), 2 ** (3 + 4 + 5 + 6 ))
4125
+ graph = f .maker .fgraph .toposort ()
4126
+ assert isinstance (graph [0 ].op , Elemwise )
4127
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4128
+ assert any (isinstance (n .op , aes .Add ) for n in inner_graph )
4129
+ assert not any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4130
+
4131
+ # 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
4132
+ op = 2 ** x * a ** y * 2 ** z * b ** w * c ** v * a ** u * s * b ** t
4133
+ f = function ([x , y , z , w , v , u , t , s , a , b , c ], op )
4134
+ utt .assert_allclose (
4135
+ f (4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 2.5 , 3 , 3.5 ),
4136
+ 2 ** (4 + 6 ) * 2.5 ** (5 + 9 ) * 3 ** (7 + 10 ) * 3.5 ** 8 * 11 ,
4137
+ )
4138
+ graph = f .maker .fgraph .toposort ()
4139
+ assert isinstance (graph [0 ].op , Elemwise )
4140
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4141
+ assert len ([True for n in inner_graph if isinstance (n .op , aes .Add )]) == 3
4142
+ assert len ([True for n in inner_graph if isinstance (n .op , aes .Pow )]) == 4
4143
+ assert any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4144
+
4145
+ # (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
4146
+ op = 2 ** x / 2 ** y * (a ** z / a ** w )
4147
+ f = function ([x , y , z , w , a ], op )
4148
+ utt .assert_allclose (f (3 , 5 , 6 , 4 , 7 ), 2 ** (3 - 5 ) * 7 ** (6 - 4 ))
4149
+ graph = f .maker .fgraph .toposort ()
4150
+ assert isinstance (graph [0 ].op , Elemwise )
4151
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4152
+ assert len ([True for n in inner_graph if isinstance (n .op , aes .Sub )]) == 2
4153
+ assert any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4154
+
4155
+ # a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
4156
+ op = a ** x * a ** y * exp (z ) * exp (w )
4157
+ f = function ([x , y , z , w , a ], op )
4158
+ utt .assert_allclose (f (3 , 4 , 5 , 6 , 2 ), 2 ** (3 + 4 ) * np .exp (5 + 6 ))
4159
+ graph = f .maker .fgraph .toposort ()
4160
+ assert isinstance (graph [0 ].op , Elemwise )
4161
+ inner_graph = graph [0 ].op .scalar_op .fgraph .toposort ()
4162
+ assert len ([True for n in inner_graph if isinstance (n .op , aes .Add )]) == 2
4163
+ assert any (isinstance (n .op , aes .Mul ) for n in inner_graph )
4164
+
4165
+
4108
4166
def test_local_expm1 ():
4109
4167
x = matrix ("x" )
4110
4168
u = scalar ("u" )
0 commit comments