Skip to content

Commit 14273e9

Browse files
committed
Rewrote hyp2f1 derivatives with scipy functions
1 parent 753627a commit 14273e9

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

pytensor/scalar/math.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1556,15 +1556,15 @@ def _hyp2f1_da(a, b, c, z):
15561556
else:
15571557
term1 = _infinisum(
15581558
lambda k: (
1559-
(gamma(a + k) / gamma(a))
1560-
* (gamma(b + k) / gamma(b))
1561-
* psi(a + k)
1559+
(scipy.special.gamma(a + k) / scipy.special.gamma(a))
1560+
* (scipy.special.gamma(b + k) / scipy.special.gamma(b))
1561+
* scipy.special.psi(a + k)
15621562
* (z**k)
15631563
)
1564-
/ (gamma(c + k) / gamma(c))
1565-
* gamma(k + 1)
1564+
/ (scipy.special.gamma(c + k) / scipy.special.gamma(c))
1565+
* scipy.special.gamma(k + 1)
15661566
)
1567-
term2 = psi(a) * hyp2f1(a, b, c, z)
1567+
term2 = scipy.special.psi(a) * scipy.special.hyp2f1(a, b, c, z)
15681568

15691569
return term1 - term2
15701570

@@ -1579,15 +1579,15 @@ def _hyp2f1_db(a, b, c, z):
15791579
else:
15801580
term1 = _infinisum(
15811581
lambda k: (
1582-
(gamma(a + k) / gamma(a))
1583-
* (gamma(b + k) / gamma(b))
1584-
* psi(b + k)
1582+
(scipy.special.gamma(a + k) / scipy.special.gamma(a))
1583+
* (scipy.special.gamma(b + k) / scipy.special.gamma(b))
1584+
* scipy.special.psi(b + k)
15851585
* (z**k)
15861586
)
1587-
/ (gamma(c + k) / gamma(c))
1588-
* gamma(k + 1)
1587+
/ (scipy.special.gamma(c + k) / scipy.special.gamma(c))
1588+
* scipy.special.gamma(k + 1)
15891589
)
1590-
term2 = psi(b) * hyp2f1(a, b, c, z)
1590+
term2 = scipy.special.psi(b) * scipy.special.hyp2f1(a, b, c, z)
15911591

15921592
return term1 - term2
15931593

@@ -1599,16 +1599,16 @@ def _hyp2f1_dc(a, b, c, z):
15991599
raise NotImplementedError("Gradient not supported for |z| >= 1")
16001600

16011601
else:
1602-
term1 = psi(c) * hyp2f1(a, b, c, z)
1602+
term1 = scipy.special.psi(c) * scipy.special.hyp2f1(a, b, c, z)
16031603
term2 = _infinisum(
16041604
lambda k: (
1605-
(gamma(a + k) / gamma(a))
1606-
* (gamma(b + k) / gamma(b))
1607-
* psi(c + k)
1605+
(scipy.special.gamma(a + k) / scipy.special.gamma(a))
1606+
* (scipy.special.gamma(b + k) / scipy.special.gamma(b))
1607+
* scipy.special.psi(c + k)
16081608
* (z**k)
16091609
)
1610-
/ (gamma(c + k) / gamma(c))
1611-
* gamma(k + 1)
1610+
/ (scipy.special.gamma(c + k) / scipy.special.gamma(c))
1611+
* scipy.special.gamma(k + 1)
16121612
)
16131613
return term1 - term2
16141614

@@ -1617,7 +1617,7 @@ def _hyp2f1_dz(a, b, c, z):
16171617
Derivative of hyp2f1 wrt z
16181618
"""
16191619

1620-
return ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z)
1620+
return ((a * b) / c) * scipy.special.hyp2f1(a + 1, b + 1, c + 1, z)
16211621

16221622
if wrt == 0:
16231623
return _hyp2f1_da(a, b, c, z)

0 commit comments

Comments
 (0)