Skip to content

Commit 12985af

Browse files
ricardoV94ColtAllen
andcommitted
Implement Hyp2F1 and gradients
Co-authored-by: ColtAllen <[email protected]>
1 parent 3a7815e commit 12985af

File tree

4 files changed

+359
-0
lines changed

4 files changed

+359
-0
lines changed

pytensor/scalar/math.py

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,3 +1481,162 @@ def c_code(self, *args, **kwargs):
14811481

14821482

14831483
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1484+
1485+
1486+
class Hyp2F1(ScalarOp):
1487+
"""
1488+
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
1490+
"""
1491+
1492+
nin = 4
1493+
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
1494+
1495+
@staticmethod
1496+
def st_impl(a, b, c, z):
1497+
return scipy.special.hyp2f1(a, b, c, z)
1498+
1499+
def impl(self, a, b, c, z):
1500+
return Hyp2F1.st_impl(a, b, c, z)
1501+
1502+
def grad(self, inputs, grads):
1503+
a, b, c, z = inputs
1504+
(gz,) = grads
1505+
return [
1506+
gz * hyp2f1_der(a, b, c, z, wrt=0),
1507+
gz * hyp2f1_der(a, b, c, z, wrt=1),
1508+
gz * hyp2f1_der(a, b, c, z, wrt=2),
1509+
# NOTE: Stan has a specialized implementation that users Euler's transform
1510+
# https://github.com/stan-dev/math/blob/95abd90d38259f27c7a6013610fbc7348f2fab4b/stan/math/prim/fun/grad_2F1.hpp#L185-L198
1511+
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
1512+
]
1513+
1514+
def c_code(self, *args, **kwargs):
1515+
raise NotImplementedError()
1516+
1517+
1518+
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1519+
1520+
1521+
class Hyp2F1Der(ScalarOp):
1522+
"""
1523+
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1524+
1525+
"""
1526+
1527+
nin = 5
1528+
1529+
def impl(self, a, b, c, z, wrt):
1530+
"""Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp"""
1531+
1532+
def check_2f1_converges(a, b, c, z) -> bool:
1533+
num_terms = 0
1534+
is_polynomial = False
1535+
1536+
def is_nonpositive_integer(x):
1537+
return x <= 0 and x.is_integer()
1538+
1539+
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1540+
is_polynomial = True
1541+
num_terms = int(np.floor(abs(a)))
1542+
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1543+
is_polynomial = True
1544+
num_terms = int(np.floor(abs(b)))
1545+
1546+
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1547+
1548+
return not is_undefined and (
1549+
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1550+
)
1551+
1552+
def compute_grad_2f1(a, b, c, z, wrt):
1553+
# Note: Stan implementation computes the multiple terms at once. For simplicity we compute only one at a time.
1554+
# If we were to implement this operator symbolically, we could probably rely on a Scan rewrite to merge them.
1555+
# See: https://github.com/pymc-devs/pytensor/issues/83
1556+
1557+
wrt_a = wrt_b = False
1558+
if wrt == 0:
1559+
wrt_a = True
1560+
elif wrt == 1:
1561+
wrt_b = True
1562+
elif wrt != 2:
1563+
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1564+
1565+
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1566+
max_steps = int(1e6)
1567+
precision = 1e-14
1568+
1569+
res = 0
1570+
1571+
if z == 0:
1572+
return res
1573+
1574+
log_g_old = -np.inf
1575+
log_t_old = 0.0
1576+
log_t_new = 0.0
1577+
sign_z = np.sign(z)
1578+
log_z = np.log(np.abs(z))
1579+
1580+
log_g_old_sign = 1
1581+
log_t_old_sign = 1
1582+
log_t_new_sign = 1
1583+
sign_zk = sign_z
1584+
1585+
for k in range(max_steps):
1586+
p = (a + k) * (b + k) / ((c + k) * (1 + k))
1587+
if p == 0:
1588+
return res
1589+
log_t_new += np.log(np.abs(p)) + log_z
1590+
log_t_new_sign = np.sign(p) * log_t_new_sign
1591+
1592+
if wrt_a:
1593+
term = log_g_old_sign * log_t_old_sign * np.exp(
1594+
log_g_old - log_t_old
1595+
) + np.reciprocal(a + k)
1596+
elif wrt_b:
1597+
term = log_g_old_sign * log_t_old_sign * np.exp(
1598+
log_g_old - log_t_old
1599+
) + np.reciprocal(b + k)
1600+
else:
1601+
# wrt_c
1602+
term = log_g_old_sign * log_t_old_sign * np.exp(
1603+
log_g_old - log_t_old
1604+
) - np.reciprocal(c + k)
1605+
1606+
log_g_old = log_t_new + np.log(np.abs(term))
1607+
log_g_old_sign = np.sign(term) * log_t_new_sign
1608+
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1609+
res += g_current
1610+
1611+
log_t_old = log_t_new
1612+
log_t_old_sign = log_t_new_sign
1613+
sign_zk *= sign_z
1614+
1615+
if k >= min_steps and np.abs(g_current) <= precision:
1616+
return res
1617+
1618+
warnings.warn(
1619+
f"hyp2f1_der did not converge after {k} iterations",
1620+
RuntimeWarning,
1621+
)
1622+
return np.nan
1623+
1624+
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1625+
if not check_2f1_converges(a, b, c, z):
1626+
warnings.warn(
1627+
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1628+
RuntimeWarning,
1629+
)
1630+
return np.nan
1631+
1632+
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1633+
1634+
def __call__(self, a, b, c, z, wrt):
1635+
# This allows wrt to be a keyword argument
1636+
return super().__call__(a, b, c, z, wrt)
1637+
1638+
def c_code(self, *args, **kwargs):
1639+
raise NotImplementedError()
1640+
1641+
1642+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def conj_inplace(a):
392392
"""elementwise conjugate (inplace on `a`)"""
393393

394394

395+
@scalar_elemwise
396+
def hyp2f1_inplace(a, b, c, z):
397+
"""gaussian hypergeometric function"""
398+
399+
395400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
396401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
397402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ def gammal(k, x):
13841384
"""Lower incomplete gamma function."""
13851385

13861386

1387+
@scalar_elemwise
1388+
def hyp2f1(a, b, c, z):
1389+
"""Gaussian hypergeometric function."""
1390+
1391+
13871392
@scalar_elemwise
13881393
def j0(x):
13891394
"""Bessel function of the first kind of order 0."""
@@ -3132,4 +3137,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31323137
"power",
31333138
"logaddexp",
31343139
"logsumexp",
3140+
"hyp2f1",
31353141
]

tests/tensor/test_math_scipy.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from contextlib import ExitStack as does_not_warn
2+
13
import numpy as np
24
import pytest
35

@@ -71,6 +73,7 @@ def scipy_special_gammal(k, x):
7173
expected_iv = scipy.special.iv
7274
expected_erfcx = scipy.special.erfcx
7375
expected_sigmoid = scipy.special.expit
76+
expected_hyp2f1 = scipy.special.hyp2f1
7477

7578
TestErfBroadcast = makeBroadcastTester(
7679
op=at.erf,
@@ -820,3 +823,189 @@ def test_beta_inc_stan_grad_combined(self):
820823
np.testing.assert_allclose(
821824
f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb]
822825
)
826+
827+
828+
_good_broadcast_quaternary_hyp2f1 = dict(
829+
normal=(
830+
random_ranged(0, 20, (2, 3)),
831+
random_ranged(0, 20, (2, 3)),
832+
random_ranged(0, 20, (2, 3)),
833+
random_ranged(-0.9, 0.9, (2, 3)),
834+
),
835+
)
836+
837+
TestHyp2F1Broadcast = makeBroadcastTester(
838+
op=at.hyp2f1,
839+
expected=expected_hyp2f1,
840+
good=_good_broadcast_quaternary_hyp2f1,
841+
grad=_good_broadcast_quaternary_hyp2f1,
842+
)
843+
844+
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
845+
op=inplace.hyp2f1_inplace,
846+
expected=expected_hyp2f1,
847+
good=_good_broadcast_quaternary_hyp2f1,
848+
inplace=True,
849+
)
850+
851+
852+
def test_hyp2f1_grad_stan_cases():
853+
"""This test reuses the same test cases as in:
854+
https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp
855+
https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp
856+
857+
Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests
858+
"""
859+
a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z")
860+
betainc_out = at.hyp2f1(a1, a2, b1, z)
861+
betainc_grad = at.grad(betainc_out, [a1, a2, b1, z])
862+
f_grad = function([a1, a2, b1, z], betainc_grad)
863+
864+
rtol = 1e-9 if config.floatX == "float64" else 1e-3
865+
866+
for (
867+
test_a1,
868+
test_a2,
869+
test_b1,
870+
test_z,
871+
expected_dda1,
872+
expected_dda2,
873+
expected_ddb1,
874+
expected_ddz,
875+
) in (
876+
(
877+
3.70975,
878+
1.0,
879+
2.70975,
880+
-0.2,
881+
-0.0488658806159776,
882+
-0.193844936204681,
883+
0.0677809985598383,
884+
0.8652952472723672,
885+
),
886+
(3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313),
887+
(
888+
1.0,
889+
1.0,
890+
1.0,
891+
0.6,
892+
2.290726829685388,
893+
2.290726829685388,
894+
-2.290726829685388,
895+
6.25,
896+
),
897+
(
898+
1.0,
899+
31.0,
900+
41.0,
901+
1.0,
902+
6.825270649241036,
903+
0.4938271604938271,
904+
-0.382716049382716,
905+
17.22222222222223,
906+
),
907+
(
908+
1.0,
909+
-2.1,
910+
41.0,
911+
1.0,
912+
-0.04921317604093563,
913+
0.02256814168279349,
914+
0.00118482743834665,
915+
-0.04854621426218426,
916+
),
917+
(
918+
1.0,
919+
-0.5,
920+
10.6,
921+
0.3,
922+
-0.01443822031245647,
923+
0.02829710651967078,
924+
0.00136986255602642,
925+
-0.04846036062115473,
926+
),
927+
(
928+
1.0,
929+
-0.5,
930+
10.0,
931+
0.3,
932+
-0.0153218866216130,
933+
0.02999436412836072,
934+
0.0015413242328729,
935+
-0.05144686244336445,
936+
),
937+
(
938+
-0.5,
939+
-4.5,
940+
11.0,
941+
0.3,
942+
-0.1227022810085707,
943+
-0.01298849638043795,
944+
-0.0053540982315572,
945+
0.1959735211840362,
946+
),
947+
(
948+
-0.5,
949+
-4.5,
950+
-3.2,
951+
0.9,
952+
0.85880025358111,
953+
0.4677704416159314,
954+
-4.19010422485256,
955+
-2.959196647856408,
956+
),
957+
(
958+
3.70975,
959+
1.0,
960+
2.70975,
961+
-0.2,
962+
-0.0488658806159776,
963+
-0.193844936204681,
964+
0.0677809985598383,
965+
0.865295247272367,
966+
),
967+
(
968+
2.0,
969+
1.0,
970+
2.0,
971+
0.4,
972+
0.4617734323582945,
973+
0.851376039609984,
974+
-0.4617734323582945,
975+
2.777777777777778,
976+
),
977+
(
978+
3.70975,
979+
1.0,
980+
2.70975,
981+
0.999696,
982+
29369830.002773938200417693317785,
983+
36347869.41885337,
984+
-30843032.10697079073015067426929807,
985+
26278034019.28811,
986+
),
987+
# Cases where series does not converge
988+
(1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf),
989+
(1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf),
990+
# Case where series converges under Euler transform (not implemented!)
991+
# (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889),
992+
(1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889),
993+
):
994+
995+
expectation = (
996+
pytest.warns(
997+
RuntimeWarning, match="Hyp2F1 does not meet convergence conditions"
998+
)
999+
if np.any(
1000+
np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz])
1001+
)
1002+
else does_not_warn()
1003+
)
1004+
with expectation:
1005+
result = np.array(f_grad(test_a1, test_a2, test_b1, test_z))
1006+
1007+
np.testing.assert_allclose(
1008+
result,
1009+
np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]),
1010+
rtol=rtol,
1011+
)

0 commit comments

Comments
 (0)