@@ -1481,3 +1481,172 @@ def c_code(self, *args, **kwargs):
1481
1481
1482
1482
1483
1483
betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1484
+
1485
+
1486
+ class Hyp2F1 (ScalarOp ):
1487
+ """
1488
+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489
+
1490
+ """
1491
+
1492
+ nin = 4
1493
+ nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
1494
+
1495
+ @staticmethod
1496
+ def st_impl (a , b , c , z ):
1497
+ return scipy .special .hyp2f1 (a , b , c , z )
1498
+
1499
+ def impl (self , a , b , c , z ):
1500
+ return Hyp2F1 .st_impl (a , b , c , z )
1501
+
1502
+ def grad (self , inputs , grads ):
1503
+ a , b , c , z = inputs
1504
+ (gz ,) = grads
1505
+ return [
1506
+ gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1507
+ gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1508
+ gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1509
+ # NOTE: Stan has a specialized implementation that users Euler's transform
1510
+ # https://github.com/stan-dev/math/blob/95abd90d38259f27c7a6013610fbc7348f2fab4b/stan/math/prim/fun/grad_2F1.hpp#L185-L198
1511
+ gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1512
+ ]
1513
+
1514
+ def c_code (self , * args , ** kwargs ):
1515
+ raise NotImplementedError ()
1516
+
1517
+
1518
+ hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1519
+
1520
+
1521
+ class Hyp2F1Der (ScalarOp ):
1522
+ """
1523
+ Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1524
+
1525
+ """
1526
+
1527
+ nin = 5
1528
+
1529
+ def impl (self , a , b , c , z , wrt ):
1530
+ """Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp"""
1531
+
1532
+ def check_2f1_converges (a , b , c , z ) -> bool :
1533
+ num_terms = 0
1534
+ is_polynomial = False
1535
+
1536
+ def is_nonpositive_integer (x ):
1537
+ return x <= 0 and x .is_integer ()
1538
+
1539
+ if is_nonpositive_integer (a ) and abs (a ) >= num_terms :
1540
+ is_polynomial = True
1541
+ num_terms = int (np .floor (abs (a )))
1542
+ if is_nonpositive_integer (b ) and abs (b ) >= num_terms :
1543
+ is_polynomial = True
1544
+ num_terms = int (np .floor (abs (b )))
1545
+
1546
+ is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1547
+
1548
+ return not is_undefined and (
1549
+ is_polynomial or np .abs (z ) < 1 or (np .abs (z ) == 1 and c > (a + b ))
1550
+ )
1551
+
1552
+ def compute_grad_2f1 (a , b , c , z , wrt ):
1553
+ """
1554
+ Notes
1555
+ -----
1556
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1557
+ β_{k+1}/β_{k} = A(k)/B(k)
1558
+ β_{k+1} = A(k)/B(k) * β_{k}
1559
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1560
+
1561
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1562
+
1563
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1564
+ by dropping the respective term
1565
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1566
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1567
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1568
+
1569
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1570
+ tracking the sign of the terms involved.
1571
+ """
1572
+
1573
+ wrt_a = wrt_b = False
1574
+ if wrt == 0 :
1575
+ wrt_a = True
1576
+ elif wrt == 1 :
1577
+ wrt_b = True
1578
+ elif wrt != 2 :
1579
+ raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1580
+
1581
+ min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1582
+ max_steps = int (1e6 )
1583
+ precision = 1e-14
1584
+
1585
+ res = 0
1586
+
1587
+ if z == 0 :
1588
+ return res
1589
+
1590
+ log_g_old = - np .inf
1591
+ log_t_old = 0.0
1592
+ log_t_new = 0.0
1593
+ sign_z = np .sign (z )
1594
+ log_z = np .log (np .abs (z ))
1595
+
1596
+ log_g_old_sign = 1
1597
+ log_t_old_sign = 1
1598
+ log_t_new_sign = 1
1599
+ sign_zk = sign_z
1600
+
1601
+ for k in range (max_steps ):
1602
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1603
+ if p == 0 :
1604
+ return res
1605
+ log_t_new += np .log (np .abs (p )) + log_z
1606
+ log_t_new_sign = np .sign (p ) * log_t_new_sign
1607
+
1608
+ term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1609
+ if wrt_a :
1610
+ term += np .reciprocal (a + k )
1611
+ elif wrt_b :
1612
+ term += np .reciprocal (b + k )
1613
+ else :
1614
+ term -= np .reciprocal (c + k )
1615
+
1616
+ log_g_old = log_t_new + np .log (np .abs (term ))
1617
+ log_g_old_sign = np .sign (term ) * log_t_new_sign
1618
+ g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1619
+ res += g_current
1620
+
1621
+ log_t_old = log_t_new
1622
+ log_t_old_sign = log_t_new_sign
1623
+ sign_zk *= sign_z
1624
+
1625
+ if k >= min_steps and np .abs (g_current ) <= precision :
1626
+ return res
1627
+
1628
+ warnings .warn (
1629
+ f"hyp2f1_der did not converge after { k } iterations" ,
1630
+ RuntimeWarning ,
1631
+ )
1632
+ return np .nan
1633
+
1634
+ # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1635
+ if not check_2f1_converges (a , b , c , z ):
1636
+ warnings .warn (
1637
+ f"Hyp2F1 does not meet convergence conditions with given arguments a={ a } , b={ b } , c={ c } , z={ z } " ,
1638
+ RuntimeWarning ,
1639
+ )
1640
+ return np .nan
1641
+
1642
+ return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1643
+
1644
+ def __call__ (self , a , b , c , z , wrt ):
1645
+ # This allows wrt to be a keyword argument
1646
+ return super ().__call__ (a , b , c , z , wrt )
1647
+
1648
+ def c_code (self , * args , ** kwargs ):
1649
+ raise NotImplementedError ()
1650
+
1651
+
1652
+ hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
0 commit comments