Skip to content

Commit fda240f

Browse files
authored
get_scalar_constant_value now raises for non-scalar inputs (#248)
* Rename old get_scalar_constant_value to get_underlying_scalar_constant
1 parent feccc41 commit fda240f

25 files changed

+204
-143
lines changed

doc/library/tensor/basic.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise.
577577
.. method:: round(mode="half_away_from_zero")
578578
:noindex:
579579
.. method:: trace()
580-
.. method:: get_scalar_constant_value()
580+
.. method:: get_underlying_scalar_constant_value()
581581
.. method:: zeros_like(model, dtype=None)
582582

583583
All the above methods are equivalent to NumPy for PyTensor on the current tensor.

pytensor/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _as_symbolic(x, **kwargs) -> Variable:
137137
# isort: on
138138

139139

140-
def get_scalar_constant_value(v):
140+
def get_underlying_scalar_constant(v):
141141
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
142142
143143
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
@@ -153,8 +153,8 @@ def get_scalar_constant_value(v):
153153
if sparse and isinstance(v.type, sparse.SparseTensorType):
154154
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
155155
data = v.owner.inputs[0]
156-
return tensor.get_scalar_constant_value(data)
157-
return tensor.get_scalar_constant_value(v)
156+
return tensor.get_underlying_scalar_constant_value(data)
157+
return tensor.get_underlying_scalar_constant_value(v)
158158

159159

160160
# isort: off

pytensor/gradient.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1325,7 +1325,7 @@ def try_to_copy_if_needed(var):
13251325
f" {i}. Since this input is only connected "
13261326
"to integer-valued outputs, it should "
13271327
"evaluate to zeros, but it evaluates to"
1328-
f"{pytensor.get_scalar_constant_value(term)}."
1328+
f"{pytensor.get_underlying_scalar_constant(term)}."
13291329
)
13301330
raise ValueError(msg)
13311331

@@ -2086,7 +2086,7 @@ def _is_zero(x):
20862086

20872087
no_constant_value = True
20882088
try:
2089-
constant_value = pytensor.get_scalar_constant_value(x)
2089+
constant_value = pytensor.get_underlying_scalar_constant(x)
20902090
no_constant_value = False
20912091
except pytensor.tensor.exceptions.NotScalarConstantError:
20922092
pass

pytensor/link/jax/dispatch/tensor_basic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
ScalarFromTensor,
1919
Split,
2020
TensorFromScalar,
21-
get_scalar_constant_value,
21+
get_underlying_scalar_constant_value,
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424

@@ -106,7 +106,7 @@ def join(axis, *tensors):
106106
def jax_funcify_Split(op: Split, node, **kwargs):
107107
_, axis, splits = node.inputs
108108
try:
109-
constant_axis = get_scalar_constant_value(axis)
109+
constant_axis = get_underlying_scalar_constant_value(axis)
110110
except NotScalarConstantError:
111111
constant_axis = None
112112
warnings.warn(
@@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
116116
try:
117117
constant_splits = np.array(
118118
[
119-
get_scalar_constant_value(splits[i])
119+
get_underlying_scalar_constant_value(splits[i])
120120
for i in range(get_vector_length(splits))
121121
]
122122
)

pytensor/scan/basic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from pytensor.graph.utils import MissingInputError, TestValueError
1313
from pytensor.scan.op import Scan, ScanInfo
1414
from pytensor.scan.utils import expand_empty, safe_new, until
15-
from pytensor.tensor.basic import get_scalar_constant_value
15+
from pytensor.tensor.basic import get_underlying_scalar_constant_value
1616
from pytensor.tensor.exceptions import NotScalarConstantError
1717
from pytensor.tensor.math import minimum
1818
from pytensor.tensor.shape import shape_padleft, unbroadcast
@@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x):
147147
isStr = False
148148
if not isNaN and not isInf:
149149
try:
150-
val = get_scalar_constant_value(x)
150+
val = get_underlying_scalar_constant_value(x)
151151
isInf = np.isinf(val)
152152
isNaN = np.isnan(val)
153153
except Exception:
@@ -476,7 +476,7 @@ def wrap_into_list(x):
476476
n_fixed_steps = int(n_steps)
477477
else:
478478
try:
479-
n_fixed_steps = at.get_scalar_constant_value(n_steps)
479+
n_fixed_steps = at.get_underlying_scalar_constant_value(n_steps)
480480
except NotScalarConstantError:
481481
n_fixed_steps = None
482482

pytensor/scan/rewriting.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@
4949
safe_new,
5050
scan_can_remove_outs,
5151
)
52-
from pytensor.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
52+
from pytensor.tensor.basic import (
53+
Alloc,
54+
AllocEmpty,
55+
get_underlying_scalar_constant_value,
56+
)
5357
from pytensor.tensor.elemwise import DimShuffle, Elemwise
5458
from pytensor.tensor.exceptions import NotScalarConstantError
5559
from pytensor.tensor.math import Dot, dot, maximum, minimum
@@ -1956,13 +1960,13 @@ def belongs_to_set(self, node, set_nodes):
19561960

19571961
nsteps = node.inputs[0]
19581962
try:
1959-
nsteps = int(get_scalar_constant_value(nsteps))
1963+
nsteps = int(get_underlying_scalar_constant_value(nsteps))
19601964
except NotScalarConstantError:
19611965
pass
19621966

19631967
rep_nsteps = rep.inputs[0]
19641968
try:
1965-
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
1969+
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
19661970
except NotScalarConstantError:
19671971
pass
19681972

pytensor/tensor/basic.py

+37-16
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,26 @@ def _obj_is_wrappable_as_tensor(x):
256256

257257

258258
def get_scalar_constant_value(
259+
v, elemwise=True, only_process_constants=False, max_recur=10
260+
):
261+
"""
262+
Checks whether 'v' is a scalar (ndim = 0).
263+
264+
If 'v' is a scalar then this function fetches the underlying constant by calling
265+
'get_underlying_scalar_constant_value()'.
266+
267+
If 'v' is not a scalar, it raises a NotScalarConstantError.
268+
269+
"""
270+
if isinstance(v, (Variable, np.ndarray)):
271+
if v.ndim != 0:
272+
raise NotScalarConstantError()
273+
return get_underlying_scalar_constant_value(
274+
v, elemwise, only_process_constants, max_recur
275+
)
276+
277+
278+
def get_underlying_scalar_constant_value(
259279
orig_v, elemwise=True, only_process_constants=False, max_recur=10
260280
):
261281
"""Return the constant scalar(0-D) value underlying variable `v`.
@@ -358,7 +378,7 @@ def get_scalar_constant_value(
358378
elif isinstance(v.owner.op, CheckAndRaise):
359379
# check if all conditions are constant and true
360380
conds = [
361-
get_scalar_constant_value(c, max_recur=max_recur)
381+
get_underlying_scalar_constant_value(c, max_recur=max_recur)
362382
for c in v.owner.inputs[1:]
363383
]
364384
if builtins.all(0 == c.ndim and c != 0 for c in conds):
@@ -372,7 +392,7 @@ def get_scalar_constant_value(
372392
continue
373393
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
374394
const = [
375-
get_scalar_constant_value(i, max_recur=max_recur)
395+
get_underlying_scalar_constant_value(i, max_recur=max_recur)
376396
for i in v.owner.inputs
377397
]
378398
ret = [[None]]
@@ -391,7 +411,7 @@ def get_scalar_constant_value(
391411
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
392412
):
393413
const = [
394-
get_scalar_constant_value(i, max_recur=max_recur)
414+
get_underlying_scalar_constant_value(i, max_recur=max_recur)
395415
for i in v.owner.inputs
396416
]
397417
ret = [[None]]
@@ -437,7 +457,7 @@ def get_scalar_constant_value(
437457
):
438458
idx = v.owner.op.idx_list[0]
439459
if isinstance(idx, Type):
440-
idx = get_scalar_constant_value(
460+
idx = get_underlying_scalar_constant_value(
441461
v.owner.inputs[1], max_recur=max_recur
442462
)
443463
try:
@@ -471,14 +491,14 @@ def get_scalar_constant_value(
471491
):
472492
idx = v.owner.op.idx_list[0]
473493
if isinstance(idx, Type):
474-
idx = get_scalar_constant_value(
494+
idx = get_underlying_scalar_constant_value(
475495
v.owner.inputs[1], max_recur=max_recur
476496
)
477497
# Python 2.4 does not support indexing with numpy.integer
478498
# So we cast it.
479499
idx = int(idx)
480500
ret = v.owner.inputs[0].owner.inputs[idx]
481-
ret = get_scalar_constant_value(ret, max_recur=max_recur)
501+
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
482502
# MakeVector can cast implicitly its input in some case.
483503
return _asarray(ret, dtype=v.type.dtype)
484504

@@ -493,7 +513,7 @@ def get_scalar_constant_value(
493513
idx_list = op.idx_list
494514
idx = idx_list[0]
495515
if isinstance(idx, Type):
496-
idx = get_scalar_constant_value(
516+
idx = get_underlying_scalar_constant_value(
497517
owner.inputs[1], max_recur=max_recur
498518
)
499519
grandparent = leftmost_parent.owner.inputs[0]
@@ -508,7 +528,7 @@ def get_scalar_constant_value(
508528

509529
if not (idx < ndim):
510530
msg = (
511-
"get_scalar_constant_value detected "
531+
"get_underlying_scalar_constant_value detected "
512532
f"deterministic IndexError: x.shape[{int(idx)}] "
513533
f"when x.ndim={int(ndim)}."
514534
)
@@ -1570,7 +1590,7 @@ def do_constant_folding(self, fgraph, node):
15701590
@_get_vector_length.register(Alloc)
15711591
def _get_vector_length_Alloc(var_inst, var):
15721592
try:
1573-
return get_scalar_constant_value(var.owner.inputs[1])
1593+
return get_underlying_scalar_constant_value(var.owner.inputs[1])
15741594
except NotScalarConstantError:
15751595
raise ValueError(f"Length of {var} cannot be determined")
15761596

@@ -1821,17 +1841,17 @@ def perform(self, node, inp, out_):
18211841

18221842
def extract_constant(x, elemwise=True, only_process_constants=False):
18231843
"""
1824-
This function is basically a call to tensor.get_scalar_constant_value.
1844+
This function is basically a call to tensor.get_underlying_scalar_constant_value.
18251845
18261846
The main difference is the behaviour in case of failure. While
1827-
get_scalar_constant_value raises an TypeError, this function returns x,
1847+
get_underlying_scalar_constant_value raises an TypeError, this function returns x,
18281848
as a tensor if possible. If x is a ScalarVariable from a
18291849
scalar_from_tensor, we remove the conversion. If x is just a
18301850
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
18311851
18321852
"""
18331853
try:
1834-
x = get_scalar_constant_value(x, elemwise, only_process_constants)
1854+
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
18351855
except NotScalarConstantError:
18361856
pass
18371857
if isinstance(x, aes.ScalarVariable) or isinstance(
@@ -2201,7 +2221,7 @@ def make_node(self, axis, *tensors):
22012221

22022222
if not isinstance(axis, int):
22032223
try:
2204-
axis = int(get_scalar_constant_value(axis))
2224+
axis = int(get_underlying_scalar_constant_value(axis))
22052225
except NotScalarConstantError:
22062226
pass
22072227

@@ -2450,7 +2470,7 @@ def infer_shape(self, fgraph, node, ishapes):
24502470
def _get_vector_length_Join(op, var):
24512471
axis, *arrays = var.owner.inputs
24522472
try:
2453-
axis = get_scalar_constant_value(axis)
2473+
axis = get_underlying_scalar_constant_value(axis)
24542474
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
24552475
return builtins.sum(get_vector_length(a) for a in arrays)
24562476
except NotScalarConstantError:
@@ -2862,7 +2882,7 @@ def infer_shape(self, fgraph, node, i_shapes):
28622882

28632883
def is_constant_value(var, value):
28642884
try:
2865-
v = get_scalar_constant_value(var)
2885+
v = get_underlying_scalar_constant_value(var)
28662886
return np.all(v == value)
28672887
except NotScalarConstantError:
28682888
pass
@@ -3774,7 +3794,7 @@ def make_node(self, a, choices):
37743794
static_out_shape = ()
37753795
for s in out_shape:
37763796
try:
3777-
s_val = pytensor.get_scalar_constant_value(s)
3797+
s_val = pytensor.get_underlying_scalar_constant(s)
37783798
except (NotScalarConstantError, AttributeError):
37793799
s_val = None
37803800

@@ -4095,6 +4115,7 @@ def take_along_axis(arr, indices, axis=0):
40954115
"scalar_from_tensor",
40964116
"tensor_from_scalar",
40974117
"get_scalar_constant_value",
4118+
"get_underlying_scalar_constant_value",
40984119
"constant",
40994120
"as_tensor_variable",
41004121
"as_tensor",

pytensor/tensor/blas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node):
18341834
xv = x.dimshuffle(0)
18351835
yv = y.dimshuffle(1)
18361836
try:
1837-
bval = at.get_scalar_constant_value(b)
1837+
bval = at.get_underlying_scalar_constant_value(b)
18381838
except NotScalarConstantError:
18391839
# b isn't a constant, GEMM is doing useful pre-scaling
18401840
return

pytensor/tensor/conv/abstract_conv.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
from pytensor.graph.basic import Apply, Variable
2525
from pytensor.graph.op import Op
2626
from pytensor.raise_op import Assert
27-
from pytensor.tensor.basic import as_tensor_variable, get_scalar_constant_value
27+
from pytensor.tensor.basic import (
28+
as_tensor_variable,
29+
get_underlying_scalar_constant_value,
30+
)
2831
from pytensor.tensor.exceptions import NotScalarConstantError
2932
from pytensor.tensor.var import TensorConstant, TensorVariable
3033

@@ -495,8 +498,8 @@ def check_dim(given, computed):
495498
if given is None or computed is None:
496499
return True
497500
try:
498-
given = get_scalar_constant_value(given)
499-
computed = get_scalar_constant_value(computed)
501+
given = get_underlying_scalar_constant_value(given)
502+
computed = get_underlying_scalar_constant_value(computed)
500503
return int(given) == int(computed)
501504
except NotScalarConstantError:
502505
# no answer possible, accept for now
@@ -532,7 +535,7 @@ def assert_conv_shape(shape):
532535
out_shape = []
533536
for i, n in enumerate(shape):
534537
try:
535-
const_n = get_scalar_constant_value(n)
538+
const_n = get_underlying_scalar_constant_value(n)
536539
if i < 2:
537540
if const_n < 0:
538541
raise ValueError(
@@ -2200,7 +2203,9 @@ def __init__(
22002203
if imshp_i is not None:
22012204
# Components of imshp should be constant or ints
22022205
try:
2203-
get_scalar_constant_value(imshp_i, only_process_constants=True)
2206+
get_underlying_scalar_constant_value(
2207+
imshp_i, only_process_constants=True
2208+
)
22042209
except NotScalarConstantError:
22052210
raise ValueError(
22062211
"imshp should be None or a tuple of constant int values"
@@ -2213,7 +2218,9 @@ def __init__(
22132218
if kshp_i is not None:
22142219
# Components of kshp should be constant or ints
22152220
try:
2216-
get_scalar_constant_value(kshp_i, only_process_constants=True)
2221+
get_underlying_scalar_constant_value(
2222+
kshp_i, only_process_constants=True
2223+
)
22172224
except NotScalarConstantError:
22182225
raise ValueError(
22192226
"kshp should be None or a tuple of constant int values"

pytensor/tensor/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def perform(self, node, inputs, output_storage):
759759
ufunc = self.ufunc
760760
elif not hasattr(node.tag, "ufunc"):
761761
# It happen that make_thunk isn't called, like in
762-
# get_scalar_constant_value
762+
# get_underlying_scalar_constant_value
763763
self.prepare_node(node, None, None, "py")
764764
# prepare_node will add ufunc to self or the tag
765765
# depending if we can reuse it or not. So we need to

pytensor/tensor/exceptions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ class ShapeError(Exception):
44

55
class NotScalarConstantError(Exception):
66
"""
7-
Raised by get_scalar_constant_value if called on something that is
7+
Raised by get_underlying_scalar_constant_value if called on something that is
88
not a scalar constant.
99
"""
1010

pytensor/tensor/extra_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def make_node(self, x, repeats):
671671
out_shape = [None]
672672
else:
673673
try:
674-
const_reps = at.get_scalar_constant_value(repeats)
674+
const_reps = at.get_underlying_scalar_constant_value(repeats)
675675
except NotScalarConstantError:
676676
const_reps = None
677677
if const_reps == 1:

0 commit comments

Comments
 (0)