Skip to content

Commit cab2655

Browse files
ricardoV94ColtAllen
andcommitted
Implement Hyp2F1 and gradients
Co-authored-by: ColtAllen <[email protected]>
1 parent 211e0cb commit cab2655

File tree

4 files changed

+367
-0
lines changed

4 files changed

+367
-0
lines changed

pytensor/scalar/math.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,3 +1481,170 @@ def c_code(self, *args, **kwargs):
14811481

14821482

14831483
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1484+
1485+
1486+
class Hyp2F1(ScalarOp):
1487+
"""
1488+
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
1490+
"""
1491+
1492+
nin = 4
1493+
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
1494+
1495+
@staticmethod
1496+
def st_impl(a, b, c, z):
1497+
return scipy.special.hyp2f1(a, b, c, z)
1498+
1499+
def impl(self, a, b, c, z):
1500+
return Hyp2F1.st_impl(a, b, c, z)
1501+
1502+
def grad(self, inputs, grads):
1503+
a, b, c, z = inputs
1504+
(gz,) = grads
1505+
return [
1506+
gz * hyp2f1_der(a, b, c, z, wrt=0),
1507+
gz * hyp2f1_der(a, b, c, z, wrt=1),
1508+
gz * hyp2f1_der(a, b, c, z, wrt=2),
1509+
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
1510+
]
1511+
1512+
def c_code(self, *args, **kwargs):
1513+
raise NotImplementedError()
1514+
1515+
1516+
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1517+
1518+
1519+
class Hyp2F1Der(ScalarOp):
1520+
"""
1521+
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1522+
1523+
"""
1524+
1525+
nin = 5
1526+
1527+
def impl(self, a, b, c, z, wrt):
1528+
"""Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp"""
1529+
1530+
def check_2f1_converges(a, b, c, z) -> bool:
1531+
num_terms = 0
1532+
is_polynomial = False
1533+
1534+
def is_nonpositive_integer(x):
1535+
return x <= 0 and x.is_integer()
1536+
1537+
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1538+
is_polynomial = True
1539+
num_terms = int(np.floor(abs(a)))
1540+
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1541+
is_polynomial = True
1542+
num_terms = int(np.floor(abs(b)))
1543+
1544+
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1545+
1546+
return not is_undefined and (
1547+
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1548+
)
1549+
1550+
def compute_grad_2f1(a, b, c, z, wrt):
1551+
"""
1552+
Notes
1553+
-----
1554+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1555+
β_{k+1}/β_{k} = A(k)/B(k)
1556+
β_{k+1} = A(k)/B(k) * β_{k}
1557+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1558+
1559+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1560+
1561+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1562+
by dropping the respective term
1563+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1564+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1565+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1566+
1567+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1568+
tracking the sign of the terms involved.
1569+
"""
1570+
1571+
wrt_a = wrt_b = False
1572+
if wrt == 0:
1573+
wrt_a = True
1574+
elif wrt == 1:
1575+
wrt_b = True
1576+
elif wrt != 2:
1577+
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1578+
1579+
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1580+
max_steps = int(1e6)
1581+
precision = 1e-14
1582+
1583+
res = 0
1584+
1585+
if z == 0:
1586+
return res
1587+
1588+
log_g_old = -np.inf
1589+
log_t_old = 0.0
1590+
log_t_new = 0.0
1591+
sign_z = np.sign(z)
1592+
log_z = np.log(np.abs(z))
1593+
1594+
log_g_old_sign = 1
1595+
log_t_old_sign = 1
1596+
log_t_new_sign = 1
1597+
sign_zk = sign_z
1598+
1599+
for k in range(max_steps):
1600+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1601+
if p == 0:
1602+
return res
1603+
log_t_new += np.log(np.abs(p)) + log_z
1604+
log_t_new_sign = np.sign(p) * log_t_new_sign
1605+
1606+
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1607+
if wrt_a:
1608+
term += np.reciprocal(a + k)
1609+
elif wrt_b:
1610+
term += np.reciprocal(b + k)
1611+
else:
1612+
term -= np.reciprocal(c + k)
1613+
1614+
log_g_old = log_t_new + np.log(np.abs(term))
1615+
log_g_old_sign = np.sign(term) * log_t_new_sign
1616+
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1617+
res += g_current
1618+
1619+
log_t_old = log_t_new
1620+
log_t_old_sign = log_t_new_sign
1621+
sign_zk *= sign_z
1622+
1623+
if k >= min_steps and np.abs(g_current) <= precision:
1624+
return res
1625+
1626+
warnings.warn(
1627+
f"hyp2f1_der did not converge after {k} iterations",
1628+
RuntimeWarning,
1629+
)
1630+
return np.nan
1631+
1632+
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1633+
if not check_2f1_converges(a, b, c, z):
1634+
warnings.warn(
1635+
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1636+
RuntimeWarning,
1637+
)
1638+
return np.nan
1639+
1640+
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1641+
1642+
def __call__(self, a, b, c, z, wrt):
1643+
# This allows wrt to be a keyword argument
1644+
return super().__call__(a, b, c, z, wrt)
1645+
1646+
def c_code(self, *args, **kwargs):
1647+
raise NotImplementedError()
1648+
1649+
1650+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def conj_inplace(a):
392392
"""elementwise conjugate (inplace on `a`)"""
393393

394394

395+
@scalar_elemwise
396+
def hyp2f1_inplace(a, b, c, z):
397+
"""gaussian hypergeometric function"""
398+
399+
395400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
396401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
397402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ def gammal(k, x):
13841384
"""Lower incomplete gamma function."""
13851385

13861386

1387+
@scalar_elemwise
1388+
def hyp2f1(a, b, c, z):
1389+
"""Gaussian hypergeometric function."""
1390+
1391+
13871392
@scalar_elemwise
13881393
def j0(x):
13891394
"""Bessel function of the first kind of order 0."""
@@ -3132,4 +3137,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31323137
"power",
31333138
"logaddexp",
31343139
"logsumexp",
3140+
"hyp2f1",
31353141
]

0 commit comments

Comments
 (0)