Skip to content

Commit 3ad0c72

Browse files
ColtAllentwiecki
authored andcommitted
Staging commit for hyp2f1 gradient tests
1 parent 3976cdf commit 3ad0c72

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,6 +1389,11 @@ def hyp2f1(a, b, c, z):
13891389
"""Gaussian hypergeometric function."""
13901390

13911391

1392+
@scalar_elemwise
1393+
def hyp2f1_der(a, b, c, z):
1394+
"""Derivatives for Gaussian hypergeometric function."""
1395+
1396+
13921397
@scalar_elemwise
13931398
def j0(x):
13941399
"""Bessel function of the first kind of order 0."""
@@ -3138,6 +3143,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31383143
"logaddexp",
31393144
"logsumexp",
31403145
"hyp2f1",
3146+
"hyp2f1_der",
31413147
]
31423148

31433149
DEPRECATED_NAMES = [

tests/tensor/test_math_scipy.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -763,12 +763,22 @@ def test_deprecated_module():
763763
),
764764
)
765765

766+
_good_broadcast_pentanary_hyp2f1_der = dict(
767+
normal=(
768+
random_ranged(0, 1000, (2, 3)),
769+
random_ranged(0, 1000, (2, 3)),
770+
random_ranged(0, 1000, (2, 3)),
771+
random_ranged(0, 0.5, (2, 3)),
772+
integers_ranged(-1, 3, (2, 3)),
773+
),
774+
)
775+
766776
TestHyp2F1Broadcast = makeBroadcastTester(
767777
op=at.hyp2f1,
768778
expected=expected_hyp2f1,
769779
good=_good_broadcast_quaternary_hyp2f1,
780+
grad=_good_broadcast_quaternary_hyp2f1,
770781
eps=2e-10,
771-
mode=mode_no_scipy,
772782
)
773783

774784
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
@@ -802,10 +812,15 @@ def test_deprecated_module():
802812
expected=expected_hyp2f1,
803813
good=_good_broadcast_quaternary_hyp2f1,
804814
eps=2e-10,
805-
mode=mode_no_scipy,
806815
inplace=True,
807816
)
808817

818+
TestHyp2F1DerBroadcast = makeBroadcastTester(
819+
op=at.hyp2f1_der,
820+
expected=expected_hyp2f1,
821+
good=_good_broadcast_pentanary_hyp2f1_der,
822+
eps=2e-10,
823+
)
809824

810825
class TestBetaIncGrad:
811826
def test_stan_grad_partial(self):

0 commit comments

Comments
 (0)