Skip to content

Add rewrite for log(gamma) -> gammaln #1181

Open
@ricardoV94

Description

@ricardoV94

Description

We're missing this simple rewrite:

import pytensor.tensor as pt
from pytensor.graph import rewrite_graph

x = pt.scalar("x")
out = pt.log(pt.gamma(x))
new_out = rewrite_graph(out, include=("canonicalize", "stabilize", "specialize"))
new_out.dprint()

Can be done easily with PatternNodeRewriter as in

local_polygamma_to_tri_gamma = PatternNodeRewriter(
(polygamma, 1, "x"),
(tri_gamma, "x"),
allow_multiple_clients=True,
name="local_polygamma_to_tri_gamma",
)

We could also add rewrites for common combinatorics expressions like

naive_betaln = pt.log((pt.gamma(x) * pt.gamma(y)) / pt.gamma(x + y)
betaln = pt.gammaln(x) + pt.gammaln(y) - pt.gammaln(x + y)

def betaln(a, b):
"""
Log beta function.
"""
return gammaln(a) + gammaln(b) - gammaln(a + b)

Or for log(poch):

def poch(z, m):
"""
Pochhammer symbol (rising factorial) function.
"""
return gamma(z + m) / gamma(z)

For these more general cases we can probably use something more flexible than the PatternNodeRewriter. We want to apply as long as we know all the terms inside are factorials/gammas/exps (positive things that easily blow up). This is a narrow/easier subset of #177

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions