17
17
import pytensor
18
18
import pytensor .tensor as pt
19
19
20
- from pytensor import scan
20
+ from pytensor import config , graph_replace , scan
21
21
from pytensor .graph import Op
22
22
from pytensor .graph .basic import Node
23
23
from pytensor .raise_op import CheckAndRaise
24
24
from pytensor .scan import until
25
25
from pytensor .tensor import TensorConstant , TensorVariable
26
26
from pytensor .tensor .random .basic import NormalRV
27
27
from pytensor .tensor .random .op import RandomVariable
28
+ from pytensor .tensor .random .type import RandomType
28
29
29
30
from pymc .distributions .continuous import TruncatedNormal , bounded_cont_transform
30
31
from pymc .distributions .dist_math import check_parameters
31
32
from pymc .distributions .distribution import (
33
+ CustomSymbolicDistRV ,
32
34
Distribution ,
33
35
SymbolicRandomVariable ,
34
36
_support_point ,
38
40
from pymc .distributions .transforms import _default_transform
39
41
from pymc .exceptions import TruncationError
40
42
from pymc .logprob .abstract import _logcdf , _logprob
41
- from pymc .logprob .basic import icdf , logcdf
43
+ from pymc .logprob .basic import icdf , logcdf , logp
42
44
from pymc .math import logdiffexp
45
+ from pymc .pytensorf import collect_default_updates
43
46
from pymc .util import check_dist_not_registered
44
47
45
48
@@ -49,11 +52,17 @@ class TruncatedRV(SymbolicRandomVariable):
49
52
that represents a truncated univariate random variable.
50
53
"""
51
54
52
- default_output = 1
53
- base_rv_op = None
54
- max_n_steps = None
55
-
56
- def __init__ (self , * args , base_rv_op : Op , max_n_steps : int , ** kwargs ):
55
+ default_output : int = 0
56
+ base_rv_op : Op
57
+ max_n_steps : int
58
+
59
+ def __init__ (
60
+ self ,
61
+ * args ,
62
+ base_rv_op : Op ,
63
+ max_n_steps : int ,
64
+ ** kwargs ,
65
+ ):
57
66
self .base_rv_op = base_rv_op
58
67
self .max_n_steps = max_n_steps
59
68
self ._print_name = (
@@ -63,8 +72,13 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
63
72
super ().__init__ (* args , ** kwargs )
64
73
65
74
def update (self , node : Node ):
66
- """Return the update mapping for the internal RNG."""
67
- return {node .inputs [- 1 ]: node .outputs [0 ]}
75
+ """Return the update mapping for the internal RNGs.
76
+
77
+ TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.
78
+ """
79
+ rngs = [inp for inp in node .inputs if isinstance (inp .type , RandomType )]
80
+ next_rngs = [out for out in node .outputs if isinstance (out .type , RandomType )]
81
+ return dict (zip (rngs , next_rngs ))
68
82
69
83
70
84
@singledispatch
@@ -141,10 +155,14 @@ class Truncated(Distribution):
141
155
142
156
@classmethod
143
157
def dist (cls , dist , lower = None , upper = None , max_n_steps : int = 10_000 , ** kwargs ):
144
- if not (isinstance (dist , TensorVariable ) and isinstance (dist .owner .op , RandomVariable )):
158
+ if not (
159
+ isinstance (dist , TensorVariable )
160
+ and isinstance (dist .owner .op , RandomVariable | CustomSymbolicDistRV )
161
+ ):
145
162
if isinstance (dist .owner .op , SymbolicRandomVariable ):
146
163
raise NotImplementedError (
147
- f"Truncation not implemented for SymbolicRandomVariable { dist .owner .op } "
164
+ f"Truncation not implemented for SymbolicRandomVariable { dist .owner .op } .\n "
165
+ f"You can try wrapping the distribution inside a CustomDist instead."
148
166
)
149
167
raise ValueError (
150
168
f"Truncation dist must be a distribution created via the `.dist()` API, got { type (dist )} "
@@ -174,46 +192,54 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
174
192
if size is None :
175
193
size = pt .broadcast_shape (dist , lower , upper )
176
194
dist = change_dist_size (dist , new_size = size )
195
+ rv_inputs = [
196
+ inp
197
+ if not isinstance (inp .type , RandomType )
198
+ else pytensor .shared (np .random .default_rng ())
199
+ for inp in dist .owner .inputs
200
+ ]
201
+ graph_inputs = [* rv_inputs , lower , upper ]
177
202
178
203
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
179
- graph_inputs = [* dist .owner .inputs [1 :], lower , upper ]
180
- graph_inputs_ = [inp .type () for inp in graph_inputs ]
204
+ graph_inputs_ = [
205
+ inp .type () if not isinstance (inp .type , RandomType ) else inp for inp in graph_inputs
206
+ ]
181
207
* rv_inputs_ , lower_ , upper_ = graph_inputs_
182
208
183
- # We will use a Shared RNG variable because Scan demands it, even though it
184
- # would not be necessary for the OpFromGraph inverse cdf.
185
- rng = pytensor .shared (np .random .default_rng ())
186
- rv_ = dist .owner .op .make_node (rng , * rv_inputs_ ).default_output ()
209
+ rv_ = dist .owner .op .make_node (* rv_inputs_ ).default_output ()
187
210
188
211
# Try to use inverted cdf sampling
212
+ # truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
189
213
try :
190
- # For left truncated discrete RVs, we need to include the whole lower bound.
191
- # This may result in draws below the truncation range, if any uniform == 0
192
- lower_value = lower_ - 1 if dist .owner .op .dtype .startswith ("int" ) else lower_
193
- cdf_lower_ = pt .exp (logcdf (rv_ , lower_value ))
194
- cdf_upper_ = pt .exp (logcdf (rv_ , upper_ ))
195
- # It's okay to reuse the same rng here, because the rng in rv_ will not be
196
- # used by either the logcdf of icdf functions
214
+ logcdf_lower_ , logcdf_upper_ = Truncated ._create_logcdf_exprs (rv_ , rv_ , lower_ , upper_ )
215
+ # We use the first RNG from the base RV, so we don't have to introduce a new one
216
+ # This is not problematic because the RNG won't be used in the RV logcdf graph
217
+ uniform_rng_ = next (inp_ for inp_ in rv_inputs_ if isinstance (inp_ .type , RandomType ))
197
218
uniform_next_rng_ , uniform_ = pt .random .uniform (
198
- cdf_lower_ ,
199
- cdf_upper_ ,
200
- rng = rng ,
201
- size = rv_inputs_ [ 0 ] ,
219
+ pt . exp ( logcdf_lower_ ) ,
220
+ pt . exp ( logcdf_upper_ ) ,
221
+ rng = uniform_rng_ ,
222
+ size = rv_ . shape ,
202
223
).owner .outputs
203
- truncated_rv_ = icdf (rv_ , uniform_ )
224
+ truncated_rv_ = icdf (rv_ , uniform_ , warn_rvs = False )
204
225
return TruncatedRV (
205
226
base_rv_op = dist .owner .op ,
206
- inputs = [ * graph_inputs_ , rng ] ,
207
- outputs = [uniform_next_rng_ , truncated_rv_ ],
227
+ inputs = graph_inputs_ ,
228
+ outputs = [truncated_rv_ , uniform_next_rng_ ],
208
229
ndim_supp = 0 ,
209
230
max_n_steps = max_n_steps ,
210
- )(* graph_inputs , rng )
231
+ )(* graph_inputs )
211
232
except NotImplementedError :
212
233
pass
213
234
214
235
# Fallback to rejection sampling
215
- def loop_fn (truncated_rv , reject_draws , lower , upper , rng , * rv_inputs ):
216
- next_rng , new_truncated_rv = dist .owner .op .make_node (rng , * rv_inputs ).outputs
236
+ # truncated_rv = zeros(rv.shape)
237
+ # reject_draws = ones(rv.shape, dtype=bool)
238
+ # while any(reject_draws):
239
+ # truncated_rv[reject_draws] = draw(rv)[reject_draws]
240
+ # reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
241
+ def loop_fn (truncated_rv , reject_draws , lower , upper , * rv_inputs ):
242
+ new_truncated_rv = dist .owner .op .make_node (* rv_inputs_ ).default_output ()
217
243
# Avoid scalar boolean indexing
218
244
if truncated_rv .type .ndim == 0 :
219
245
truncated_rv = new_truncated_rv
@@ -226,7 +252,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
226
252
227
253
return (
228
254
(truncated_rv , reject_draws ),
229
- [( rng , next_rng )] ,
255
+ collect_default_updates ( new_truncated_rv ) ,
230
256
until (~ pt .any (reject_draws )),
231
257
)
232
258
@@ -236,7 +262,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
236
262
pt .zeros_like (rv_ ),
237
263
pt .ones_like (rv_ , dtype = bool ),
238
264
],
239
- non_sequences = [lower_ , upper_ , rng , * rv_inputs_ ],
265
+ non_sequences = [lower_ , upper_ , * rv_inputs_ ],
240
266
n_steps = max_n_steps ,
241
267
strict = True ,
242
268
)
@@ -246,24 +272,49 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
246
272
truncated_rv_ = TruncationCheck (f"Truncation did not converge in { max_n_steps } steps" )(
247
273
truncated_rv_ , convergence_
248
274
)
275
+ # Sort updates of each RNG so that they show in the same order as the input RNGs
276
+
277
+ def sort_updates (update ):
278
+ rng , next_rng = update
279
+ return graph_inputs .index (rng )
280
+
281
+ next_rngs = [next_rng for rng , next_rng in sorted (updates .items (), key = sort_updates )]
249
282
250
- [next_rng ] = updates .values ()
251
283
return TruncatedRV (
252
284
base_rv_op = dist .owner .op ,
253
- inputs = [ * graph_inputs_ , rng ] ,
254
- outputs = [next_rng , truncated_rv_ ],
285
+ inputs = graph_inputs_ ,
286
+ outputs = [truncated_rv_ , * next_rngs ],
255
287
ndim_supp = 0 ,
256
288
max_n_steps = max_n_steps ,
257
- )(* graph_inputs , rng )
289
+ )(* graph_inputs )
290
+
291
+ @staticmethod
292
+ def _create_logcdf_exprs (
293
+ base_rv : TensorVariable ,
294
+ value : TensorVariable ,
295
+ lower : TensorVariable ,
296
+ upper : TensorVariable ,
297
+ ) -> tuple [TensorVariable , TensorVariable ]:
298
+ """Create lower and upper logcdf expressions for base_rv.
299
+
300
+ Uses `value` as a template for broadcasting.
301
+ """
302
+ # For left truncated discrete RVs, we need to include the whole lower bound.
303
+ lower_value = lower - 1 if base_rv .type .dtype .startswith ("int" ) else lower
304
+ lower_value = pt .full_like (value , lower_value , dtype = config .floatX )
305
+ upper_value = pt .full_like (value , upper , dtype = config .floatX )
306
+ lower_logcdf = logcdf (base_rv , lower_value , warn_rvs = False )
307
+ upper_logcdf = graph_replace (lower_logcdf , {lower_value : upper_value })
308
+ return lower_logcdf , upper_logcdf
258
309
259
310
260
311
@_change_dist_size .register (TruncatedRV )
261
- def change_truncated_size (op , dist , new_size , expand ):
262
- * rv_inputs , lower , upper , rng = dist .owner .inputs
263
- # Recreate the original untruncated RV
264
- untruncated_rv = op . base_rv_op . make_node ( rng , * rv_inputs ). default_output ()
312
+ def change_truncated_size (op : TruncatedRV , truncated_rv , new_size , expand ):
313
+ * rv_inputs , lower , upper = truncated_rv .owner .inputs
314
+ untruncated_rv = op . base_rv_op . make_node ( * rv_inputs ). default_output ()
315
+
265
316
if expand :
266
- new_size = to_tuple (new_size ) + tuple (dist .shape )
317
+ new_size = to_tuple (new_size ) + tuple (truncated_rv .shape )
267
318
268
319
return Truncated .rv_op (
269
320
untruncated_rv ,
@@ -275,11 +326,11 @@ def change_truncated_size(op, dist, new_size, expand):
275
326
276
327
277
328
@_support_point .register (TruncatedRV )
278
- def truncated_support_point (op , rv , * inputs ):
279
- * rv_inputs , lower , upper , rng = inputs
329
+ def truncated_support_point (op : TruncatedRV , truncated_rv , * inputs ):
330
+ * rv_inputs , lower , upper = inputs
280
331
281
332
# recreate untruncated rv and respective support_point
282
- untruncated_rv = op .base_rv_op .make_node (rng , * rv_inputs ).default_output ()
333
+ untruncated_rv = op .base_rv_op .make_node (* rv_inputs ).default_output ()
283
334
untruncated_support_point = support_point (untruncated_rv )
284
335
285
336
fallback_support_point = pt .switch (
@@ -300,31 +351,25 @@ def truncated_support_point(op, rv, *inputs):
300
351
301
352
302
353
@_default_transform .register (TruncatedRV )
303
- def truncated_default_transform (op , rv ):
354
+ def truncated_default_transform (op , truncated_rv ):
304
355
# Don't transform discrete truncated distributions
305
- if op . base_rv_op .dtype .startswith ("int" ):
356
+ if truncated_rv . type .dtype .startswith ("int" ):
306
357
return None
307
- # Lower and Upper are the arguments -3 and -2
308
- return bounded_cont_transform (op , rv , bound_args_indices = (- 3 , - 2 ))
358
+ # Lower and Upper are the arguments -2 and -1
359
+ return bounded_cont_transform (op , truncated_rv , bound_args_indices = (- 2 , - 1 ))
309
360
310
361
311
362
@_logprob .register (TruncatedRV )
312
363
def truncated_logprob (op , values , * inputs , ** kwargs ):
313
364
(value ,) = values
314
-
315
- * rv_inputs , lower , upper , rng = inputs
316
- rv_inputs = [rng , * rv_inputs ]
365
+ * rv_inputs , lower , upper = inputs
317
366
318
367
base_rv_op = op .base_rv_op
319
- logp = _logprob (base_rv_op , (value ,), * rv_inputs , ** kwargs )
320
- # For left truncated RVs, we don't want to include the lower bound in the
321
- # normalization term
322
- lower_value = lower - 1 if base_rv_op .dtype .startswith ("int" ) else lower
323
- lower_logcdf = _logcdf (base_rv_op , lower_value , * rv_inputs , ** kwargs )
324
- upper_logcdf = _logcdf (base_rv_op , upper , * rv_inputs , ** kwargs )
325
-
368
+ base_rv = base_rv_op .make_node (* rv_inputs ).default_output ()
369
+ base_logp = logp (base_rv , value )
370
+ lower_logcdf , upper_logcdf = Truncated ._create_logcdf_exprs (base_rv , value , lower , upper )
326
371
if base_rv_op .name :
327
- logp .name = f"{ base_rv_op } _logprob"
372
+ base_logp .name = f"{ base_rv_op } _logprob"
328
373
lower_logcdf .name = f"{ base_rv_op } _lower_logcdf"
329
374
upper_logcdf .name = f"{ base_rv_op } _upper_logcdf"
330
375
@@ -339,37 +384,31 @@ def truncated_logprob(op, values, *inputs, **kwargs):
339
384
elif is_upper_bounded :
340
385
lognorm = upper_logcdf
341
386
342
- logp = logp - lognorm
387
+ truncated_logp = base_logp - lognorm
343
388
344
389
if is_lower_bounded :
345
- logp = pt .switch (value < lower , - np .inf , logp )
390
+ truncated_logp = pt .switch (value < lower , - np .inf , truncated_logp )
346
391
347
392
if is_upper_bounded :
348
- logp = pt .switch (value <= upper , logp , - np .inf )
393
+ truncated_logp = pt .switch (value <= upper , truncated_logp , - np .inf )
349
394
350
395
if is_lower_bounded and is_upper_bounded :
351
- logp = check_parameters (
352
- logp ,
396
+ truncated_logp = check_parameters (
397
+ truncated_logp ,
353
398
pt .le (lower , upper ),
354
399
msg = "lower_bound <= upper_bound" ,
355
400
)
356
401
357
- return logp
402
+ return truncated_logp
358
403
359
404
360
405
@_logcdf .register (TruncatedRV )
361
- def truncated_logcdf (op , value , * inputs , ** kwargs ):
362
- * rv_inputs , lower , upper , rng = inputs
363
- rv_inputs = [rng , * rv_inputs ]
364
-
365
- base_rv_op = op .base_rv_op
366
- logcdf = _logcdf (base_rv_op , value , * rv_inputs , ** kwargs )
406
+ def truncated_logcdf (op : TruncatedRV , value , * inputs , ** kwargs ):
407
+ * rv_inputs , lower , upper = inputs
367
408
368
- # For left truncated discrete RVs, we don't want to include the lower bound in the
369
- # normalization term
370
- lower_value = lower - 1 if base_rv_op .dtype .startswith ("int" ) else lower
371
- lower_logcdf = _logcdf (base_rv_op , lower_value , * rv_inputs , ** kwargs )
372
- upper_logcdf = _logcdf (base_rv_op , upper , * rv_inputs , ** kwargs )
409
+ base_rv = op .base_rv_op .make_node (* rv_inputs ).default_output ()
410
+ base_logcdf = logcdf (base_rv , value )
411
+ lower_logcdf , upper_logcdf = Truncated ._create_logcdf_exprs (base_rv , value , lower , upper )
373
412
374
413
is_lower_bounded = not (isinstance (lower , TensorConstant ) and np .all (np .isneginf (lower .value )))
375
414
is_upper_bounded = not (isinstance (upper , TensorConstant ) and np .all (np .isinf (upper .value )))
@@ -382,7 +421,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs):
382
421
elif is_upper_bounded :
383
422
lognorm = upper_logcdf
384
423
385
- logcdf_numerator = logdiffexp (logcdf , lower_logcdf ) if is_lower_bounded else logcdf
424
+ logcdf_numerator = logdiffexp (base_logcdf , lower_logcdf ) if is_lower_bounded else base_logcdf
386
425
logcdf_trunc = logcdf_numerator - lognorm
387
426
388
427
if is_lower_bounded :
0 commit comments