Skip to content

Commit 10d4d94

Browse files
ricardoV94ColtAllen
andcommitted
Implement Hyp2F1 and gradients
Co-authored-by: ColtAllen <[email protected]>
1 parent 4278ce4 commit 10d4d94

File tree

4 files changed

+366
-0
lines changed

4 files changed

+366
-0
lines changed

pytensor/scalar/math.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,3 +1481,169 @@ def c_code(self, *args, **kwargs):
14811481

14821482

14831483
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1484+
1485+
1486+
class Hyp2F1(ScalarOp):
1487+
"""
1488+
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
1490+
"""
1491+
1492+
nin = 4
1493+
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
1494+
1495+
@staticmethod
1496+
def st_impl(a, b, c, z):
1497+
return scipy.special.hyp2f1(a, b, c, z)
1498+
1499+
def impl(self, a, b, c, z):
1500+
return Hyp2F1.st_impl(a, b, c, z)
1501+
1502+
def grad(self, inputs, grads):
1503+
a, b, c, z = inputs
1504+
(gz,) = grads
1505+
return [
1506+
gz * hyp2f1_der(a, b, c, z, wrt=0),
1507+
gz * hyp2f1_der(a, b, c, z, wrt=1),
1508+
gz * hyp2f1_der(a, b, c, z, wrt=2),
1509+
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
1510+
]
1511+
1512+
def c_code(self, *args, **kwargs):
1513+
raise NotImplementedError()
1514+
1515+
1516+
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1517+
1518+
1519+
class Hyp2F1Der(ScalarOp):
1520+
"""
1521+
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1522+
1523+
Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1524+
"""
1525+
1526+
nin = 5
1527+
1528+
def impl(self, a, b, c, z, wrt):
1529+
def check_2f1_converges(a, b, c, z) -> bool:
1530+
num_terms = 0
1531+
is_polynomial = False
1532+
1533+
def is_nonpositive_integer(x):
1534+
return x <= 0 and x.is_integer()
1535+
1536+
if is_nonpositive_integer(a) and abs(a) >= num_terms:
1537+
is_polynomial = True
1538+
num_terms = int(np.floor(abs(a)))
1539+
if is_nonpositive_integer(b) and abs(b) >= num_terms:
1540+
is_polynomial = True
1541+
num_terms = int(np.floor(abs(b)))
1542+
1543+
is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms
1544+
1545+
return not is_undefined and (
1546+
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
1547+
)
1548+
1549+
def compute_grad_2f1(a, b, c, z, wrt):
1550+
"""
1551+
Notes
1552+
-----
1553+
The algorithm can be derived by looking at the ratio of two successive terms in the series
1554+
β_{k+1}/β_{k} = A(k)/B(k)
1555+
β_{k+1} = A(k)/B(k) * β_{k}
1556+
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1557+
1558+
In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1559+
1560+
The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1561+
by dropping the respective term
1562+
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1563+
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1564+
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1565+
1566+
The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1567+
tracking their signs.
1568+
"""
1569+
1570+
wrt_a = wrt_b = False
1571+
if wrt == 0:
1572+
wrt_a = True
1573+
elif wrt == 1:
1574+
wrt_b = True
1575+
elif wrt != 2:
1576+
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")
1577+
1578+
min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1579+
max_steps = int(1e6)
1580+
precision = 1e-14
1581+
1582+
res = 0
1583+
1584+
if z == 0:
1585+
return res
1586+
1587+
log_g_old = -np.inf
1588+
log_t_old = 0.0
1589+
log_t_new = 0.0
1590+
sign_z = np.sign(z)
1591+
log_z = np.log(np.abs(z))
1592+
1593+
log_g_old_sign = 1
1594+
log_t_old_sign = 1
1595+
log_t_new_sign = 1
1596+
sign_zk = sign_z
1597+
1598+
for k in range(max_steps):
1599+
p = (a + k) * (b + k) / ((c + k) * (k + 1))
1600+
if p == 0:
1601+
return res
1602+
log_t_new += np.log(np.abs(p)) + log_z
1603+
log_t_new_sign = np.sign(p) * log_t_new_sign
1604+
1605+
term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
1606+
if wrt_a:
1607+
term += np.reciprocal(a + k)
1608+
elif wrt_b:
1609+
term += np.reciprocal(b + k)
1610+
else:
1611+
term -= np.reciprocal(c + k)
1612+
1613+
log_g_old = log_t_new + np.log(np.abs(term))
1614+
log_g_old_sign = np.sign(term) * log_t_new_sign
1615+
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
1616+
res += g_current
1617+
1618+
log_t_old = log_t_new
1619+
log_t_old_sign = log_t_new_sign
1620+
sign_zk *= sign_z
1621+
1622+
if k >= min_steps and np.abs(g_current) <= precision:
1623+
return res
1624+
1625+
warnings.warn(
1626+
f"hyp2f1_der did not converge after {k} iterations",
1627+
RuntimeWarning,
1628+
)
1629+
return np.nan
1630+
1631+
# TODO: We could implement the Euler transform to expand supported domain, as Stan does
1632+
if not check_2f1_converges(a, b, c, z):
1633+
warnings.warn(
1634+
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
1635+
RuntimeWarning,
1636+
)
1637+
return np.nan
1638+
1639+
return compute_grad_2f1(a, b, c, z, wrt=wrt)
1640+
1641+
def __call__(self, a, b, c, z, wrt):
1642+
# This allows wrt to be a keyword argument
1643+
return super().__call__(a, b, c, z, wrt)
1644+
1645+
def c_code(self, *args, **kwargs):
1646+
raise NotImplementedError()
1647+
1648+
1649+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def conj_inplace(a):
392392
"""elementwise conjugate (inplace on `a`)"""
393393

394394

395+
@scalar_elemwise
396+
def hyp2f1_inplace(a, b, c, z):
397+
"""gaussian hypergeometric function"""
398+
399+
395400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
396401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
397402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ def gammal(k, x):
13841384
"""Lower incomplete gamma function."""
13851385

13861386

1387+
@scalar_elemwise
1388+
def hyp2f1(a, b, c, z):
1389+
"""Gaussian hypergeometric function."""
1390+
1391+
13871392
@scalar_elemwise
13881393
def j0(x):
13891394
"""Bessel function of the first kind of order 0."""
@@ -3132,4 +3137,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31323137
"power",
31333138
"logaddexp",
31343139
"logsumexp",
3140+
"hyp2f1",
31353141
]

0 commit comments

Comments
 (0)