@@ -1481,3 +1481,169 @@ def c_code(self, *args, **kwargs):
1481
1481
1482
1482
1483
1483
betainc_der = BetaIncDer (upgrade_to_float_no_complex , name = "betainc_der" )
1484
+
1485
+
1486
+ class Hyp2F1 (ScalarOp ):
1487
+ """
1488
+ Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489
+
1490
+ """
1491
+
1492
+ nin = 4
1493
+ nfunc_spec = ("scipy.special.hyp2f1" , 4 , 1 )
1494
+
1495
+ @staticmethod
1496
+ def st_impl (a , b , c , z ):
1497
+ return scipy .special .hyp2f1 (a , b , c , z )
1498
+
1499
+ def impl (self , a , b , c , z ):
1500
+ return Hyp2F1 .st_impl (a , b , c , z )
1501
+
1502
+ def grad (self , inputs , grads ):
1503
+ a , b , c , z = inputs
1504
+ (gz ,) = grads
1505
+ return [
1506
+ gz * hyp2f1_der (a , b , c , z , wrt = 0 ),
1507
+ gz * hyp2f1_der (a , b , c , z , wrt = 1 ),
1508
+ gz * hyp2f1_der (a , b , c , z , wrt = 2 ),
1509
+ gz * ((a * b ) / c ) * hyp2f1 (a + 1 , b + 1 , c + 1 , z ),
1510
+ ]
1511
+
1512
+ def c_code (self , * args , ** kwargs ):
1513
+ raise NotImplementedError ()
1514
+
1515
+
1516
+ hyp2f1 = Hyp2F1 (upgrade_to_float , name = "hyp2f1" )
1517
+
1518
+
1519
+ class Hyp2F1Der (ScalarOp ):
1520
+ """
1521
+ Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.
1522
+
1523
+ Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
1524
+ """
1525
+
1526
+ nin = 5
1527
+
1528
+ def impl (self , a , b , c , z , wrt ):
1529
+ def check_2f1_converges (a , b , c , z ) -> bool :
1530
+ num_terms = 0
1531
+ is_polynomial = False
1532
+
1533
+ def is_nonpositive_integer (x ):
1534
+ return x <= 0 and x .is_integer ()
1535
+
1536
+ if is_nonpositive_integer (a ) and abs (a ) >= num_terms :
1537
+ is_polynomial = True
1538
+ num_terms = int (np .floor (abs (a )))
1539
+ if is_nonpositive_integer (b ) and abs (b ) >= num_terms :
1540
+ is_polynomial = True
1541
+ num_terms = int (np .floor (abs (b )))
1542
+
1543
+ is_undefined = is_nonpositive_integer (c ) and abs (c ) <= num_terms
1544
+
1545
+ return not is_undefined and (
1546
+ is_polynomial or np .abs (z ) < 1 or (np .abs (z ) == 1 and c > (a + b ))
1547
+ )
1548
+
1549
+ def compute_grad_2f1 (a , b , c , z , wrt ):
1550
+ """
1551
+ Notes
1552
+ -----
1553
+ The algorithm can be derived by looking at the ratio of two successive terms in the series
1554
+ β_{k+1}/β_{k} = A(k)/B(k)
1555
+ β_{k+1} = A(k)/B(k) * β_{k}
1556
+ d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule
1557
+
1558
+ In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z
1559
+
1560
+ The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
1561
+ by dropping the respective term
1562
+ d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
1563
+ d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
1564
+ d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)
1565
+
1566
+ The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
1567
+ tracking their signs.
1568
+ """
1569
+
1570
+ wrt_a = wrt_b = False
1571
+ if wrt == 0 :
1572
+ wrt_a = True
1573
+ elif wrt == 1 :
1574
+ wrt_b = True
1575
+ elif wrt != 2 :
1576
+ raise ValueError (f"wrt must be 0, 1, or 2, got { wrt } " )
1577
+
1578
+ min_steps = 10 # https://github.com/stan-dev/math/issues/2857
1579
+ max_steps = int (1e6 )
1580
+ precision = 1e-14
1581
+
1582
+ res = 0
1583
+
1584
+ if z == 0 :
1585
+ return res
1586
+
1587
+ log_g_old = - np .inf
1588
+ log_t_old = 0.0
1589
+ log_t_new = 0.0
1590
+ sign_z = np .sign (z )
1591
+ log_z = np .log (np .abs (z ))
1592
+
1593
+ log_g_old_sign = 1
1594
+ log_t_old_sign = 1
1595
+ log_t_new_sign = 1
1596
+ sign_zk = sign_z
1597
+
1598
+ for k in range (max_steps ):
1599
+ p = (a + k ) * (b + k ) / ((c + k ) * (k + 1 ))
1600
+ if p == 0 :
1601
+ return res
1602
+ log_t_new += np .log (np .abs (p )) + log_z
1603
+ log_t_new_sign = np .sign (p ) * log_t_new_sign
1604
+
1605
+ term = log_g_old_sign * log_t_old_sign * np .exp (log_g_old - log_t_old )
1606
+ if wrt_a :
1607
+ term += np .reciprocal (a + k )
1608
+ elif wrt_b :
1609
+ term += np .reciprocal (b + k )
1610
+ else :
1611
+ term -= np .reciprocal (c + k )
1612
+
1613
+ log_g_old = log_t_new + np .log (np .abs (term ))
1614
+ log_g_old_sign = np .sign (term ) * log_t_new_sign
1615
+ g_current = log_g_old_sign * np .exp (log_g_old ) * sign_zk
1616
+ res += g_current
1617
+
1618
+ log_t_old = log_t_new
1619
+ log_t_old_sign = log_t_new_sign
1620
+ sign_zk *= sign_z
1621
+
1622
+ if k >= min_steps and np .abs (g_current ) <= precision :
1623
+ return res
1624
+
1625
+ warnings .warn (
1626
+ f"hyp2f1_der did not converge after { k } iterations" ,
1627
+ RuntimeWarning ,
1628
+ )
1629
+ return np .nan
1630
+
1631
+ # TODO: We could implement the Euler transform to expand supported domain, as Stan does
1632
+ if not check_2f1_converges (a , b , c , z ):
1633
+ warnings .warn (
1634
+ f"Hyp2F1 does not meet convergence conditions with given arguments a={ a } , b={ b } , c={ c } , z={ z } " ,
1635
+ RuntimeWarning ,
1636
+ )
1637
+ return np .nan
1638
+
1639
+ return compute_grad_2f1 (a , b , c , z , wrt = wrt )
1640
+
1641
+ def __call__ (self , a , b , c , z , wrt ):
1642
+ # This allows wrt to be a keyword argument
1643
+ return super ().__call__ (a , b , c , z , wrt )
1644
+
1645
+ def c_code (self , * args , ** kwargs ):
1646
+ raise NotImplementedError ()
1647
+
1648
+
1649
+ hyp2f1_der = Hyp2F1Der (upgrade_to_float , name = "hyp2f1_der" )
0 commit comments