Skip to content

Commit c92a9a9

Browse files
committed
Allow truncation of self-contained SymbolicRandomVariables
1 parent c8afedb commit c92a9a9

File tree

2 files changed

+32
-9
lines changed

2 files changed

+32
-9
lines changed

pymc/distributions/truncated.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from pytensor.tensor.random.type import RandomType
2929

3030
from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
31-
from pymc.distributions.custom import CustomSymbolicDistRV
3231
from pymc.distributions.dist_math import check_parameters
3332
from pymc.distributions.distribution import (
3433
Distribution,
@@ -302,17 +301,24 @@ class Truncated(Distribution):
302301
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
303302
if not (
304303
isinstance(dist, TensorVariable)
305-
and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV)
304+
and dist.owner is not None
305+
and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable)
306306
):
307-
if isinstance(dist.owner.op, SymbolicRandomVariable):
308-
raise NotImplementedError(
309-
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
310-
f"You can try wrapping the distribution inside a CustomDist instead."
311-
)
312307
raise ValueError(
313308
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
314309
)
315310

311+
if (
312+
isinstance(dist.owner.op, SymbolicRandomVariable)
313+
and "[size]" not in dist.owner.op.extended_signature
314+
):
315+
# Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole
316+
# random graph and as such we don't know where the actual inputs begin. This happens mostly for
317+
# distribution factories like `Censored` and `Mixture` which would have a very complex signature if they
318+
# encapsulated the random components instead of taking them as inputs like they do now.
319+
# SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter.
320+
raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}")
321+
316322
if dist.owner.op.ndim_supp > 0:
317323
raise NotImplementedError("Truncation not implemented for multivariate distributions")
318324

tests/distributions/test_truncated.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor.tensor.random.basic import GeometricRV, NormalRV
2222
from pytensor.tensor.random.type import RandomType
2323

24-
from pymc import Model, draw, find_MAP
24+
from pymc import ExGaussian, Model, Normal, draw, find_MAP
2525
from pymc.distributions import (
2626
Censored,
2727
ChiSquared,
@@ -342,7 +342,7 @@ def test_truncation_exceptions():
342342
# Truncation does not work with SymbolicRV inputs
343343
with pytest.raises(
344344
NotImplementedError,
345-
match="Truncation not implemented for SymbolicRandomVariable CensoredRV",
345+
match="Truncation not implemented for CensoredRV",
346346
):
347347
Truncated.dist(Censored.dist(pt.random.normal(), lower=-1, upper=1), -1, 1)
348348

@@ -599,3 +599,20 @@ def dist(scale, size):
599599
rv_out = Truncated.dist(latent, upper=7)
600600

601601
assert np.ptp(draw(rv_out, draws=100)) < 7
602+
603+
604+
@pytest.mark.parametrize(
605+
"dist_fn",
606+
[
607+
lambda: ExGaussian.dist(nu=3),
608+
pytest.param(
609+
lambda: Censored.dist(Normal.dist(), lower=1),
610+
marks=pytest.mark.xfail(raises=NotImplementedError),
611+
),
612+
],
613+
)
614+
def test_truncated_symbolic_rv(dist_fn):
615+
dist = dist_fn()
616+
trunc_dist = Truncated.dist(dist, lower=1, upper=3)
617+
assert 1 <= draw(trunc_dist) <= 3
618+
assert (logp(trunc_dist, 2.5) > logp(dist, 2.5)).eval()

0 commit comments

Comments
 (0)