|
28 | 28 | from pytensor.tensor.random.type import RandomType
|
29 | 29 |
|
30 | 30 | from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
|
31 |
| -from pymc.distributions.custom import CustomSymbolicDistRV |
32 | 31 | from pymc.distributions.dist_math import check_parameters
|
33 | 32 | from pymc.distributions.distribution import (
|
34 | 33 | Distribution,
|
@@ -302,17 +301,24 @@ class Truncated(Distribution):
|
302 | 301 | def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
|
303 | 302 | if not (
|
304 | 303 | isinstance(dist, TensorVariable)
|
305 |
| - and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV) |
| 304 | + and dist.owner is not None |
| 305 | + and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable) |
306 | 306 | ):
|
307 |
| - if isinstance(dist.owner.op, SymbolicRandomVariable): |
308 |
| - raise NotImplementedError( |
309 |
| - f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n" |
310 |
| - f"You can try wrapping the distribution inside a CustomDist instead." |
311 |
| - ) |
312 | 307 | raise ValueError(
|
313 | 308 | f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
|
314 | 309 | )
|
315 | 310 |
|
| 311 | + if ( |
| 312 | + isinstance(dist.owner.op, SymbolicRandomVariable) |
| 313 | + and "[size]" not in dist.owner.op.extended_signature |
| 314 | + ): |
| 315 | + # Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole |
| 316 | + # random graph and as such we don't know where the actual inputs begin. This happens mostly for |
| 317 | + # distribution factories like `Censored` and `Mixture` which would have a very complex signature if they |
| 318 | + # encapsulated the random components instead of taking them as inputs like they do now. |
| 319 | + # SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter. |
| 320 | + raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}") |
| 321 | + |
316 | 322 | if dist.owner.op.ndim_supp > 0:
|
317 | 323 | raise NotImplementedError("Truncation not implemented for multivariate distributions")
|
318 | 324 |
|
|
0 commit comments