Skip to content

Commit 211e0cb

Browse files
ColtAllenricardoV94
authored andcommitted
Add factorial and poch helpers
1 parent db18c97 commit 211e0cb

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

pytensor/tensor/special.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from pytensor.graph.basic import Apply
88
from pytensor.link.c.op import COp
99
from pytensor.tensor.basic import as_tensor_variable
10-
from pytensor.tensor.math import neg, sum
10+
from pytensor.tensor.math import gamma, neg, sum
1111

1212

1313
class SoftmaxGrad(COp):
@@ -768,7 +768,25 @@ def log_softmax(c, axis=UNSET_AXIS):
768768
return LogSoftmax(axis=axis)(c)
769769

770770

771+
def poch(z, m):
772+
"""
773+
Pochhammer symbol (rising factorial) function.
774+
775+
"""
776+
return gamma(z + m) / gamma(z)
777+
778+
779+
def factorial(n):
780+
"""
781+
Factorial function of a scalar or array of numbers.
782+
783+
"""
784+
return gamma(n + 1)
785+
786+
771787
__all__ = [
772788
"softmax",
773789
"log_softmax",
790+
"poch",
791+
"factorial",
774792
]

tests/tensor/test_special.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
22
import pytest
3+
from scipy.special import factorial as scipy_factorial
34
from scipy.special import log_softmax as scipy_log_softmax
5+
from scipy.special import poch as scipy_poch
46
from scipy.special import softmax as scipy_softmax
57

68
from pytensor.compile.function import function
@@ -9,11 +11,14 @@
911
LogSoftmax,
1012
Softmax,
1113
SoftmaxGrad,
14+
factorial,
1215
log_softmax,
16+
poch,
1317
softmax,
1418
)
15-
from pytensor.tensor.type import matrix, tensor3, tensor4, vector
19+
from pytensor.tensor.type import matrix, tensor3, tensor4, vector, vectors
1620
from tests import unittest_tools as utt
21+
from tests.tensor.utils import random_ranged
1722

1823

1924
class TestSoftmax(utt.InferShapeTester):
@@ -140,3 +145,29 @@ def test_valid_axis(self):
140145

141146
with pytest.raises(ValueError):
142147
SoftmaxGrad(-4)(*x)
148+
149+
150+
def test_poch():
151+
_z, _m = vectors("z", "m")
152+
actual_fn = function([_z, _m], poch(_z, _m))
153+
154+
z = random_ranged(0, 5, (2,))
155+
m = random_ranged(0, 5, (2,))
156+
actual = actual_fn(z, m)
157+
expected = scipy_poch(z, m)
158+
np.testing.assert_allclose(
159+
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
160+
)
161+
162+
163+
@pytest.mark.parametrize("n", random_ranged(0, 5, (1,)))
164+
def test_factorial(n):
165+
_n = vector("n")
166+
actual_fn = function([_n], factorial(_n))
167+
168+
n = random_ranged(0, 5, (2,))
169+
actual = actual_fn(n)
170+
expected = scipy_factorial(n)
171+
np.testing.assert_allclose(
172+
actual, expected, rtol=1e-7 if config.floatX == "float64" else 1e-5
173+
)

0 commit comments

Comments
 (0)