10
10
11
11
import numpy as np
12
12
import scipy .special
13
- import scipy .stats
14
13
15
14
from pytensor .configdefaults import config
16
15
from pytensor .gradient import grad_not_implemented , grad_undefined
@@ -261,12 +260,8 @@ def c_code(self, node, name, inp, out, sub):
261
260
class Owens_t (BinaryScalarOp ):
262
261
nfunc_spec = ("scipy.special.owens_t" , 2 , 1 )
263
262
264
- @staticmethod
265
- def st_impl (h , a ):
266
- return scipy .special .owens_t (h , a )
267
-
268
263
def impl (self , h , a ):
269
- return Owens_t . st_impl (h , a )
264
+ return scipy . special . owens_t (h , a )
270
265
271
266
def grad (self , inputs , grads ):
272
267
(h , a ) = inputs
@@ -290,12 +285,8 @@ def c_code(self, *args, **kwargs):
290
285
class Gamma (UnaryScalarOp ):
291
286
nfunc_spec = ("scipy.special.gamma" , 1 , 1 )
292
287
293
- @staticmethod
294
- def st_impl (x ):
295
- return scipy .special .gamma (x )
296
-
297
288
def impl (self , x ):
298
- return Gamma . st_impl (x )
289
+ return scipy . special . gamma (x )
299
290
300
291
def L_op (self , inputs , outputs , gout ):
301
292
(x ,) = inputs
@@ -329,12 +320,8 @@ class GammaLn(UnaryScalarOp):
329
320
330
321
nfunc_spec = ("scipy.special.gammaln" , 1 , 1 )
331
322
332
- @staticmethod
333
- def st_impl (x ):
334
- return scipy .special .gammaln (x )
335
-
336
323
def impl (self , x ):
337
- return GammaLn . st_impl (x )
324
+ return scipy . special . gammaln (x )
338
325
339
326
def L_op (self , inputs , outputs , grads ):
340
327
(x ,) = inputs
@@ -373,12 +360,8 @@ class Psi(UnaryScalarOp):
373
360
374
361
nfunc_spec = ("scipy.special.psi" , 1 , 1 )
375
362
376
- @staticmethod
377
- def st_impl (x ):
378
- return scipy .special .psi (x )
379
-
380
363
def impl (self , x ):
381
- return Psi . st_impl (x )
364
+ return scipy . special . psi (x )
382
365
383
366
def L_op (self , inputs , outputs , grads ):
384
367
(x ,) = inputs
@@ -464,12 +447,8 @@ class TriGamma(UnaryScalarOp):
464
447
465
448
"""
466
449
467
- @staticmethod
468
- def st_impl (x ):
469
- return scipy .special .polygamma (1 , x )
470
-
471
450
def impl (self , x ):
472
- return TriGamma . st_impl ( x )
451
+ return scipy . special . polygamma ( 1 , x )
473
452
474
453
def L_op (self , inputs , outputs , outputs_gradients ):
475
454
(x ,) = inputs
@@ -567,12 +546,8 @@ def output_types_preference(n_type, x_type):
567
546
# Scipy doesn't support it
568
547
return upgrade_to_float_no_complex (x_type )
569
548
570
- @staticmethod
571
- def st_impl (n , x ):
572
- return scipy .special .polygamma (n , x )
573
-
574
549
def impl (self , n , x ):
575
- return PolyGamma . st_impl (n , x )
550
+ return scipy . special . polygamma (n , x )
576
551
577
552
def L_op (self , inputs , outputs , output_gradients ):
578
553
(n , x ) = inputs
@@ -598,12 +573,8 @@ class GammaInc(BinaryScalarOp):
598
573
599
574
nfunc_spec = ("scipy.special.gammainc" , 2 , 1 )
600
575
601
- @staticmethod
602
- def st_impl (k , x ):
603
- return scipy .special .gammainc (k , x )
604
-
605
576
def impl (self , k , x ):
606
- return GammaInc . st_impl (k , x )
577
+ return scipy . special . gammainc (k , x )
607
578
608
579
def grad (self , inputs , grads ):
609
580
(k , x ) = inputs
@@ -649,12 +620,8 @@ class GammaIncC(BinaryScalarOp):
649
620
650
621
nfunc_spec = ("scipy.special.gammaincc" , 2 , 1 )
651
622
652
- @staticmethod
653
- def st_impl (k , x ):
654
- return scipy .special .gammaincc (k , x )
655
-
656
623
def impl (self , k , x ):
657
- return GammaIncC . st_impl (k , x )
624
+ return scipy . special . gammaincc (k , x )
658
625
659
626
def grad (self , inputs , grads ):
660
627
(k , x ) = inputs
@@ -700,12 +667,8 @@ class GammaIncInv(BinaryScalarOp):
700
667
701
668
nfunc_spec = ("scipy.special.gammaincinv" , 2 , 1 )
702
669
703
- @staticmethod
704
- def st_impl (k , x ):
705
- return scipy .special .gammaincinv (k , x )
706
-
707
670
def impl (self , k , x ):
708
- return GammaIncInv . st_impl (k , x )
671
+ return scipy . special . gammaincinv (k , x )
709
672
710
673
def grad (self , inputs , grads ):
711
674
(k , x ) = inputs
@@ -729,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp):
729
692
730
693
nfunc_spec = ("scipy.special.gammainccinv" , 2 , 1 )
731
694
732
- @staticmethod
733
- def st_impl (k , x ):
734
- return scipy .special .gammainccinv (k , x )
735
-
736
695
def impl (self , k , x ):
737
- return GammaIncCInv . st_impl (k , x )
696
+ return scipy . special . gammainccinv (k , x )
738
697
739
698
def grad (self , inputs , grads ):
740
699
(k , x ) = inputs
@@ -968,12 +927,8 @@ class GammaU(BinaryScalarOp):
968
927
969
928
# Note there is no basic SciPy version so no nfunc_spec.
970
929
971
- @staticmethod
972
- def st_impl (k , x ):
973
- return scipy .special .gammaincc (k , x ) * scipy .special .gamma (k )
974
-
975
930
def impl (self , k , x ):
976
- return GammaU . st_impl (k , x )
931
+ return scipy . special . gammaincc (k , x ) * scipy . special . gamma ( k )
977
932
978
933
def c_support_code (self , ** kwargs ):
979
934
return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -1004,12 +959,8 @@ class GammaL(BinaryScalarOp):
1004
959
1005
960
# Note there is no basic SciPy version so no nfunc_spec.
1006
961
1007
- @staticmethod
1008
- def st_impl (k , x ):
1009
- return scipy .special .gammainc (k , x ) * scipy .special .gamma (k )
1010
-
1011
962
def impl (self , k , x ):
1012
- return GammaL . st_impl (k , x )
963
+ return scipy . special . gammainc (k , x ) * scipy . special . gamma ( k )
1013
964
1014
965
def c_support_code (self , ** kwargs ):
1015
966
return (C_CODE_PATH / "gamma.c" ).read_text (encoding = "utf-8" )
@@ -1040,12 +991,8 @@ class Jv(BinaryScalarOp):
1040
991
1041
992
nfunc_spec = ("scipy.special.jv" , 2 , 1 )
1042
993
1043
- @staticmethod
1044
- def st_impl (v , x ):
1045
- return scipy .special .jv (v , x )
1046
-
1047
994
def impl (self , v , x ):
1048
- return self . st_impl (v , x )
995
+ return scipy . special . jv (v , x )
1049
996
1050
997
def grad (self , inputs , grads ):
1051
998
v , x = inputs
@@ -1069,12 +1016,8 @@ class J1(UnaryScalarOp):
1069
1016
1070
1017
nfunc_spec = ("scipy.special.j1" , 1 , 1 )
1071
1018
1072
- @staticmethod
1073
- def st_impl (x ):
1074
- return scipy .special .j1 (x )
1075
-
1076
1019
def impl (self , x ):
1077
- return self . st_impl (x )
1020
+ return scipy . special . j1 (x )
1078
1021
1079
1022
def grad (self , inputs , grads ):
1080
1023
(x ,) = inputs
@@ -1100,12 +1043,8 @@ class J0(UnaryScalarOp):
1100
1043
1101
1044
nfunc_spec = ("scipy.special.j0" , 1 , 1 )
1102
1045
1103
- @staticmethod
1104
- def st_impl (x ):
1105
- return scipy .special .j0 (x )
1106
-
1107
1046
def impl (self , x ):
1108
- return self . st_impl (x )
1047
+ return scipy . special . j0 (x )
1109
1048
1110
1049
def grad (self , inp , grads ):
1111
1050
(x ,) = inp
@@ -1131,12 +1070,8 @@ class Iv(BinaryScalarOp):
1131
1070
1132
1071
nfunc_spec = ("scipy.special.iv" , 2 , 1 )
1133
1072
1134
- @staticmethod
1135
- def st_impl (v , x ):
1136
- return scipy .special .iv (v , x )
1137
-
1138
1073
def impl (self , v , x ):
1139
- return self . st_impl (v , x )
1074
+ return scipy . special . iv (v , x )
1140
1075
1141
1076
def grad (self , inputs , grads ):
1142
1077
v , x = inputs
@@ -1160,12 +1095,8 @@ class I1(UnaryScalarOp):
1160
1095
1161
1096
nfunc_spec = ("scipy.special.i1" , 1 , 1 )
1162
1097
1163
- @staticmethod
1164
- def st_impl (x ):
1165
- return scipy .special .i1 (x )
1166
-
1167
1098
def impl (self , x ):
1168
- return self . st_impl (x )
1099
+ return scipy . special . i1 (x )
1169
1100
1170
1101
def grad (self , inputs , grads ):
1171
1102
(x ,) = inputs
@@ -1186,12 +1117,8 @@ class I0(UnaryScalarOp):
1186
1117
1187
1118
nfunc_spec = ("scipy.special.i0" , 1 , 1 )
1188
1119
1189
- @staticmethod
1190
- def st_impl (x ):
1191
- return scipy .special .i0 (x )
1192
-
1193
1120
def impl (self , x ):
1194
- return self . st_impl (x )
1121
+ return scipy . special . i0 (x )
1195
1122
1196
1123
def grad (self , inp , grads ):
1197
1124
(x ,) = inp
@@ -1212,12 +1139,8 @@ class Ive(BinaryScalarOp):
1212
1139
1213
1140
nfunc_spec = ("scipy.special.ive" , 2 , 1 )
1214
1141
1215
- @staticmethod
1216
- def st_impl (v , x ):
1217
- return scipy .special .ive (v , x )
1218
-
1219
1142
def impl (self , v , x ):
1220
- return self . st_impl (v , x )
1143
+ return scipy . special . ive (v , x )
1221
1144
1222
1145
def grad (self , inputs , grads ):
1223
1146
v , x = inputs
@@ -1241,12 +1164,8 @@ class Kve(BinaryScalarOp):
1241
1164
1242
1165
nfunc_spec = ("scipy.special.kve" , 2 , 1 )
1243
1166
1244
- @staticmethod
1245
- def st_impl (v , x ):
1246
- return scipy .special .kve (v , x )
1247
-
1248
1167
def impl (self , v , x ):
1249
- return self . st_impl (v , x )
1168
+ return scipy . special . kve (v , x )
1250
1169
1251
1170
def L_op (self , inputs , outputs , output_grads ):
1252
1171
v , x = inputs
@@ -1327,8 +1246,7 @@ class Softplus(UnaryScalarOp):
1327
1246
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
1328
1247
"""
1329
1248
1330
- @staticmethod
1331
- def static_impl (x ):
1249
+ def impl (self , x ):
1332
1250
# If x is an int8 or uint8, numpy.exp will compute the result in
1333
1251
# half-precision (float16), where we want float32.
1334
1252
not_int8 = str (getattr (x , "dtype" , "" )) not in ("int8" , "uint8" )
@@ -1343,9 +1261,6 @@ def static_impl(x):
1343
1261
else :
1344
1262
return x
1345
1263
1346
- def impl (self , x ):
1347
- return Softplus .static_impl (x )
1348
-
1349
1264
def grad (self , inp , grads ):
1350
1265
(x ,) = inp
1351
1266
(gz ,) = grads
@@ -1408,16 +1323,12 @@ class Log1mexp(UnaryScalarOp):
1408
1323
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
1409
1324
"""
1410
1325
1411
- @staticmethod
1412
- def static_impl (x ):
1326
+ def impl (self , x ):
1413
1327
if x < np .log (0.5 ):
1414
1328
return np .log1p (- np .exp (x ))
1415
1329
else :
1416
1330
return np .log (- np .expm1 (x ))
1417
1331
1418
- def impl (self , x ):
1419
- return Log1mexp .static_impl (x )
1420
-
1421
1332
def grad (self , inp , grads ):
1422
1333
(x ,) = inp
1423
1334
(gz ,) = grads
@@ -1749,12 +1660,8 @@ class Hyp2F1(ScalarOp):
1749
1660
nin = 4
1750
1661
nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
1751
1662
1752
- @staticmethod
1753
- def st_impl (a , b , c , z ):
1754
- return scipy .special .hyp2f1 (a , b , c , z )
1755
-
1756
1663
def impl (self , a , b , c , z ):
1757
- return Hyp2F1 . st_impl (a , b , c , z )
1664
+ return scipy . special . hyp2f1 (a , b , c , z )
1758
1665
1759
1666
def grad (self , inputs , grads ):
1760
1667
a , b , c , z = inputs
0 commit comments