@@ -256,6 +256,26 @@ def _obj_is_wrappable_as_tensor(x):
256
256
257
257
258
258
def get_scalar_constant_value (
259
+ v , elemwise = True , only_process_constants = False , max_recur = 10
260
+ ):
261
+ """
262
+ Checks whether 'v' is a scalar (ndim = 0).
263
+
264
+ If 'v' is a scalar then this function fetches the underlying constant by calling
265
+ 'get_underlying_scalar_constant_value()'.
266
+
267
+ If 'v' is not a scalar, it raises a NotScalarConstantError.
268
+
269
+ """
270
+ if isinstance (v , (Variable , np .ndarray )):
271
+ if v .ndim != 0 :
272
+ raise NotScalarConstantError ()
273
+ return get_underlying_scalar_constant_value (
274
+ v , elemwise , only_process_constants , max_recur
275
+ )
276
+
277
+
278
+ def get_underlying_scalar_constant_value (
259
279
orig_v , elemwise = True , only_process_constants = False , max_recur = 10
260
280
):
261
281
"""Return the constant scalar(0-D) value underlying variable `v`.
@@ -358,7 +378,7 @@ def get_scalar_constant_value(
358
378
elif isinstance (v .owner .op , CheckAndRaise ):
359
379
# check if all conditions are constant and true
360
380
conds = [
361
- get_scalar_constant_value (c , max_recur = max_recur )
381
+ get_underlying_scalar_constant_value (c , max_recur = max_recur )
362
382
for c in v .owner .inputs [1 :]
363
383
]
364
384
if builtins .all (0 == c .ndim and c != 0 for c in conds ):
@@ -372,7 +392,7 @@ def get_scalar_constant_value(
372
392
continue
373
393
if isinstance (v .owner .op , _scalar_constant_value_elemwise_ops ):
374
394
const = [
375
- get_scalar_constant_value (i , max_recur = max_recur )
395
+ get_underlying_scalar_constant_value (i , max_recur = max_recur )
376
396
for i in v .owner .inputs
377
397
]
378
398
ret = [[None ]]
@@ -391,7 +411,7 @@ def get_scalar_constant_value(
391
411
v .owner .op .scalar_op , _scalar_constant_value_elemwise_ops
392
412
):
393
413
const = [
394
- get_scalar_constant_value (i , max_recur = max_recur )
414
+ get_underlying_scalar_constant_value (i , max_recur = max_recur )
395
415
for i in v .owner .inputs
396
416
]
397
417
ret = [[None ]]
@@ -437,7 +457,7 @@ def get_scalar_constant_value(
437
457
):
438
458
idx = v .owner .op .idx_list [0 ]
439
459
if isinstance (idx , Type ):
440
- idx = get_scalar_constant_value (
460
+ idx = get_underlying_scalar_constant_value (
441
461
v .owner .inputs [1 ], max_recur = max_recur
442
462
)
443
463
try :
@@ -471,14 +491,14 @@ def get_scalar_constant_value(
471
491
):
472
492
idx = v .owner .op .idx_list [0 ]
473
493
if isinstance (idx , Type ):
474
- idx = get_scalar_constant_value (
494
+ idx = get_underlying_scalar_constant_value (
475
495
v .owner .inputs [1 ], max_recur = max_recur
476
496
)
477
497
# Python 2.4 does not support indexing with numpy.integer
478
498
# So we cast it.
479
499
idx = int (idx )
480
500
ret = v .owner .inputs [0 ].owner .inputs [idx ]
481
- ret = get_scalar_constant_value (ret , max_recur = max_recur )
501
+ ret = get_underlying_scalar_constant_value (ret , max_recur = max_recur )
482
502
# MakeVector can cast implicitly its input in some case.
483
503
return _asarray (ret , dtype = v .type .dtype )
484
504
@@ -493,7 +513,7 @@ def get_scalar_constant_value(
493
513
idx_list = op .idx_list
494
514
idx = idx_list [0 ]
495
515
if isinstance (idx , Type ):
496
- idx = get_scalar_constant_value (
516
+ idx = get_underlying_scalar_constant_value (
497
517
owner .inputs [1 ], max_recur = max_recur
498
518
)
499
519
grandparent = leftmost_parent .owner .inputs [0 ]
@@ -508,7 +528,7 @@ def get_scalar_constant_value(
508
528
509
529
if not (idx < ndim ):
510
530
msg = (
511
- "get_scalar_constant_value detected "
531
+ "get_underlying_scalar_constant_value detected "
512
532
f"deterministic IndexError: x.shape[{ int (idx )} ] "
513
533
f"when x.ndim={ int (ndim )} ."
514
534
)
@@ -1570,7 +1590,7 @@ def do_constant_folding(self, fgraph, node):
1570
1590
@_get_vector_length .register (Alloc )
1571
1591
def _get_vector_length_Alloc (var_inst , var ):
1572
1592
try :
1573
- return get_scalar_constant_value (var .owner .inputs [1 ])
1593
+ return get_underlying_scalar_constant_value (var .owner .inputs [1 ])
1574
1594
except NotScalarConstantError :
1575
1595
raise ValueError (f"Length of { var } cannot be determined" )
1576
1596
@@ -1821,17 +1841,17 @@ def perform(self, node, inp, out_):
1821
1841
1822
1842
def extract_constant (x , elemwise = True , only_process_constants = False ):
1823
1843
"""
1824
- This function is basically a call to tensor.get_scalar_constant_value .
1844
+ This function is basically a call to tensor.get_underlying_scalar_constant_value .
1825
1845
1826
1846
The main difference is the behaviour in case of failure. While
1827
- get_scalar_constant_value raises an TypeError, this function returns x,
1847
+ get_underlying_scalar_constant_value raises an TypeError, this function returns x,
1828
1848
as a tensor if possible. If x is a ScalarVariable from a
1829
1849
scalar_from_tensor, we remove the conversion. If x is just a
1830
1850
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
1831
1851
1832
1852
"""
1833
1853
try :
1834
- x = get_scalar_constant_value (x , elemwise , only_process_constants )
1854
+ x = get_underlying_scalar_constant_value (x , elemwise , only_process_constants )
1835
1855
except NotScalarConstantError :
1836
1856
pass
1837
1857
if isinstance (x , aes .ScalarVariable ) or isinstance (
@@ -2201,7 +2221,7 @@ def make_node(self, axis, *tensors):
2201
2221
2202
2222
if not isinstance (axis , int ):
2203
2223
try :
2204
- axis = int (get_scalar_constant_value (axis ))
2224
+ axis = int (get_underlying_scalar_constant_value (axis ))
2205
2225
except NotScalarConstantError :
2206
2226
pass
2207
2227
@@ -2450,7 +2470,7 @@ def infer_shape(self, fgraph, node, ishapes):
2450
2470
def _get_vector_length_Join (op , var ):
2451
2471
axis , * arrays = var .owner .inputs
2452
2472
try :
2453
- axis = get_scalar_constant_value (axis )
2473
+ axis = get_underlying_scalar_constant_value (axis )
2454
2474
assert axis == 0 and builtins .all (a .ndim == 1 for a in arrays )
2455
2475
return builtins .sum (get_vector_length (a ) for a in arrays )
2456
2476
except NotScalarConstantError :
@@ -2862,7 +2882,7 @@ def infer_shape(self, fgraph, node, i_shapes):
2862
2882
2863
2883
def is_constant_value (var , value ):
2864
2884
try :
2865
- v = get_scalar_constant_value (var )
2885
+ v = get_underlying_scalar_constant_value (var )
2866
2886
return np .all (v == value )
2867
2887
except NotScalarConstantError :
2868
2888
pass
@@ -3774,7 +3794,7 @@ def make_node(self, a, choices):
3774
3794
static_out_shape = ()
3775
3795
for s in out_shape :
3776
3796
try :
3777
- s_val = pytensor .get_scalar_constant_value (s )
3797
+ s_val = pytensor .get_underlying_scalar_constant (s )
3778
3798
except (NotScalarConstantError , AttributeError ):
3779
3799
s_val = None
3780
3800
@@ -4095,6 +4115,7 @@ def take_along_axis(arr, indices, axis=0):
4095
4115
"scalar_from_tensor" ,
4096
4116
"tensor_from_scalar" ,
4097
4117
"get_scalar_constant_value" ,
4118
+ "get_underlying_scalar_constant_value" ,
4098
4119
"constant" ,
4099
4120
"as_tensor_variable" ,
4100
4121
"as_tensor" ,
0 commit comments