Skip to content

Commit a3dd6b2

Browse files
ricardoV94ColtAllen
andcommitted
Implement Hyp2F1 and gradients
Co-authored-by: ColtAllen <[email protected]>
1 parent 211e0cb commit a3dd6b2

File tree

4 files changed

+369
-0
lines changed

4 files changed

+369
-0
lines changed

pytensor/scalar/math.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,3 +1481,172 @@ def c_code(self, *args, **kwargs):
14811481

14821482

14831483
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1484+
1485+
1486+
class Hyp2F1(ScalarOp):
1487+
"""
1488+
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
1490+
"""
1491+
1492+
nin = 4
1493+
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
1494+
1495+
@staticmethod
1496+
def st_impl(a, b, c, z):
1497+
return scipy.special.hyp2f1(a, b, c, z)
1498+
1499+
def impl(self, a, b, c, z):
1500+
return Hyp2F1.st_impl(a, b, c, z)
1501+
1502+
def grad(self, inputs, grads):
1503+
a, b, c, z = inputs
1504+
(gz,) = grads
1505+
return [
1506+
gz * hyp2f1_der(a, b, c, z, wrt=0),
1507+
gz * hyp2f1_der(a, b, c, z, wrt=1),
1508+
gz * hyp2f1_der(a, b, c, z, wrt=2),
1509+
# NOTE: Stan has a specialized implementation that users Euler's transform
1510+
# https://github.com/stan-dev/math/blob/95abd90d38259f27c7a6013610fbc7348f2fab4b/stan/math/prim/fun/grad_2F1.hpp#L185-L198
1511+
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
1512+
]
1513+
1514+
def c_code(self, *args, **kwargs):
1515+
raise NotImplementedError()
1516+
1517+
1518+
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1519+
1520+
1521+
class Hyp2F1Der(ScalarOp):
1522+
"""
1523+
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1524+
1525+
"""
1526+
1527+
nin = 5
1528+
1529+
def impl(self, a, b, c, z, wrt):
1530+
"""Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp"""
1531+
1532+
def check_2f1_converges(a, b, c, z) -> bool:
1533+
num_terms = 0
1534+
is_polynomial = False
1535+
1536+
def is_nonpositive_integer(x):
1537+
return x <= 0 and x.is_integer()
1538+
1539+
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1540+
is_polynomial = True
1541+
num_terms = int(np.floor(abs(a)))
1542+
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1543+
is_polynomial = True
1544+
num_terms = int(np.floor(abs(b)))
1545+
1546+
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1547+
1548+
return not is_undefined and (
1549+
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1550+
)
1551+
1552+
def compute_grad_2f1(a, b, c, z, wrt):
1553+
"""
1554+
Notes
1555+
-----
1556+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1557+
β_{k+1}/β_{k} = A(k)/B(k)
1558+
β_{k+1} = A(k)/B(k) * β_{k}
1559+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1560+
1561+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1562+
1563+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1564+
by dropping the respective term
1565+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1566+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1567+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1568+
1569+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1570+
tracking the sign of the terms involved.
1571+
"""
1572+
1573+
wrt_a = wrt_b = False
1574+
if wrt == 0:
1575+
wrt_a = True
1576+
elif wrt == 1:
1577+
wrt_b = True
1578+
elif wrt != 2:
1579+
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1580+
1581+
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1582+
max_steps = int(1e6)
1583+
precision = 1e-14
1584+
1585+
res = 0
1586+
1587+
if z == 0:
1588+
return res
1589+
1590+
log_g_old = -np.inf
1591+
log_t_old = 0.0
1592+
log_t_new = 0.0
1593+
sign_z = np.sign(z)
1594+
log_z = np.log(np.abs(z))
1595+
1596+
log_g_old_sign = 1
1597+
log_t_old_sign = 1
1598+
log_t_new_sign = 1
1599+
sign_zk = sign_z
1600+
1601+
for k in range(max_steps):
1602+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1603+
if p == 0:
1604+
return res
1605+
log_t_new += np.log(np.abs(p)) + log_z
1606+
log_t_new_sign = np.sign(p) * log_t_new_sign
1607+
1608+
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1609+
if wrt_a:
1610+
term += np.reciprocal(a + k)
1611+
elif wrt_b:
1612+
term += np.reciprocal(b + k)
1613+
else:
1614+
term -= np.reciprocal(c + k)
1615+
1616+
log_g_old = log_t_new + np.log(np.abs(term))
1617+
log_g_old_sign = np.sign(term) * log_t_new_sign
1618+
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1619+
res += g_current
1620+
1621+
log_t_old = log_t_new
1622+
log_t_old_sign = log_t_new_sign
1623+
sign_zk *= sign_z
1624+
1625+
if k >= min_steps and np.abs(g_current) <= precision:
1626+
return res
1627+
1628+
warnings.warn(
1629+
f"hyp2f1_der did not converge after {k} iterations",
1630+
RuntimeWarning,
1631+
)
1632+
return np.nan
1633+
1634+
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1635+
if not check_2f1_converges(a, b, c, z):
1636+
warnings.warn(
1637+
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1638+
RuntimeWarning,
1639+
)
1640+
return np.nan
1641+
1642+
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1643+
1644+
def __call__(self, a, b, c, z, wrt):
1645+
# This allows wrt to be a keyword argument
1646+
return super().__call__(a, b, c, z, wrt)
1647+
1648+
def c_code(self, *args, **kwargs):
1649+
raise NotImplementedError()
1650+
1651+
1652+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def conj_inplace(a):
392392
"""elementwise conjugate (inplace on `a`)"""
393393

394394

395+
@scalar_elemwise
396+
def hyp2f1_inplace(a, b, c, z):
397+
"""gaussian hypergeometric function"""
398+
399+
395400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
396401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
397402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ def gammal(k, x):
13841384
"""Lower incomplete gamma function."""
13851385

13861386

1387+
@scalar_elemwise
1388+
def hyp2f1(a, b, c, z):
1389+
"""Gaussian hypergeometric function."""
1390+
1391+
13871392
@scalar_elemwise
13881393
def j0(x):
13891394
"""Bessel function of the first kind of order 0."""
@@ -3132,4 +3137,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31323137
"power",
31333138
"logaddexp",
31343139
"logsumexp",
3140+
"hyp2f1",
31353141
]

0 commit comments

Comments
 (0)