Skip to content

Commit 0b07727

Browse files
committed
Remove unused ScalarOp.st_impl
1 parent 60c2d92 commit 0b07727

File tree

2 files changed

+25
-126
lines changed

2 files changed

+25
-126
lines changed

pytensor/scalar/math.py

Lines changed: 23 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
import numpy as np
1212
import scipy.special
13-
import scipy.stats
1413

1514
from pytensor.configdefaults import config
1615
from pytensor.gradient import grad_not_implemented, grad_undefined
@@ -261,12 +260,8 @@ def c_code(self, node, name, inp, out, sub):
261260
class Owens_t(BinaryScalarOp):
262261
nfunc_spec = ("scipy.special.owens_t", 2, 1)
263262

264-
@staticmethod
265-
def st_impl(h, a):
266-
return scipy.special.owens_t(h, a)
267-
268263
def impl(self, h, a):
269-
return Owens_t.st_impl(h, a)
264+
return scipy.special.owens_t(h, a)
270265

271266
def grad(self, inputs, grads):
272267
(h, a) = inputs
@@ -290,12 +285,8 @@ def c_code(self, *args, **kwargs):
290285
class Gamma(UnaryScalarOp):
291286
nfunc_spec = ("scipy.special.gamma", 1, 1)
292287

293-
@staticmethod
294-
def st_impl(x):
295-
return scipy.special.gamma(x)
296-
297288
def impl(self, x):
298-
return Gamma.st_impl(x)
289+
return scipy.special.gamma(x)
299290

300291
def L_op(self, inputs, outputs, gout):
301292
(x,) = inputs
@@ -329,12 +320,8 @@ class GammaLn(UnaryScalarOp):
329320

330321
nfunc_spec = ("scipy.special.gammaln", 1, 1)
331322

332-
@staticmethod
333-
def st_impl(x):
334-
return scipy.special.gammaln(x)
335-
336323
def impl(self, x):
337-
return GammaLn.st_impl(x)
324+
return scipy.special.gammaln(x)
338325

339326
def L_op(self, inputs, outputs, grads):
340327
(x,) = inputs
@@ -373,12 +360,8 @@ class Psi(UnaryScalarOp):
373360

374361
nfunc_spec = ("scipy.special.psi", 1, 1)
375362

376-
@staticmethod
377-
def st_impl(x):
378-
return scipy.special.psi(x)
379-
380363
def impl(self, x):
381-
return Psi.st_impl(x)
364+
return scipy.special.psi(x)
382365

383366
def L_op(self, inputs, outputs, grads):
384367
(x,) = inputs
@@ -464,12 +447,8 @@ class TriGamma(UnaryScalarOp):
464447
465448
"""
466449

467-
@staticmethod
468-
def st_impl(x):
469-
return scipy.special.polygamma(1, x)
470-
471450
def impl(self, x):
472-
return TriGamma.st_impl(x)
451+
return scipy.special.polygamma(1, x)
473452

474453
def L_op(self, inputs, outputs, outputs_gradients):
475454
(x,) = inputs
@@ -567,12 +546,8 @@ def output_types_preference(n_type, x_type):
567546
# Scipy doesn't support it
568547
return upgrade_to_float_no_complex(x_type)
569548

570-
@staticmethod
571-
def st_impl(n, x):
572-
return scipy.special.polygamma(n, x)
573-
574549
def impl(self, n, x):
575-
return PolyGamma.st_impl(n, x)
550+
return scipy.special.polygamma(n, x)
576551

577552
def L_op(self, inputs, outputs, output_gradients):
578553
(n, x) = inputs
@@ -598,12 +573,8 @@ class GammaInc(BinaryScalarOp):
598573

599574
nfunc_spec = ("scipy.special.gammainc", 2, 1)
600575

601-
@staticmethod
602-
def st_impl(k, x):
603-
return scipy.special.gammainc(k, x)
604-
605576
def impl(self, k, x):
606-
return GammaInc.st_impl(k, x)
577+
return scipy.special.gammainc(k, x)
607578

608579
def grad(self, inputs, grads):
609580
(k, x) = inputs
@@ -649,12 +620,8 @@ class GammaIncC(BinaryScalarOp):
649620

650621
nfunc_spec = ("scipy.special.gammaincc", 2, 1)
651622

652-
@staticmethod
653-
def st_impl(k, x):
654-
return scipy.special.gammaincc(k, x)
655-
656623
def impl(self, k, x):
657-
return GammaIncC.st_impl(k, x)
624+
return scipy.special.gammaincc(k, x)
658625

659626
def grad(self, inputs, grads):
660627
(k, x) = inputs
@@ -700,12 +667,8 @@ class GammaIncInv(BinaryScalarOp):
700667

701668
nfunc_spec = ("scipy.special.gammaincinv", 2, 1)
702669

703-
@staticmethod
704-
def st_impl(k, x):
705-
return scipy.special.gammaincinv(k, x)
706-
707670
def impl(self, k, x):
708-
return GammaIncInv.st_impl(k, x)
671+
return scipy.special.gammaincinv(k, x)
709672

710673
def grad(self, inputs, grads):
711674
(k, x) = inputs
@@ -729,12 +692,8 @@ class GammaIncCInv(BinaryScalarOp):
729692

730693
nfunc_spec = ("scipy.special.gammainccinv", 2, 1)
731694

732-
@staticmethod
733-
def st_impl(k, x):
734-
return scipy.special.gammainccinv(k, x)
735-
736695
def impl(self, k, x):
737-
return GammaIncCInv.st_impl(k, x)
696+
return scipy.special.gammainccinv(k, x)
738697

739698
def grad(self, inputs, grads):
740699
(k, x) = inputs
@@ -968,12 +927,8 @@ class GammaU(BinaryScalarOp):
968927

969928
# Note there is no basic SciPy version so no nfunc_spec.
970929

971-
@staticmethod
972-
def st_impl(k, x):
973-
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
974-
975930
def impl(self, k, x):
976-
return GammaU.st_impl(k, x)
931+
return scipy.special.gammaincc(k, x) * scipy.special.gamma(k)
977932

978933
def c_support_code(self, **kwargs):
979934
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -1004,12 +959,8 @@ class GammaL(BinaryScalarOp):
1004959

1005960
# Note there is no basic SciPy version so no nfunc_spec.
1006961

1007-
@staticmethod
1008-
def st_impl(k, x):
1009-
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
1010-
1011962
def impl(self, k, x):
1012-
return GammaL.st_impl(k, x)
963+
return scipy.special.gammainc(k, x) * scipy.special.gamma(k)
1013964

1014965
def c_support_code(self, **kwargs):
1015966
return (C_CODE_PATH / "gamma.c").read_text(encoding="utf-8")
@@ -1040,12 +991,8 @@ class Jv(BinaryScalarOp):
1040991

1041992
nfunc_spec = ("scipy.special.jv", 2, 1)
1042993

1043-
@staticmethod
1044-
def st_impl(v, x):
1045-
return scipy.special.jv(v, x)
1046-
1047994
def impl(self, v, x):
1048-
return self.st_impl(v, x)
995+
return scipy.special.jv(v, x)
1049996

1050997
def grad(self, inputs, grads):
1051998
v, x = inputs
@@ -1069,12 +1016,8 @@ class J1(UnaryScalarOp):
10691016

10701017
nfunc_spec = ("scipy.special.j1", 1, 1)
10711018

1072-
@staticmethod
1073-
def st_impl(x):
1074-
return scipy.special.j1(x)
1075-
10761019
def impl(self, x):
1077-
return self.st_impl(x)
1020+
return scipy.special.j1(x)
10781021

10791022
def grad(self, inputs, grads):
10801023
(x,) = inputs
@@ -1100,12 +1043,8 @@ class J0(UnaryScalarOp):
11001043

11011044
nfunc_spec = ("scipy.special.j0", 1, 1)
11021045

1103-
@staticmethod
1104-
def st_impl(x):
1105-
return scipy.special.j0(x)
1106-
11071046
def impl(self, x):
1108-
return self.st_impl(x)
1047+
return scipy.special.j0(x)
11091048

11101049
def grad(self, inp, grads):
11111050
(x,) = inp
@@ -1131,12 +1070,8 @@ class Iv(BinaryScalarOp):
11311070

11321071
nfunc_spec = ("scipy.special.iv", 2, 1)
11331072

1134-
@staticmethod
1135-
def st_impl(v, x):
1136-
return scipy.special.iv(v, x)
1137-
11381073
def impl(self, v, x):
1139-
return self.st_impl(v, x)
1074+
return scipy.special.iv(v, x)
11401075

11411076
def grad(self, inputs, grads):
11421077
v, x = inputs
@@ -1160,12 +1095,8 @@ class I1(UnaryScalarOp):
11601095

11611096
nfunc_spec = ("scipy.special.i1", 1, 1)
11621097

1163-
@staticmethod
1164-
def st_impl(x):
1165-
return scipy.special.i1(x)
1166-
11671098
def impl(self, x):
1168-
return self.st_impl(x)
1099+
return scipy.special.i1(x)
11691100

11701101
def grad(self, inputs, grads):
11711102
(x,) = inputs
@@ -1186,12 +1117,8 @@ class I0(UnaryScalarOp):
11861117

11871118
nfunc_spec = ("scipy.special.i0", 1, 1)
11881119

1189-
@staticmethod
1190-
def st_impl(x):
1191-
return scipy.special.i0(x)
1192-
11931120
def impl(self, x):
1194-
return self.st_impl(x)
1121+
return scipy.special.i0(x)
11951122

11961123
def grad(self, inp, grads):
11971124
(x,) = inp
@@ -1212,12 +1139,8 @@ class Ive(BinaryScalarOp):
12121139

12131140
nfunc_spec = ("scipy.special.ive", 2, 1)
12141141

1215-
@staticmethod
1216-
def st_impl(v, x):
1217-
return scipy.special.ive(v, x)
1218-
12191142
def impl(self, v, x):
1220-
return self.st_impl(v, x)
1143+
return scipy.special.ive(v, x)
12211144

12221145
def grad(self, inputs, grads):
12231146
v, x = inputs
@@ -1241,12 +1164,8 @@ class Kve(BinaryScalarOp):
12411164

12421165
nfunc_spec = ("scipy.special.kve", 2, 1)
12431166

1244-
@staticmethod
1245-
def st_impl(v, x):
1246-
return scipy.special.kve(v, x)
1247-
12481167
def impl(self, v, x):
1249-
return self.st_impl(v, x)
1168+
return scipy.special.kve(v, x)
12501169

12511170
def L_op(self, inputs, outputs, output_grads):
12521171
v, x = inputs
@@ -1327,8 +1246,7 @@ class Softplus(UnaryScalarOp):
13271246
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
13281247
"""
13291248

1330-
@staticmethod
1331-
def static_impl(x):
1249+
def impl(self, x):
13321250
# If x is an int8 or uint8, numpy.exp will compute the result in
13331251
# half-precision (float16), where we want float32.
13341252
not_int8 = str(getattr(x, "dtype", "")) not in ("int8", "uint8")
@@ -1343,9 +1261,6 @@ def static_impl(x):
13431261
else:
13441262
return x
13451263

1346-
def impl(self, x):
1347-
return Softplus.static_impl(x)
1348-
13491264
def grad(self, inp, grads):
13501265
(x,) = inp
13511266
(gz,) = grads
@@ -1408,16 +1323,12 @@ class Log1mexp(UnaryScalarOp):
14081323
"Accurately computing `\log(1-\exp(- \mid a \mid))` Assessed by the Rmpfr package"
14091324
"""
14101325

1411-
@staticmethod
1412-
def static_impl(x):
1326+
def impl(self, x):
14131327
if x < np.log(0.5):
14141328
return np.log1p(-np.exp(x))
14151329
else:
14161330
return np.log(-np.expm1(x))
14171331

1418-
def impl(self, x):
1419-
return Log1mexp.static_impl(x)
1420-
14211332
def grad(self, inp, grads):
14221333
(x,) = inp
14231334
(gz,) = grads
@@ -1749,12 +1660,8 @@ class Hyp2F1(ScalarOp):
17491660
nin = 4
17501661
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
17511662

1752-
@staticmethod
1753-
def st_impl(a, b, c, z):
1754-
return scipy.special.hyp2f1(a, b, c, z)
1755-
17561663
def impl(self, a, b, c, z):
1757-
return Hyp2F1.st_impl(a, b, c, z)
1664+
return scipy.special.hyp2f1(a, b, c, z)
17581665

17591666
def grad(self, inputs, grads):
17601667
a, b, c, z = inputs

pytensor/tensor/xlogx.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,11 @@ class XlogX(ps.UnaryScalarOp):
1010
1111
"""
1212

13-
@staticmethod
14-
def st_impl(x):
13+
def impl(self, x):
1514
if x == 0.0:
1615
return 0.0
1716
return x * np.log(x)
1817

19-
def impl(self, x):
20-
return XlogX.st_impl(x)
21-
2218
def grad(self, inputs, grads):
2319
(x,) = inputs
2420
(gz,) = grads
@@ -45,15 +41,11 @@ class XlogY0(ps.BinaryScalarOp):
4541
4642
"""
4743

48-
@staticmethod
49-
def st_impl(x, y):
44+
def impl(self, x, y):
5045
if x == 0.0:
5146
return 0.0
5247
return x * np.log(y)
5348

54-
def impl(self, x, y):
55-
return XlogY0.st_impl(x, y)
56-
5749
def grad(self, inputs, grads):
5850
x, y = inputs
5951
(gz,) = grads

0 commit comments

Comments
 (0)