Skip to content

Commit 823a702

Browse files
committed
Initial commit for hyp2f1 op. 4 of 10 tests failing
Revised to ScalarOp n=4 inputs Added Hyp2F1 derivatives. Added mpmath functions to derivatives and black formatting. Changed grad_not_implemented to NotImplementedError. Added Pochhammer Symbol Op. Added Factorial Op Replaced mpmath functions with scipy changed test inputs from integers to normal for _good_broadcast_unary_factorial Updated test values in _good_broadcast_quaternary_hyp2f1 set makeBroadcastTester(grad=None) for all tests Rewrote st_impl for Poch and Factorial in terms of gamma Op Refactored Poch and Factorial into helper functions for at.Gamma. Rewrote tests. Refactored hyp2f1_der in terms of gamma and ran black formatting Moved factorial and poch into tensor.special Add Hyp2F1, poch, and factorial Staging commit for hyp2f1 gradient tests Continued downstream_1288 in new fork. Added Black formatting. Removed duplicate tests Removed elemwise hyp2f1_der and respective tests Added hyp2f1 and derivatives, factorial and poch helpers
1 parent a2221ef commit 823a702

File tree

6 files changed

+253
-2
lines changed

6 files changed

+253
-2
lines changed

pytensor/scalar/math.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,3 +1481,155 @@ def c_code(self, *args, **kwargs):
14811481

14821482

14831483
betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")
1484+
1485+
1486+
class Hyp2F1(ScalarOp):
1487+
"""
1488+
Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1489+
1490+
"""
1491+
1492+
nin = 4
1493+
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)
1494+
1495+
@staticmethod
1496+
def st_impl(a, b, c, z):
1497+
1498+
if abs(z) >= 1:
1499+
raise NotImplementedError("hyp2f1 only supported for z < 1.")
1500+
else:
1501+
return scipy.special.hyp2f1(a, b, c, z)
1502+
1503+
def impl(self, a, b, c, z):
1504+
return Hyp2F1.st_impl(a, b, c, z)
1505+
1506+
def grad(self, inputs, grads):
1507+
a, b, c, z = inputs
1508+
(gz,) = grads
1509+
return [
1510+
gz * hyp2f1_der(a, b, c, z, 0),
1511+
gz * hyp2f1_der(a, b, c, z, 1),
1512+
gz * hyp2f1_der(a, b, c, z, 2),
1513+
gz * hyp2f1_der(a, b, c, z, 3),
1514+
]
1515+
1516+
def c_code(self, *args, **kwargs):
1517+
raise NotImplementedError()
1518+
1519+
1520+
hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")
1521+
1522+
1523+
class Hyp2F1Der(ScalarOp):
1524+
"""
1525+
Derivatives of the Gaussian hypergeometric function ``2F1(a, b; c; z)``.
1526+
1527+
"""
1528+
1529+
nin = 5
1530+
1531+
def impl(self, a, b, c, z, wrt):
1532+
def _infinisum(f):
1533+
"""
1534+
Utility function for infinite summations.
1535+
"""
1536+
1537+
n, res = 0, f(0)
1538+
while True:
1539+
term = f(n + 1)
1540+
if RuntimeWarning:
1541+
break
1542+
if (res + term) - res == 0:
1543+
break
1544+
n, res = n + 1, res + term
1545+
return res
1546+
1547+
def _hyp2f1_da(a, b, c, z):
1548+
"""
1549+
Derivative of hyp2f1 wrt a
1550+
1551+
"""
1552+
1553+
if abs(z) >= 1:
1554+
raise NotImplementedError("Gradient not supported for |z| >= 1")
1555+
1556+
else:
1557+
term1 = _infinisum(
1558+
lambda k: (
1559+
(scipy.special.gamma(a + k) / scipy.special.gamma(a))
1560+
* (scipy.special.gamma(b + k) / scipy.special.gamma(b))
1561+
* scipy.special.psi(a + k)
1562+
* (z**k)
1563+
)
1564+
/ (scipy.special.gamma(c + k) / scipy.special.gamma(c))
1565+
* scipy.special.gamma(k + 1)
1566+
)
1567+
term2 = scipy.special.psi(a) * scipy.special.hyp2f1(a, b, c, z)
1568+
1569+
return term1 - term2
1570+
1571+
def _hyp2f1_db(a, b, c, z):
1572+
"""
1573+
Derivative of hyp2f1 wrt b
1574+
"""
1575+
1576+
if abs(z) >= 1:
1577+
raise NotImplementedError("Gradient not supported for |z| >= 1")
1578+
1579+
else:
1580+
term1 = _infinisum(
1581+
lambda k: (
1582+
(scipy.special.gamma(a + k) / scipy.special.gamma(a))
1583+
* (scipy.special.gamma(b + k) / scipy.special.gamma(b))
1584+
* scipy.special.psi(b + k)
1585+
* (z**k)
1586+
)
1587+
/ (scipy.special.gamma(c + k) / scipy.special.gamma(c))
1588+
* scipy.special.gamma(k + 1)
1589+
)
1590+
term2 = scipy.special.psi(b) * scipy.special.hyp2f1(a, b, c, z)
1591+
1592+
return term1 - term2
1593+
1594+
def _hyp2f1_dc(a, b, c, z):
1595+
"""
1596+
Derivative of hyp2f1 wrt c
1597+
"""
1598+
if abs(z) >= 1:
1599+
raise NotImplementedError("Gradient not supported for |z| >= 1")
1600+
1601+
else:
1602+
term1 = scipy.special.psi(c) * scipy.special.hyp2f1(a, b, c, z)
1603+
term2 = _infinisum(
1604+
lambda k: (
1605+
(scipy.special.gamma(a + k) / scipy.special.gamma(a))
1606+
* (scipy.special.gamma(b + k) / scipy.special.gamma(b))
1607+
* scipy.special.psi(c + k)
1608+
* (z**k)
1609+
)
1610+
/ (scipy.special.gamma(c + k) / scipy.special.gamma(c))
1611+
* scipy.special.gamma(k + 1)
1612+
)
1613+
return term1 - term2
1614+
1615+
def _hyp2f1_dz(a, b, c, z):
1616+
"""
1617+
Derivative of hyp2f1 wrt z
1618+
"""
1619+
1620+
return ((a * b) / c) * scipy.special.hyp2f1(a + 1, b + 1, c + 1, z)
1621+
1622+
if wrt == 0:
1623+
return _hyp2f1_da(a, b, c, z)
1624+
elif wrt == 1:
1625+
return _hyp2f1_db(a, b, c, z)
1626+
elif wrt == 2:
1627+
return _hyp2f1_dc(a, b, c, z)
1628+
elif wrt == 3:
1629+
return _hyp2f1_dz(a, b, c, z)
1630+
1631+
def c_code(self, *args, **kwargs):
1632+
raise NotImplementedError()
1633+
1634+
1635+
hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")

pytensor/tensor/inplace.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,11 @@ def conj_inplace(a):
392392
"""elementwise conjugate (inplace on `a`)"""
393393

394394

395+
@scalar_elemwise
396+
def hyp2f1_inplace(a, b, c, z):
397+
"""gaussian hypergeometric function"""
398+
399+
395400
pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
396401
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
397402
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))

pytensor/tensor/math.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1384,6 +1384,11 @@ def gammal(k, x):
13841384
"""Lower incomplete gamma function."""
13851385

13861386

1387+
@scalar_elemwise
1388+
def hyp2f1(a, b, c, z):
1389+
"""Gaussian hypergeometric function."""
1390+
1391+
13871392
@scalar_elemwise
13881393
def j0(x):
13891394
"""Bessel function of the first kind of order 0."""
@@ -3132,6 +3137,7 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
31323137
"power",
31333138
"logaddexp",
31343139
"logsumexp",
3140+
"hyp2f1",
31353141
]
31363142

31373143
DEPRECATED_NAMES = [

pytensor/tensor/special.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import warnings
22
from textwrap import dedent
3+
from typing import TYPE_CHECKING
34

45
import numpy as np
56
import scipy
67

78
from pytensor.graph.basic import Apply
89
from pytensor.link.c.op import COp
910
from pytensor.tensor.basic import as_tensor_variable
10-
from pytensor.tensor.math import neg, sum
11+
from pytensor.tensor.math import neg, sum, gamma
12+
13+
14+
if TYPE_CHECKING:
15+
from pytensor.tensor import TensorLike, TensorVariable
1116

1217

1318
class SoftmaxGrad(COp):
@@ -768,7 +773,25 @@ def log_softmax(c, axis=UNSET_AXIS):
768773
return LogSoftmax(axis=axis)(c)
769774

770775

776+
def poch(z: "TensorLike", m: "TensorLike") -> "TensorVariable":
777+
"""
778+
Pochhammer symbol (rising factorial) function.
779+
780+
"""
781+
return gamma(z + m) / gamma(z)
782+
783+
784+
def factorial(n: "TensorLike") -> "TensorVariable":
785+
"""
786+
Factorial function of a scalar or array of numbers.
787+
788+
"""
789+
return gamma(n + 1)
790+
791+
771792
__all__ = [
772793
"softmax",
773794
"log_softmax",
795+
"poch",
796+
"factorial",
774797
]

tests/tensor/test_math_scipy.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def scipy_special_gammal(k, x):
7171
expected_iv = scipy.special.iv
7272
expected_erfcx = scipy.special.erfcx
7373
expected_sigmoid = scipy.special.expit
74+
expected_hyp2f1 = scipy.special.hyp2f1
7475

7576
TestErfBroadcast = makeBroadcastTester(
7677
op=at.erf,
@@ -753,6 +754,32 @@ def test_deprecated_module():
753754
inplace=True,
754755
)
755756

757+
_good_broadcast_quaternary_hyp2f1 = dict(
758+
normal=(
759+
random_ranged(0, 1000, (2, 3)),
760+
random_ranged(0, 1000, (2, 3)),
761+
random_ranged(0, 1000, (2, 3)),
762+
random_ranged(0, 0.5, (2, 3)),
763+
),
764+
)
765+
766+
TestHyp2F1Broadcast = makeBroadcastTester(
767+
op=at.hyp2f1,
768+
expected=expected_hyp2f1,
769+
good=_good_broadcast_quaternary_hyp2f1,
770+
grad=_good_broadcast_quaternary_hyp2f1,
771+
eps=2e-10,
772+
)
773+
774+
TestHyp2F1InplaceBroadcast = makeBroadcastTester(
775+
op=inplace.hyp2f1_inplace,
776+
expected=expected_hyp2f1,
777+
good=_good_broadcast_quaternary_hyp2f1,
778+
eps=2e-10,
779+
mode=mode_no_scipy,
780+
inplace=True,
781+
)
782+
756783

757784
class TestBetaIncGrad:
758785
def test_stan_grad_partial(self):

tests/tensor/test_special.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,24 @@
11
import numpy as np
22
import pytest
3+
from scipy.special import factorial as scipy_factorial
34
from scipy.special import log_softmax as scipy_log_softmax
5+
from scipy.special import poch as scipy_poch
46
from scipy.special import softmax as scipy_softmax
57

68
from pytensor.compile.function import function
79
from pytensor.configdefaults import config
8-
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad, log_softmax, softmax
10+
from pytensor.tensor import scalar, scalars
11+
from pytensor.tensor.special import (
12+
LogSoftmax,
13+
Softmax,
14+
SoftmaxGrad,
15+
log_softmax,
16+
softmax,
17+
poch,
18+
factorial,
19+
)
920
from pytensor.tensor.type import matrix, tensor3, tensor4, vector
21+
from tests.tensor.utils import random_ranged
1022
from tests import unittest_tools as utt
1123

1224

@@ -134,3 +146,29 @@ def test_valid_axis(self):
134146

135147
with pytest.raises(ValueError):
136148
SoftmaxGrad(-4)(*x)
149+
150+
151+
@pytest.mark.parametrize("z, m", [random_ranged(0, 5, (2,)), random_ranged(0, 5, (2,))])
152+
def test_poch(z, m):
153+
154+
_z, _m = scalars("z", "m")
155+
156+
actual_fn = function([_z, _m], poch(_z, _m))
157+
actual = actual_fn(z, m)
158+
159+
expected = scipy_poch(z, m)
160+
161+
assert np.allclose(actual, expected)
162+
163+
164+
@pytest.mark.parametrize("n", random_ranged(0, 5, (1,)))
165+
def test_factorial(n):
166+
167+
_n = scalar("n")
168+
169+
actual_fn = function([_n], factorial(_n))
170+
actual = actual_fn(n)
171+
172+
expected = scipy_factorial(n)
173+
174+
assert np.allclose(actual, expected)

0 commit comments

Comments
 (0)