Skip to content

Commit 5f95fc2

Browse files
committed
Allow Truncation of CustomDists
1 parent 1195261 commit 5f95fc2

File tree

3 files changed

+228
-97
lines changed

3 files changed

+228
-97
lines changed

pymc/distributions/truncated.py

+120-81
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717
import pytensor
1818
import pytensor.tensor as pt
1919

20-
from pytensor import scan
20+
from pytensor import config, graph_replace, scan
2121
from pytensor.graph import Op
2222
from pytensor.graph.basic import Node
2323
from pytensor.raise_op import CheckAndRaise
2424
from pytensor.scan import until
2525
from pytensor.tensor import TensorConstant, TensorVariable
2626
from pytensor.tensor.random.basic import NormalRV
2727
from pytensor.tensor.random.op import RandomVariable
28+
from pytensor.tensor.random.type import RandomType
2829

2930
from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform
3031
from pymc.distributions.dist_math import check_parameters
3132
from pymc.distributions.distribution import (
33+
CustomSymbolicDistRV,
3234
Distribution,
3335
SymbolicRandomVariable,
3436
_support_point,
@@ -38,8 +40,9 @@
3840
from pymc.distributions.transforms import _default_transform
3941
from pymc.exceptions import TruncationError
4042
from pymc.logprob.abstract import _logcdf, _logprob
41-
from pymc.logprob.basic import icdf, logcdf
43+
from pymc.logprob.basic import icdf, logcdf, logp
4244
from pymc.math import logdiffexp
45+
from pymc.pytensorf import collect_default_updates
4346
from pymc.util import check_dist_not_registered
4447

4548

@@ -49,11 +52,17 @@ class TruncatedRV(SymbolicRandomVariable):
4952
that represents a truncated univariate random variable.
5053
"""
5154

52-
default_output = 1
53-
base_rv_op = None
54-
max_n_steps = None
55-
56-
def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
55+
default_output: int = 0
56+
base_rv_op: Op
57+
max_n_steps: int
58+
59+
def __init__(
60+
self,
61+
*args,
62+
base_rv_op: Op,
63+
max_n_steps: int,
64+
**kwargs,
65+
):
5766
self.base_rv_op = base_rv_op
5867
self.max_n_steps = max_n_steps
5968
self._print_name = (
@@ -63,8 +72,13 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
6372
super().__init__(*args, **kwargs)
6473

6574
def update(self, node: Node):
66-
"""Return the update mapping for the internal RNG."""
67-
return {node.inputs[-1]: node.outputs[0]}
75+
"""Return the update mapping for the internal RNGs.
76+
77+
TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.
78+
"""
79+
rngs = [inp for inp in node.inputs if isinstance(inp.type, RandomType)]
80+
next_rngs = [out for out in node.outputs if isinstance(out.type, RandomType)]
81+
return dict(zip(rngs, next_rngs))
6882

6983

7084
@singledispatch
@@ -141,10 +155,14 @@ class Truncated(Distribution):
141155

142156
@classmethod
143157
def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs):
144-
if not (isinstance(dist, TensorVariable) and isinstance(dist.owner.op, RandomVariable)):
158+
if not (
159+
isinstance(dist, TensorVariable)
160+
and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV)
161+
):
145162
if isinstance(dist.owner.op, SymbolicRandomVariable):
146163
raise NotImplementedError(
147-
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}"
164+
f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n"
165+
f"You can try wrapping the distribution inside a CustomDist instead."
148166
)
149167
raise ValueError(
150168
f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}"
@@ -174,46 +192,54 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
174192
if size is None:
175193
size = pt.broadcast_shape(dist, lower, upper)
176194
dist = change_dist_size(dist, new_size=size)
195+
rv_inputs = [
196+
inp
197+
if not isinstance(inp.type, RandomType)
198+
else pytensor.shared(np.random.default_rng())
199+
for inp in dist.owner.inputs
200+
]
201+
graph_inputs = [*rv_inputs, lower, upper]
177202

178203
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
179-
graph_inputs = [*dist.owner.inputs[1:], lower, upper]
180-
graph_inputs_ = [inp.type() for inp in graph_inputs]
204+
graph_inputs_ = [
205+
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
206+
]
181207
*rv_inputs_, lower_, upper_ = graph_inputs_
182208

183-
# We will use a Shared RNG variable because Scan demands it, even though it
184-
# would not be necessary for the OpFromGraph inverse cdf.
185-
rng = pytensor.shared(np.random.default_rng())
186-
rv_ = dist.owner.op.make_node(rng, *rv_inputs_).default_output()
209+
rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()
187210

188211
# Try to use inverted cdf sampling
212+
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
189213
try:
190-
# For left truncated discrete RVs, we need to include the whole lower bound.
191-
# This may result in draws below the truncation range, if any uniform == 0
192-
lower_value = lower_ - 1 if dist.owner.op.dtype.startswith("int") else lower_
193-
cdf_lower_ = pt.exp(logcdf(rv_, lower_value))
194-
cdf_upper_ = pt.exp(logcdf(rv_, upper_))
195-
# It's okay to reuse the same rng here, because the rng in rv_ will not be
196-
# used by either the logcdf of icdf functions
214+
logcdf_lower_, logcdf_upper_ = Truncated._create_logcdf_exprs(rv_, rv_, lower_, upper_)
215+
# We use the first RNG from the base RV, so we don't have to introduce a new one
216+
# This is not problematic because the RNG won't be used in the RV logcdf graph
217+
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
197218
uniform_next_rng_, uniform_ = pt.random.uniform(
198-
cdf_lower_,
199-
cdf_upper_,
200-
rng=rng,
201-
size=rv_inputs_[0],
219+
pt.exp(logcdf_lower_),
220+
pt.exp(logcdf_upper_),
221+
rng=uniform_rng_,
222+
size=rv_.shape,
202223
).owner.outputs
203-
truncated_rv_ = icdf(rv_, uniform_)
224+
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
204225
return TruncatedRV(
205226
base_rv_op=dist.owner.op,
206-
inputs=[*graph_inputs_, rng],
207-
outputs=[uniform_next_rng_, truncated_rv_],
227+
inputs=graph_inputs_,
228+
outputs=[truncated_rv_, uniform_next_rng_],
208229
ndim_supp=0,
209230
max_n_steps=max_n_steps,
210-
)(*graph_inputs, rng)
231+
)(*graph_inputs)
211232
except NotImplementedError:
212233
pass
213234

214235
# Fallback to rejection sampling
215-
def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
216-
next_rng, new_truncated_rv = dist.owner.op.make_node(rng, *rv_inputs).outputs
236+
# truncated_rv = zeros(rv.shape)
237+
# reject_draws = ones(rv.shape, dtype=bool)
238+
# while any(reject_draws):
239+
# truncated_rv[reject_draws] = draw(rv)[reject_draws]
240+
# reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
241+
def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
242+
new_truncated_rv = dist.owner.op.make_node(*rv_inputs_).default_output()
217243
# Avoid scalar boolean indexing
218244
if truncated_rv.type.ndim == 0:
219245
truncated_rv = new_truncated_rv
@@ -226,7 +252,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
226252

227253
return (
228254
(truncated_rv, reject_draws),
229-
[(rng, next_rng)],
255+
collect_default_updates(new_truncated_rv),
230256
until(~pt.any(reject_draws)),
231257
)
232258

@@ -236,7 +262,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
236262
pt.zeros_like(rv_),
237263
pt.ones_like(rv_, dtype=bool),
238264
],
239-
non_sequences=[lower_, upper_, rng, *rv_inputs_],
265+
non_sequences=[lower_, upper_, *rv_inputs_],
240266
n_steps=max_n_steps,
241267
strict=True,
242268
)
@@ -246,24 +272,49 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
246272
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
247273
truncated_rv_, convergence_
248274
)
275+
# Sort updates of each RNG so that they show in the same order as the input RNGs
276+
277+
def sort_updates(update):
278+
rng, next_rng = update
279+
return graph_inputs.index(rng)
280+
281+
next_rngs = [next_rng for rng, next_rng in sorted(updates.items(), key=sort_updates)]
249282

250-
[next_rng] = updates.values()
251283
return TruncatedRV(
252284
base_rv_op=dist.owner.op,
253-
inputs=[*graph_inputs_, rng],
254-
outputs=[next_rng, truncated_rv_],
285+
inputs=graph_inputs_,
286+
outputs=[truncated_rv_, *next_rngs],
255287
ndim_supp=0,
256288
max_n_steps=max_n_steps,
257-
)(*graph_inputs, rng)
289+
)(*graph_inputs)
290+
291+
@staticmethod
292+
def _create_logcdf_exprs(
293+
base_rv: TensorVariable,
294+
value: TensorVariable,
295+
lower: TensorVariable,
296+
upper: TensorVariable,
297+
) -> tuple[TensorVariable, TensorVariable]:
298+
"""Create lower and upper logcdf expressions for base_rv.
299+
300+
Uses `value` as a template for broadcasting.
301+
"""
302+
# For left truncated discrete RVs, we need to include the whole lower bound.
303+
lower_value = lower - 1 if base_rv.type.dtype.startswith("int") else lower
304+
lower_value = pt.full_like(value, lower_value, dtype=config.floatX)
305+
upper_value = pt.full_like(value, upper, dtype=config.floatX)
306+
lower_logcdf = logcdf(base_rv, lower_value, warn_rvs=False)
307+
upper_logcdf = graph_replace(lower_logcdf, {lower_value: upper_value})
308+
return lower_logcdf, upper_logcdf
258309

259310

260311
@_change_dist_size.register(TruncatedRV)
261-
def change_truncated_size(op, dist, new_size, expand):
262-
*rv_inputs, lower, upper, rng = dist.owner.inputs
263-
# Recreate the original untruncated RV
264-
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
312+
def change_truncated_size(op: TruncatedRV, truncated_rv, new_size, expand):
313+
*rv_inputs, lower, upper = truncated_rv.owner.inputs
314+
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
315+
265316
if expand:
266-
new_size = to_tuple(new_size) + tuple(dist.shape)
317+
new_size = to_tuple(new_size) + tuple(truncated_rv.shape)
267318

268319
return Truncated.rv_op(
269320
untruncated_rv,
@@ -275,11 +326,11 @@ def change_truncated_size(op, dist, new_size, expand):
275326

276327

277328
@_support_point.register(TruncatedRV)
278-
def truncated_support_point(op, rv, *inputs):
279-
*rv_inputs, lower, upper, rng = inputs
329+
def truncated_support_point(op: TruncatedRV, truncated_rv, *inputs):
330+
*rv_inputs, lower, upper = inputs
280331

281332
# recreate untruncated rv and respective support_point
282-
untruncated_rv = op.base_rv_op.make_node(rng, *rv_inputs).default_output()
333+
untruncated_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
283334
untruncated_support_point = support_point(untruncated_rv)
284335

285336
fallback_support_point = pt.switch(
@@ -300,31 +351,25 @@ def truncated_support_point(op, rv, *inputs):
300351

301352

302353
@_default_transform.register(TruncatedRV)
303-
def truncated_default_transform(op, rv):
354+
def truncated_default_transform(op, truncated_rv):
304355
# Don't transform discrete truncated distributions
305-
if op.base_rv_op.dtype.startswith("int"):
356+
if truncated_rv.type.dtype.startswith("int"):
306357
return None
307-
# Lower and Upper are the arguments -3 and -2
308-
return bounded_cont_transform(op, rv, bound_args_indices=(-3, -2))
358+
# Lower and Upper are the arguments -2 and -1
359+
return bounded_cont_transform(op, truncated_rv, bound_args_indices=(-2, -1))
309360

310361

311362
@_logprob.register(TruncatedRV)
312363
def truncated_logprob(op, values, *inputs, **kwargs):
313364
(value,) = values
314-
315-
*rv_inputs, lower, upper, rng = inputs
316-
rv_inputs = [rng, *rv_inputs]
365+
*rv_inputs, lower, upper = inputs
317366

318367
base_rv_op = op.base_rv_op
319-
logp = _logprob(base_rv_op, (value,), *rv_inputs, **kwargs)
320-
# For left truncated RVs, we don't want to include the lower bound in the
321-
# normalization term
322-
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
323-
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
324-
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
325-
368+
base_rv = base_rv_op.make_node(*rv_inputs).default_output()
369+
base_logp = logp(base_rv, value)
370+
lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper)
326371
if base_rv_op.name:
327-
logp.name = f"{base_rv_op}_logprob"
372+
base_logp.name = f"{base_rv_op}_logprob"
328373
lower_logcdf.name = f"{base_rv_op}_lower_logcdf"
329374
upper_logcdf.name = f"{base_rv_op}_upper_logcdf"
330375

@@ -339,37 +384,31 @@ def truncated_logprob(op, values, *inputs, **kwargs):
339384
elif is_upper_bounded:
340385
lognorm = upper_logcdf
341386

342-
logp = logp - lognorm
387+
truncated_logp = base_logp - lognorm
343388

344389
if is_lower_bounded:
345-
logp = pt.switch(value < lower, -np.inf, logp)
390+
truncated_logp = pt.switch(value < lower, -np.inf, truncated_logp)
346391

347392
if is_upper_bounded:
348-
logp = pt.switch(value <= upper, logp, -np.inf)
393+
truncated_logp = pt.switch(value <= upper, truncated_logp, -np.inf)
349394

350395
if is_lower_bounded and is_upper_bounded:
351-
logp = check_parameters(
352-
logp,
396+
truncated_logp = check_parameters(
397+
truncated_logp,
353398
pt.le(lower, upper),
354399
msg="lower_bound <= upper_bound",
355400
)
356401

357-
return logp
402+
return truncated_logp
358403

359404

360405
@_logcdf.register(TruncatedRV)
361-
def truncated_logcdf(op, value, *inputs, **kwargs):
362-
*rv_inputs, lower, upper, rng = inputs
363-
rv_inputs = [rng, *rv_inputs]
364-
365-
base_rv_op = op.base_rv_op
366-
logcdf = _logcdf(base_rv_op, value, *rv_inputs, **kwargs)
406+
def truncated_logcdf(op: TruncatedRV, value, *inputs, **kwargs):
407+
*rv_inputs, lower, upper = inputs
367408

368-
# For left truncated discrete RVs, we don't want to include the lower bound in the
369-
# normalization term
370-
lower_value = lower - 1 if base_rv_op.dtype.startswith("int") else lower
371-
lower_logcdf = _logcdf(base_rv_op, lower_value, *rv_inputs, **kwargs)
372-
upper_logcdf = _logcdf(base_rv_op, upper, *rv_inputs, **kwargs)
409+
base_rv = op.base_rv_op.make_node(*rv_inputs).default_output()
410+
base_logcdf = logcdf(base_rv, value)
411+
lower_logcdf, upper_logcdf = Truncated._create_logcdf_exprs(base_rv, value, lower, upper)
373412

374413
is_lower_bounded = not (isinstance(lower, TensorConstant) and np.all(np.isneginf(lower.value)))
375414
is_upper_bounded = not (isinstance(upper, TensorConstant) and np.all(np.isinf(upper.value)))
@@ -382,7 +421,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs):
382421
elif is_upper_bounded:
383422
lognorm = upper_logcdf
384423

385-
logcdf_numerator = logdiffexp(logcdf, lower_logcdf) if is_lower_bounded else logcdf
424+
logcdf_numerator = logdiffexp(base_logcdf, lower_logcdf) if is_lower_bounded else base_logcdf
386425
logcdf_trunc = logcdf_numerator - lognorm
387426

388427
if is_lower_bounded:

tests/distributions/test_mixture.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -1588,8 +1588,8 @@ def test_hurdle_negativebinomial_graph(self):
15881588
_, nonzero_dist = self.check_hurdle_mixture_graph(dist)
15891589

15901590
assert isinstance(nonzero_dist.owner.op.base_rv_op, NegativeBinomial)
1591-
assert nonzero_dist.owner.inputs[2].data == n
1592-
assert nonzero_dist.owner.inputs[3].data == p
1591+
assert nonzero_dist.owner.inputs[-4].data == n
1592+
assert nonzero_dist.owner.inputs[-3].data == p
15931593

15941594
def test_hurdle_gamma_graph(self):
15951595
psi, alpha, beta = 0.25, 3, 4
@@ -1599,17 +1599,17 @@ def test_hurdle_gamma_graph(self):
15991599
# Under the hood it uses the shape-scale parametrization of the Gamma distribution.
16001600
# So the second value is the reciprocal of the rate (i.e. 1 / beta)
16011601
assert isinstance(nonzero_dist.owner.op.base_rv_op, Gamma)
1602-
assert nonzero_dist.owner.inputs[2].data == alpha
1603-
assert nonzero_dist.owner.inputs[3].eval() == 1 / beta
1602+
assert nonzero_dist.owner.inputs[-4].data == alpha
1603+
assert nonzero_dist.owner.inputs[-3].eval() == 1 / beta
16041604

16051605
def test_hurdle_lognormal_graph(self):
16061606
psi, mu, sigma = 0.1, 2, 2.5
16071607
dist = HurdleLogNormal.dist(psi=psi, mu=mu, sigma=sigma)
16081608
_, nonzero_dist = self.check_hurdle_mixture_graph(dist)
16091609

16101610
assert isinstance(nonzero_dist.owner.op.base_rv_op, LogNormal)
1611-
assert nonzero_dist.owner.inputs[2].data == mu
1612-
assert nonzero_dist.owner.inputs[3].data == sigma
1611+
assert nonzero_dist.owner.inputs[-4].data == mu
1612+
assert nonzero_dist.owner.inputs[-3].data == sigma
16131613

16141614
@pytest.mark.parametrize(
16151615
"dist, psi, non_psi_args",

0 commit comments

Comments
 (0)