@@ -4014,6 +4014,161 @@ def test_local_sumsqr2dot():
4014
4014
)
4015
4015
4016
4016
4017
+ def test_local_mul_exp_to_exp_add ():
4018
+ # Default and FAST_RUN modes put a Composite op into the final graph,
4019
+ # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
4020
+ # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
4021
+ mode = get_default_mode ().excluding ("fusion" ).including ("local_mul_exp_to_exp_add" )
4022
+
4023
+ x = scalar ("x" )
4024
+ y = scalar ("y" )
4025
+ z = scalar ("z" )
4026
+ w = scalar ("w" )
4027
+ expx = exp (x )
4028
+ expy = exp (y )
4029
+ expz = exp (z )
4030
+ expw = exp (w )
4031
+
4032
+ # e^x * e^y * e^z * e^w = e^(x+y+z+w)
4033
+ op = expx * expy * expz * expw
4034
+ f = function ([x , y , z , w ], op , mode )
4035
+ pytensor .dprint (f )
4036
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 + 5 + 6 ))
4037
+ graph = f .maker .fgraph .toposort ()
4038
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4039
+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4040
+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4041
+
4042
+ # e^x * e^y * e^z / e^w = e^(x+y+z-w)
4043
+ op = expx * expy * expz / expw
4044
+ f = function ([x , y , z , w ], op , mode )
4045
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 + 5 - 6 ))
4046
+ graph = f .maker .fgraph .toposort ()
4047
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4048
+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4049
+ assert any (isinstance (n .op .scalar_op , aes .Sub ) for n in graph )
4050
+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4051
+ assert not any (isinstance (n .op .scalar_op , aes .TrueDiv ) for n in graph )
4052
+
4053
+ # e^x * e^y / e^z * e^w = e^(x+y-z+w)
4054
+ op = expx * expy / expz * expw
4055
+ f = function ([x , y , z , w ], op , mode )
4056
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 4 - 5 + 6 ))
4057
+ graph = f .maker .fgraph .toposort ()
4058
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4059
+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4060
+ assert any (isinstance (n .op .scalar_op , aes .Sub ) for n in graph )
4061
+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4062
+ assert not any (isinstance (n .op .scalar_op , aes .TrueDiv ) for n in graph )
4063
+
4064
+ # e^x / e^y / e^z = (e^x / e^y) / e^z = e^(x-y-z)
4065
+ op = expx / expy / expz
4066
+ f = function ([x , y , z ], op , mode )
4067
+ utt .assert_allclose (f (3 , 4 , 5 ), np .exp (3 - 4 - 5 ))
4068
+ graph = f .maker .fgraph .toposort ()
4069
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4070
+ assert any (isinstance (n .op .scalar_op , aes .Sub ) for n in graph )
4071
+ assert not any (isinstance (n .op .scalar_op , aes .TrueDiv ) for n in graph )
4072
+
4073
+ # e^x * y * e^z * w = e^(x+z) * y * w
4074
+ op = expx * y * expz * w
4075
+ f = function ([x , y , z , w ], op , mode )
4076
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), np .exp (3 + 5 ) * 4 * 6 )
4077
+ graph = f .maker .fgraph .toposort ()
4078
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4079
+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4080
+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4081
+
4082
+ # expect same for matrices as well
4083
+ mx = matrix ("mx" )
4084
+ my = matrix ("my" )
4085
+ f = function ([mx , my ], exp (mx ) * exp (my ), mode , allow_input_downcast = True )
4086
+ M1 = np .array ([[1.0 , 2.0 ], [3.0 , 4.0 ]])
4087
+ M2 = np .array ([[5.0 , 6.0 ], [7.0 , 8.0 ]])
4088
+ utt .assert_allclose (f (M1 , M2 ), np .exp (M1 + M2 ))
4089
+ graph = f .maker .fgraph .toposort ()
4090
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4091
+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4092
+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4093
+
4094
+ # checking whether further rewrites can proceed after this one as one would expect
4095
+ # e^x * e^(-x) = e^(x-x) = e^0 = 1
4096
+ f = function ([x ], expx * exp (neg (x )), mode )
4097
+ utt .assert_allclose (f (42 ), 1 )
4098
+ graph = f .maker .fgraph .toposort ()
4099
+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4100
+
4101
+ # e^x / e^x = e^(x-x) = e^0 = 1
4102
+ f = function ([x ], expx / expx , mode )
4103
+ utt .assert_allclose (f (42 ), 1 )
4104
+ graph = f .maker .fgraph .toposort ()
4105
+ assert isinstance (graph [0 ].inputs [0 ], TensorConstant )
4106
+
4107
+
4108
+ def test_local_mul_pow_to_pow_add ():
4109
+ # Default and FAST_RUN modes put a Composite op into the final graph,
4110
+ # whereas FAST_COMPILE doesn't. To unify the graph the test cases analyze across runs,
4111
+ # we'll avoid the insertion of Composite ops in each mode by skipping Fusion rewrites
4112
+ mode = (
4113
+ get_default_mode ()
4114
+ .excluding ("fusion" )
4115
+ .including ("local_mul_exp_to_exp_add" )
4116
+ .including ("local_mul_pow_to_pow_add" )
4117
+ )
4118
+
4119
+ x = scalar ("x" )
4120
+ y = scalar ("y" )
4121
+ z = scalar ("z" )
4122
+ w = scalar ("w" )
4123
+ v = scalar ("v" )
4124
+ u = scalar ("u" )
4125
+ t = scalar ("t" )
4126
+ s = scalar ("s" )
4127
+ a = scalar ("a" )
4128
+ b = scalar ("b" )
4129
+ c = scalar ("c" )
4130
+
4131
+ # 2^x * 2^y * 2^z * 2^w = 2^(x+y+z+w)
4132
+ op = 2 ** x * 2 ** y * 2 ** z * 2 ** w
4133
+ f = function ([x , y , z , w ], op , mode )
4134
+ utt .assert_allclose (f (3 , 4 , 5 , 6 ), 2 ** (3 + 4 + 5 + 6 ))
4135
+ graph = f .maker .fgraph .toposort ()
4136
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4137
+ assert any (isinstance (n .op .scalar_op , aes .Add ) for n in graph )
4138
+ assert not any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4139
+
4140
+ # 2^x * a^y * 2^z * b^w * c^v * a^u * s * b^t = 2^(x+z) * a^(y+u) * b^(w+t) * c^v * s
4141
+ op = 2 ** x * a ** y * 2 ** z * b ** w * c ** v * a ** u * s * b ** t
4142
+ f = function ([x , y , z , w , v , u , t , s , a , b , c ], op , mode )
4143
+ utt .assert_allclose (
4144
+ f (4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 2.5 , 3 , 3.5 ),
4145
+ 2 ** (4 + 6 ) * 2.5 ** (5 + 9 ) * 3 ** (7 + 10 ) * 3.5 ** 8 * 11 ,
4146
+ )
4147
+ graph = f .maker .fgraph .toposort ()
4148
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4149
+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Add )]) == 3
4150
+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Pow )]) == 4
4151
+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4152
+
4153
+ # (2^x / 2^y) * (a^z / a^w) = 2^(x-y) * a^(z-w)
4154
+ op = 2 ** x / 2 ** y * (a ** z / a ** w )
4155
+ f = function ([x , y , z , w , a ], op , mode )
4156
+ utt .assert_allclose (f (3 , 5 , 6 , 4 , 7 ), 2 ** (3 - 5 ) * 7 ** (6 - 4 ))
4157
+ graph = f .maker .fgraph .toposort ()
4158
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4159
+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Sub )]) == 2
4160
+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4161
+
4162
+ # a^x * a^y * exp(z) * exp(w) = a^(x+y) * exp(z+w)
4163
+ op = a ** x * a ** y * exp (z ) * exp (w )
4164
+ f = function ([x , y , z , w , a ], op , mode )
4165
+ utt .assert_allclose (f (3 , 4 , 5 , 6 , 2 ), 2 ** (3 + 4 ) * np .exp (5 + 6 ))
4166
+ graph = f .maker .fgraph .toposort ()
4167
+ assert all (isinstance (n .op , Elemwise ) for n in graph )
4168
+ assert len ([True for n in graph if isinstance (n .op .scalar_op , aes .Add )]) == 2
4169
+ assert any (isinstance (n .op .scalar_op , aes .Mul ) for n in graph )
4170
+
4171
+
4017
4172
def test_local_expm1 ():
4018
4173
x = matrix ("x" )
4019
4174
u = scalar ("u" )
0 commit comments