Skip to content

Add Ops for Gaussian Hypergeometric Function, Pochhammer Symbol, and Factorials #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 166 additions & 0 deletions pytensor/scalar/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1481,3 +1481,169 @@ def c_code(self, *args, **kwargs):


betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der")


class Hyp2F1(ScalarOp):
"""
Gaussian hypergeometric function ``2F1(a, b; c; z)``.

"""

nin = 4
nfunc_spec = ("scipy.special.hyp2f1", 4, 1)

@staticmethod
def st_impl(a, b, c, z):
return scipy.special.hyp2f1(a, b, c, z)

def impl(self, a, b, c, z):
return Hyp2F1.st_impl(a, b, c, z)

def grad(self, inputs, grads):
a, b, c, z = inputs
(gz,) = grads
return [
gz * hyp2f1_der(a, b, c, z, wrt=0),
gz * hyp2f1_der(a, b, c, z, wrt=1),
gz * hyp2f1_der(a, b, c, z, wrt=2),
gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z),
]

def c_code(self, *args, **kwargs):
raise NotImplementedError()


hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1")


class Hyp2F1Der(ScalarOp):
"""
Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs.

Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp
"""

nin = 5

def impl(self, a, b, c, z, wrt):
def check_2f1_converges(a, b, c, z) -> bool:
num_terms = 0
is_polynomial = False

def is_nonpositive_integer(x):
return x <= 0 and x.is_integer()

if is_nonpositive_integer(a) and abs(a) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(a)))
if is_nonpositive_integer(b) and abs(b) >= num_terms:
is_polynomial = True
num_terms = int(np.floor(abs(b)))

is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms

return not is_undefined and (
is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b))
)
Comment on lines +1529 to +1547
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just idle curiosity, but what is the significance of the num_terms variable?

Copy link
Member

@ricardoV94 ricardoV94 Jan 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think, when the series has a negative integer in the denominator it will eventually have a divide by zero and blow up... unless it has a negative integer in the numerator that will arrive at zero first, truncating the series before that happens


def compute_grad_2f1(a, b, c, z, wrt):
"""
Notes
-----
The algorithm can be derived by looking at the ratio of two successive terms in the series
β_{k+1}/β_{k} = A(k)/B(k)
β_{k+1} = A(k)/B(k) * β_{k}
d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule

In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z

The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k),
by dropping the respective term
d/da[A(k)/B(k)] = A(k)/B(k) / (a + k)
d/db[A(k)/B(k)] = A(k)/B(k) / (b + k)
d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k)

The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and
tracking their signs.
"""

wrt_a = wrt_b = False
if wrt == 0:
wrt_a = True
elif wrt == 1:
wrt_b = True
elif wrt != 2:
raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}")

min_steps = 10 # https://github.com/stan-dev/math/issues/2857
max_steps = int(1e6)
precision = 1e-14

res = 0

if z == 0:
return res

log_g_old = -np.inf
log_t_old = 0.0
log_t_new = 0.0
sign_z = np.sign(z)
log_z = np.log(np.abs(z))

log_g_old_sign = 1
log_t_old_sign = 1
log_t_new_sign = 1
sign_zk = sign_z

for k in range(max_steps):
p = (a + k) * (b + k) / ((c + k) * (k + 1))
if p == 0:
return res
log_t_new += np.log(np.abs(p)) + log_z
log_t_new_sign = np.sign(p) * log_t_new_sign

term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old)
if wrt_a:
term += np.reciprocal(a + k)
elif wrt_b:
term += np.reciprocal(b + k)
else:
term -= np.reciprocal(c + k)

log_g_old = log_t_new + np.log(np.abs(term))
log_g_old_sign = np.sign(term) * log_t_new_sign
g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk
res += g_current

log_t_old = log_t_new
log_t_old_sign = log_t_new_sign
sign_zk *= sign_z

if k >= min_steps and np.abs(g_current) <= precision:
return res

warnings.warn(
f"hyp2f1_der did not converge after {k} iterations",
RuntimeWarning,
)
return np.nan

# TODO: We could implement the Euler transform to expand supported domain, as Stan does
if not check_2f1_converges(a, b, c, z):
warnings.warn(
f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}",
RuntimeWarning,
)
return np.nan

return compute_grad_2f1(a, b, c, z, wrt=wrt)

def __call__(self, a, b, c, z, wrt):
# This allows wrt to be a keyword argument
return super().__call__(a, b, c, z, wrt)

def c_code(self, *args, **kwargs):
raise NotImplementedError()


hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der")
5 changes: 5 additions & 0 deletions pytensor/tensor/inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,11 @@ def conj_inplace(a):
"""elementwise conjugate (inplace on `a`)"""


@scalar_elemwise
def hyp2f1_inplace(a, b, c, z):
"""gaussian hypergeometric function"""


pprint.assign(add_inplace, printing.OperatorPrinter("+=", -2, "either"))
pprint.assign(mul_inplace, printing.OperatorPrinter("*=", -1, "either"))
pprint.assign(sub_inplace, printing.OperatorPrinter("-=", -2, "left"))
Expand Down
6 changes: 6 additions & 0 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,6 +1384,11 @@ def gammal(k, x):
"""Lower incomplete gamma function."""


@scalar_elemwise
def hyp2f1(a, b, c, z):
"""Gaussian hypergeometric function."""


@scalar_elemwise
def j0(x):
"""Bessel function of the first kind of order 0."""
Expand Down Expand Up @@ -3132,4 +3137,5 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
"power",
"logaddexp",
"logsumexp",
"hyp2f1",
]
20 changes: 19 additions & 1 deletion pytensor/tensor/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytensor.graph.basic import Apply
from pytensor.link.c.op import COp
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import neg, sum
from pytensor.tensor.math import gamma, neg, sum


class SoftmaxGrad(COp):
Expand Down Expand Up @@ -768,7 +768,25 @@ def log_softmax(c, axis=UNSET_AXIS):
return LogSoftmax(axis=axis)(c)


def poch(z, m):
"""
Pochhammer symbol (rising factorial) function.

"""
return gamma(z + m) / gamma(z)


def factorial(n):
"""
Factorial function of a scalar or array of numbers.

"""
return gamma(n + 1)


__all__ = [
"softmax",
"log_softmax",
"poch",
"factorial",
]
Loading