Skip to content

Commit 99a3945

Browse files
committed
Continued downstream_1288 in new fork. Added Black formatting.
1 parent f34a885 commit 99a3945

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

pytensor/scalar/math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1632,4 +1632,4 @@ def c_code(self, *args, **kwargs):
16321632
raise NotImplementedError()
16331633

16341634

1635-
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
1635+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

tests/tensor/test_math_scipy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ def test_deprecated_module():
822822
eps=2e-10,
823823
)
824824

825+
825826
class TestBetaIncGrad:
826827
def test_stan_grad_partial(self):
827828
# This test combines the following STAN tests:

tests/tensor/test_special.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
from pytensor.compile.function import function
99
from pytensor.configdefaults import config
1010
from pytensor.tensor import scalar, scalars
11-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax, poch, factorial
11+
from pytensor.tensor.special import (
12+
LogSoftmax,
13+
Softmax,
14+
SoftmaxGrad,
15+
log_softmax,
16+
softmax,
17+
poch,
18+
factorial,
19+
)
1220
from pytensor.tensor.type import matrix, tensor3, tensor4, vector
1321
from tests.tensor.utils import random_ranged
1422
from tests import unittest_tools as utt
@@ -140,9 +148,7 @@ def test_valid_axis(self):
140148
SoftmaxGrad(-4)(*x)
141149

142150

143-
@pytest.mark.parametrize(
144-
"z, m", [random_ranged(0, 5, (2,)), random_ranged(0, 5, (2,))]
145-
)
151+
@pytest.mark.parametrize("z, m", [random_ranged(0, 5, (2,)), random_ranged(0, 5, (2,))])
146152
def test_poch(z, m):
147153

148154
_z, _m = scalars("z", "m")
@@ -162,8 +168,7 @@ def test_factorial(n):
162168

163169
actual_fn = function([_n], factorial(_n))
164170
actual = actual_fn(n)
165-
171+
166172
expected = scipy_factorial(n)
167173

168174
assert np.allclose(actual, expected)
169-

0 commit comments

Comments
 (0)