26
26
normalize_size_param ,
27
27
)
28
28
from pytensor .tensor .shape import shape_tuple
29
- from pytensor .tensor .type import TensorType , all_dtypes
29
+ from pytensor .tensor .type import TensorType
30
30
from pytensor .tensor .type_other import NoneConst
31
31
from pytensor .tensor .utils import _parse_gufunc_signature , safe_signature
32
32
from pytensor .tensor .variable import TensorVariable
@@ -64,7 +64,7 @@ def __init__(
64
64
signature: str
65
65
Numpy-like vectorized signature of the random variable.
66
66
dtype: str (optional)
67
- The dtype of the sampled output. If the value ``"floatX"`` is
67
+ The default dtype of the sampled output. If the value ``"floatX"`` is
68
68
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
69
69
``None`` (the default), the `dtype` keyword must be set when
70
70
`RandomVariable.make_node` is called.
@@ -289,8 +289,8 @@ def extract_batch_shape(p, ps, n):
289
289
return shape
290
290
291
291
def infer_shape (self , fgraph , node , input_shapes ):
292
- _ , size , _ , * dist_params = node .inputs
293
- _ , size_shape , _ , * param_shapes = input_shapes
292
+ _ , size , * dist_params = node .inputs
293
+ _ , size_shape , * param_shapes = input_shapes
294
294
295
295
try :
296
296
size_len = get_vector_length (size )
@@ -304,14 +304,34 @@ def infer_shape(self, fgraph, node, input_shapes):
304
304
return [None , list (shape )]
305
305
306
306
def __call__ (self , * args , size = None , name = None , rng = None , dtype = None , ** kwargs ):
307
- res = super ().__call__ (rng , size , dtype , * args , ** kwargs )
307
+ if dtype is None :
308
+ dtype = self .dtype
309
+ if dtype == "floatX" :
310
+ dtype = config .floatX
311
+
312
+ # We need to recreate the Op with the right dtype
313
+ if dtype != self .dtype :
314
+ # Check we are not switching from float to int
315
+ if self .dtype is not None :
316
+ if dtype .startswith ("float" ) != self .dtype .startswith ("float" ):
317
+ raise ValueError (
318
+ f"Cannot change the dtype of a { self .name } RV from { self .dtype } to { dtype } "
319
+ )
320
+ props = self ._props_dict ()
321
+ props ["dtype" ] = dtype
322
+ new_op = type (self )(** props )
323
+ return new_op .__call__ (
324
+ * args , size = size , name = name , rng = rng , dtype = dtype , ** kwargs
325
+ )
326
+
327
+ res = super ().__call__ (rng , size , * args , ** kwargs )
308
328
309
329
if name is not None :
310
330
res .name = name
311
331
312
332
return res
313
333
314
- def make_node (self , rng , size , dtype , * dist_params ):
334
+ def make_node (self , rng , size , * dist_params ):
315
335
"""Create a random variable node.
316
336
317
337
Parameters
@@ -351,22 +371,11 @@ def make_node(self, rng, size, dtype, *dist_params):
351
371
352
372
shape = self ._infer_shape (size , dist_params )
353
373
_ , static_shape = infer_static_shape (shape )
354
- dtype = self .dtype or dtype
355
374
356
- if dtype == "floatX" :
357
- dtype = config .floatX
358
- elif dtype is None or (isinstance (dtype , str ) and dtype not in all_dtypes ):
359
- raise TypeError ("dtype is unspecified" )
360
-
361
- if isinstance (dtype , str ):
362
- dtype_idx = constant (all_dtypes .index (dtype ), dtype = "int64" )
363
- else :
364
- dtype_idx = constant (dtype , dtype = "int64" )
365
-
366
- dtype = all_dtypes [dtype_idx .data ]
367
-
368
- inputs = (rng , size , dtype_idx , * dist_params )
375
+ dtype = self .dtype
369
376
out_var = TensorType (dtype = dtype , shape = static_shape )()
377
+
378
+ inputs = (rng , size , * dist_params )
370
379
outputs = (rng .type (), out_var )
371
380
372
381
return Apply (self , inputs , outputs )
@@ -381,12 +390,12 @@ def size_param(self, node) -> Variable:
381
390
382
391
def dist_params (self , node ) -> Sequence [Variable ]:
383
392
"""Return the node inpust corresponding to dist params"""
384
- return node .inputs [3 :]
393
+ return node .inputs [2 :]
385
394
386
395
def perform (self , node , inputs , outputs ):
387
396
rng_var_out , smpl_out = outputs
388
397
389
- rng , size , dtype , * args = inputs
398
+ rng , size , * args = inputs
390
399
391
400
out_var = node .outputs [1 ]
392
401
@@ -462,7 +471,7 @@ class DefaultGeneratorMakerOp(AbstractRNGConstructor):
462
471
463
472
@_vectorize_node .register (RandomVariable )
464
473
def vectorize_random_variable (
465
- op : RandomVariable , node : Apply , rng , size , dtype , * new_dist_params
474
+ op : RandomVariable , node : Apply , rng , size , * new_dist_params
466
475
) -> Apply :
467
476
# If size was provided originally and a new size hasn't been provided,
468
477
# We extend it to accommodate the new input batch dimensions.
@@ -494,4 +503,4 @@ def vectorize_random_variable(
494
503
new_size_dims = new_size [:new_ndim ]
495
504
size = concatenate ([new_size_dims , size ])
496
505
497
- return op .make_node (rng , size , dtype , * new_dist_params )
506
+ return op .make_node (rng , size , * new_dist_params )
0 commit comments