@@ -1481,3 +1481,170 @@ def c_code(self, *args, **kwargs):
1481
1481
1482
1482
1483
1483
betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1484
+
1485
+
1486
+ class Hyp2F1 (ScalarOp ):
1487
+ """
1488
+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489
+
1490
+ """
1491
+
1492
+ nin = 4
1493
+ nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
1494
+
1495
+ @staticmethod
1496
+ def st_impl (a , b , c , z ):
1497
+ return scipy .special .hyp2f1 (a , b , c , z )
1498
+
1499
+ def impl (self , a , b , c , z ):
1500
+ return Hyp2F1 .st_impl (a , b , c , z )
1501
+
1502
+ def grad (self , inputs , grads ):
1503
+ a , b , c , z = inputs
1504
+ (gz ,) = grads
1505
+ return [
1506
+ gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1507
+ gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1508
+ gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1509
+ gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1510
+ ]
1511
+
1512
+ def c_code (self , * args , ** kwargs ):
1513
+ raise NotImplementedError ()
1514
+
1515
+
1516
+ hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1517
+
1518
+
1519
+ class Hyp2F1Der (ScalarOp ):
1520
+ """
1521
+ Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1522
+
1523
+ """
1524
+
1525
+ nin = 5
1526
+
1527
+ def impl (self , a , b , c , z , wrt ):
1528
+ """Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp"""
1529
+
1530
+ def check_2f1_converges (a , b , c , z ) -> bool :
1531
+ num_terms = 0
1532
+ is_polynomial = False
1533
+
1534
+ def is_nonpositive_integer (x ):
1535
+ return x <= 0 and x .is_integer ()
1536
+
1537
+ if is_nonpositive_integer (a ) and abs (a ) >= num_terms :
1538
+ is_polynomial = True
1539
+ num_terms = int (np .floor (abs (a )))
1540
+ if is_nonpositive_integer (b ) and abs (b ) >= num_terms :
1541
+ is_polynomial = True
1542
+ num_terms = int (np .floor (abs (b )))
1543
+
1544
+ is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1545
+
1546
+ return not is_undefined and (
1547
+ is_polynomial or np .abs (z ) < 1 or (np .abs (z ) == 1 and c > (a + b ))
1548
+ )
1549
+
1550
+ def compute_grad_2f1 (a , b , c , z , wrt ):
1551
+ """
1552
+ Notes
1553
+ -----
1554
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1555
+ β_{k+1}/β_{k} = A(k)/B(k)
1556
+ β_{k+1} = A(k)/B(k) * β_{k}
1557
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1558
+
1559
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1560
+
1561
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1562
+ by dropping the respective term
1563
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1564
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1565
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1566
+
1567
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1568
+ tracking the sign of the terms involved.
1569
+ """
1570
+
1571
+ wrt_a = wrt_b = False
1572
+ if wrt == 0 :
1573
+ wrt_a = True
1574
+ elif wrt == 1 :
1575
+ wrt_b = True
1576
+ elif wrt != 2 :
1577
+ raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1578
+
1579
+ min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1580
+ max_steps = int (1e6 )
1581
+ precision = 1e-14
1582
+
1583
+ res = 0
1584
+
1585
+ if z == 0 :
1586
+ return res
1587
+
1588
+ log_g_old = - np .inf
1589
+ log_t_old = 0.0
1590
+ log_t_new = 0.0
1591
+ sign_z = np .sign (z )
1592
+ log_z = np .log (np .abs (z ))
1593
+
1594
+ log_g_old_sign = 1
1595
+ log_t_old_sign = 1
1596
+ log_t_new_sign = 1
1597
+ sign_zk = sign_z
1598
+
1599
+ for k in range (max_steps ):
1600
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1601
+ if p == 0 :
1602
+ return res
1603
+ log_t_new += np .log (np .abs (p )) + log_z
1604
+ log_t_new_sign = np .sign (p ) * log_t_new_sign
1605
+
1606
+ term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1607
+ if wrt_a :
1608
+ term += np .reciprocal (a + k )
1609
+ elif wrt_b :
1610
+ term += np .reciprocal (b + k )
1611
+ else :
1612
+ term -= np .reciprocal (c + k )
1613
+
1614
+ log_g_old = log_t_new + np .log (np .abs (term ))
1615
+ log_g_old_sign = np .sign (term ) * log_t_new_sign
1616
+ g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1617
+ res += g_current
1618
+
1619
+ log_t_old = log_t_new
1620
+ log_t_old_sign = log_t_new_sign
1621
+ sign_zk *= sign_z
1622
+
1623
+ if k >= min_steps and np .abs (g_current ) <= precision :
1624
+ return res
1625
+
1626
+ warnings .warn (
1627
+ f"hyp2f1_der did not converge after { k } iterations" ,
1628
+ RuntimeWarning ,
1629
+ )
1630
+ return np .nan
1631
+
1632
+ # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1633
+ if not check_2f1_converges (a , b , c , z ):
1634
+ warnings .warn (
1635
+ f"Hyp2F1 does not meet convergence conditions with given arguments a={ a } , b={ b } , c={ c } , z={ z } " ,
1636
+ RuntimeWarning ,
1637
+ )
1638
+ return np .nan
1639
+
1640
+ return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1641
+
1642
+ def __call__ (self , a , b , c , z , wrt ):
1643
+ # This allows wrt to be a keyword argument
1644
+ return super ().__call__ (a , b , c , z , wrt )
1645
+
1646
+ def c_code (self , * args , ** kwargs ):
1647
+ raise NotImplementedError ()
1648
+
1649
+
1650
+ hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
0 commit comments