Skip to content

Commit aa679f3

Browse files
committed
Add print_name to Truncated and CustomDists
1 parent 30d00fe commit aa679f3

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

pymc/distributions/distribution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,7 @@ def rv_op(
663663
ndim_supp=ndim_supp,
664664
ndims_params=ndims_params,
665665
dtype=dtype,
666+
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
666667
# Specific to CustomDist
667668
_random_fn=random,
668669
),
@@ -802,6 +803,7 @@ def rv_op(
802803
# If logp is not provided, we try to infer it from the dist graph
803804
dict(
804805
inline_logprob=logp is None,
806+
_print_name=(class_name, f"\\operatorname{{{class_name}}}"),
805807
),
806808
)
807809

pymc/distributions/truncated.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ class TruncatedRV(SymbolicRandomVariable):
5656
def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
5757
self.base_rv_op = base_rv_op
5858
self.max_n_steps = max_n_steps
59+
self._print_name = (
60+
f"Truncated{self.base_rv_op._print_name[0]}",
61+
f"\\operatorname{{{self.base_rv_op._print_name[1]}}}",
62+
)
5963
super().__init__(*args, **kwargs)
6064

6165
def update(self, node: Node):

tests/test_printing.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from pytensor.tensor.random import normal
1717

18-
from pymc import Bernoulli, Censored, HalfCauchy, Mixture, StudentT
18+
from pymc import Bernoulli, Censored, CustomDist, Gamma, HalfCauchy, Mixture, StudentT, Truncated
1919
from pymc.distributions import (
2020
Dirichlet,
2121
DirichletMultinomial,
@@ -285,3 +285,27 @@ def test_model_repr_variables_without_monkey_patched_repr():
285285

286286
str_repr = model.str_repr()
287287
assert str_repr == "x ~ Normal(0, 1)"
288+
289+
290+
def test_truncated_repr():
291+
with Model() as model:
292+
x = Truncated("x", Gamma.dist(1, 1), lower=0, upper=20)
293+
294+
str_repr = model.str_repr(include_params=False)
295+
assert str_repr == "x ~ TruncatedGamma"
296+
297+
298+
def test_custom_dist_repr():
299+
with Model() as model:
300+
301+
def dist(mu, size):
302+
return Normal.dist(mu, 1, size=size)
303+
304+
def random(rng, mu, size):
305+
return rng.normal(mu, size=size)
306+
307+
x = CustomDist("x", 0, dist=dist, class_name="CustomDistNormal")
308+
x = CustomDist("y", 0, random=random, class_name="CustomRandomNormal")
309+
310+
str_repr = model.str_repr(include_params=False)
311+
assert str_repr == "\n".join(["x ~ CustomDistNormal", "y ~ CustomRandomNormal"])

0 commit comments

Comments
 (0)