Skip to content

Commit 7002018

Browse files
committed
Implement shape inference for boolean advanced indexing
1 parent d108ebb commit 7002018

File tree

4 files changed

+160
-41
lines changed

4 files changed

+160
-41
lines changed

pytensor/tensor/extra_ops.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,18 @@
2121
from pytensor.raise_op import Assert
2222
from pytensor.scalar import int32 as int_t
2323
from pytensor.scalar import upcast
24+
from pytensor.tensor import as_tensor_variable
2425
from pytensor.tensor import basic as at
2526
from pytensor.tensor import get_vector_length
2627
from pytensor.tensor.exceptions import NotScalarConstantError
27-
from pytensor.tensor.math import abs as at_abs
28+
from pytensor.tensor.math import abs as pt_abs
2829
from pytensor.tensor.math import all as pt_all
2930
from pytensor.tensor.math import eq as pt_eq
30-
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
31+
from pytensor.tensor.math import ge, lt
32+
from pytensor.tensor.math import max as pt_max
33+
from pytensor.tensor.math import maximum, minimum, prod
3134
from pytensor.tensor.math import sum as at_sum
35+
from pytensor.tensor.math import switch
3236
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
3337
from pytensor.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
3438
from pytensor.tensor.var import TensorVariable
@@ -1063,7 +1067,7 @@ def grad(self, inp, cost_grad):
10631067
# only valid for matrices
10641068
wr_a = fill_diagonal_offset(grad, 0, offset)
10651069

1066-
offset_abs = at_abs(offset)
1070+
offset_abs = pt_abs(offset)
10671071
pos_offset_flag = ge(offset, 0)
10681072
neg_offset_flag = lt(offset, 0)
10691073
min_wh = minimum(width, height)
@@ -1442,6 +1446,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
14421446
"axes that have a statically known length 1. Use `specify_broadcastable` to "
14431447
"inform PyTensor of a known shape."
14441448
)
1449+
_runtime_broadcast_assert = Assert("Could not broadcast dimensions.")
14451450

14461451

14471452
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
@@ -1465,6 +1470,7 @@ def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
14651470
def broadcast_shape_iter(
14661471
arrays: Iterable[Union[TensorVariable, Tuple[TensorVariable, ...]]],
14671472
arrays_are_shapes: bool = False,
1473+
allow_runtime_broadcast: bool = False,
14681474
) -> Tuple[aes.ScalarVariable, ...]:
14691475
r"""Compute the shape resulting from broadcasting arrays.
14701476
@@ -1480,22 +1486,24 @@ def broadcast_shape_iter(
14801486
arrays
14811487
An iterable of tensors, or a tuple of shapes (as tuples),
14821488
for which the broadcast shape is computed.
1483-
arrays_are_shapes
1489+
arrays_are_shapes: bool, default False
14841490
Indicates whether or not the `arrays` contains shape tuples.
14851491
If you use this approach, make sure that the broadcastable dimensions
14861492
are (scalar) constants with the value ``1``--or simply the integer
1487-
``1``.
1493+
``1``. This is not revelant if `allow_runtime_broadcast` is True.
1494+
allow_runtime_broadcast: bool, default False
1495+
Whether to allow non-statically known broadcast on the shape computation.
14881496
14891497
"""
1490-
one_at = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
1498+
one = pytensor.scalar.ScalarConstant(pytensor.scalar.int64, 1)
14911499

14921500
if arrays_are_shapes:
14931501
max_dims = max(len(a) for a in arrays)
14941502

14951503
array_shapes = [
1496-
(one_at,) * (max_dims - len(a))
1504+
(one,) * (max_dims - len(a))
14971505
+ tuple(
1498-
one_at
1506+
one
14991507
if sh == 1 or isinstance(sh, Constant) and sh.value == 1
15001508
else (aes.as_scalar(sh) if not isinstance(sh, Variable) else sh)
15011509
for sh in a
@@ -1508,10 +1516,8 @@ def broadcast_shape_iter(
15081516
_arrays = tuple(at.as_tensor_variable(a) for a in arrays)
15091517

15101518
array_shapes = [
1511-
(one_at,) * (max_dims - a.ndim)
1512-
+ tuple(
1513-
one_at if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape)
1514-
)
1519+
(one,) * (max_dims - a.ndim)
1520+
+ tuple(one if t_sh == 1 else sh for sh, t_sh in zip(a.shape, a.type.shape))
15151521
for a in _arrays
15161522
]
15171523

@@ -1520,11 +1526,11 @@ def broadcast_shape_iter(
15201526
for dim_shapes in zip(*array_shapes):
15211527
# Get the shapes in this dimension that are not broadcastable
15221528
# (i.e. not symbolically known to be broadcastable)
1523-
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
1529+
non_bcast_shapes = [shape for shape in dim_shapes if shape != one]
15241530

15251531
if len(non_bcast_shapes) == 0:
15261532
# Every shape was broadcastable in this dimension
1527-
result_dims.append(one_at)
1533+
result_dims.append(one)
15281534
elif len(non_bcast_shapes) == 1:
15291535
# Only one shape might not be broadcastable in this dimension
15301536
result_dims.extend(non_bcast_shapes)
@@ -1554,9 +1560,26 @@ def broadcast_shape_iter(
15541560
result_dims.append(first_length)
15551561
continue
15561562

1557-
# Add assert that all remaining shapes are equal
1558-
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
1559-
result_dims.append(_broadcast_assert(first_length, condition))
1563+
if not allow_runtime_broadcast:
1564+
# Add assert that all remaining shapes are equal
1565+
condition = pt_all(
1566+
[pt_eq(first_length, other) for other in other_lengths]
1567+
)
1568+
result_dims.append(_broadcast_assert(first_length, condition))
1569+
else:
1570+
lengths = as_tensor_variable((first_length, *other_lengths))
1571+
runtime_broadcastable = pt_eq(lengths, one)
1572+
result_dim = pt_abs(
1573+
pt_max(switch(runtime_broadcastable, -one, lengths))
1574+
)
1575+
condition = pt_all(
1576+
switch(
1577+
~runtime_broadcastable,
1578+
pt_eq(lengths, result_dim),
1579+
np.array(True),
1580+
)
1581+
)
1582+
result_dims.append(_runtime_broadcast_assert(result_dim, condition))
15601583

15611584
return tuple(result_dims)
15621585

pytensor/tensor/subtensor.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,11 @@
2020
from pytensor.printing import Printer, pprint, set_precedence
2121
from pytensor.scalar.basic import ScalarConstant
2222
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
23-
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
23+
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value, nonzero
2424
from pytensor.tensor.elemwise import DimShuffle
25-
from pytensor.tensor.exceptions import (
26-
AdvancedIndexingError,
27-
NotScalarConstantError,
28-
ShapeError,
29-
)
25+
from pytensor.tensor.exceptions import AdvancedIndexingError, NotScalarConstantError
3026
from pytensor.tensor.math import clip
31-
from pytensor.tensor.shape import Reshape, specify_broadcastable
27+
from pytensor.tensor.shape import Reshape, shape_i, specify_broadcastable
3228
from pytensor.tensor.type import (
3329
TensorType,
3430
bscalar,
@@ -510,7 +506,11 @@ def indexed_result_shape(array_shape, indices, indices_are_shapes=False):
510506
from pytensor.tensor.extra_ops import broadcast_shape
511507

512508
res_shape += broadcast_shape(
513-
*grp_indices, arrays_are_shapes=indices_are_shapes
509+
*grp_indices,
510+
arrays_are_shapes=indices_are_shapes,
511+
# The AdvancedIndexing Op relies on the Numpy implementation which allows runtime broadcasting.
512+
# As long as that is true, the shape inference has to respect that this is not an error.
513+
allow_runtime_broadcast=True,
514514
)
515515

516516
res_shape += tuple(array_shape[dim] for dim in remaining_dims)
@@ -2584,26 +2584,47 @@ def R_op(self, inputs, eval_points):
25842584
return self.make_node(eval_points[0], *inputs[1:]).outputs
25852585

25862586
def infer_shape(self, fgraph, node, ishapes):
2587-
indices = node.inputs[1:]
2588-
index_shapes = list(ishapes[1:])
2589-
for i, idx in enumerate(indices):
2590-
if (
2587+
def is_bool_index(idx):
2588+
return (
25912589
isinstance(idx, (np.bool_, bool))
25922590
or getattr(idx, "dtype", None) == "bool"
2593-
):
2594-
raise ShapeError(
2595-
"Shape inference for boolean indices is not implemented"
2591+
)
2592+
2593+
indices = node.inputs[1:]
2594+
index_shapes = []
2595+
for idx, ishape in zip(indices, ishapes[1:]):
2596+
# Mixed bool indexes are converted to nonzero entries
2597+
if is_bool_index(idx):
2598+
index_shapes.extend(
2599+
(shape_i(nz_dim, 0, fgraph=fgraph),) for nz_dim in nonzero(idx)
25962600
)
25972601
# The `ishapes` entries for `SliceType`s will be None, and
25982602
# we need to give `indexed_result_shape` the actual slices.
2599-
if isinstance(getattr(idx, "type", None), SliceType):
2600-
index_shapes[i] = idx
2603+
elif isinstance(getattr(idx, "type", None), SliceType):
2604+
index_shapes.append(idx)
2605+
else:
2606+
index_shapes.append(ishape)
26012607

2602-
res_shape = indexed_result_shape(
2603-
ishapes[0], index_shapes, indices_are_shapes=True
2608+
res_shape = list(
2609+
indexed_result_shape(ishapes[0], index_shapes, indices_are_shapes=True)
26042610
)
2611+
2612+
adv_indices = [idx for idx in indices if not is_basic_idx(idx)]
2613+
bool_indices = [idx for idx in adv_indices if is_bool_index(idx)]
2614+
2615+
# Special logic when the only advanced index group is of bool type.
2616+
# We can replace the nonzeros by a sum of the whole bool variable.
2617+
if len(bool_indices) == 1 and len(adv_indices) == 1:
2618+
[bool_index] = bool_indices
2619+
# Find the output dim associated with the bool index group
2620+
# Because there are no more advanced index groups, there is exactly
2621+
# one output dim per index variable up to the bool group.
2622+
# Note: Scalar integer indexing counts as advanced indexing.
2623+
start_dim = indices.index(bool_index)
2624+
res_shape[start_dim] = bool_index.sum()
2625+
26052626
assert node.outputs[0].ndim == len(res_shape)
2606-
return [list(res_shape)]
2627+
return [res_shape]
26072628

26082629
def perform(self, node, inputs, out_):
26092630
(out,) = out_

tests/tensor/test_extra_ops.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,9 +1087,17 @@ def shape_tuple(x, use_bcast=True):
10871087
assert any(
10881088
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
10891089
)
1090-
# This should fail because it would need dynamic broadcasting
10911090
with pytest.raises(AssertionError):
10921091
assert np.array_equal([z.eval() for z in b_at], b.shape)
1092+
# But fine if we allow_runtime_broadcast
1093+
b_at = broadcast_shape(
1094+
shape_tuple(x_at, use_bcast=False),
1095+
shape_tuple(y_at, use_bcast=False),
1096+
arrays_are_shapes=True,
1097+
allow_runtime_broadcast=True,
1098+
)
1099+
assert np.array_equal([z.eval() for z in b_at], b.shape)
1100+
# Or if static bcast is known
10931101
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
10941102
assert np.array_equal([z.eval() for z in b_at], b.shape)
10951103

tests/tensor/test_subtensor.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
tensor,
6464
tensor3,
6565
tensor4,
66+
tensor5,
6667
vector,
6768
)
6869
from pytensor.tensor.type_other import NoneConst, SliceConstant, make_slice, slicetype
@@ -2150,6 +2151,12 @@ def fun(x, y):
21502151

21512152

21522153
class TestInferShape(utt.InferShapeTester):
2154+
@staticmethod
2155+
def random_bool_mask(shape, rng=None):
2156+
if rng is None:
2157+
rng = np.random.default_rng()
2158+
return rng.binomial(n=1, p=0.5, size=shape).astype(bool)
2159+
21532160
def test_IncSubtensor(self):
21542161
admat = dmatrix()
21552162
bdmat = dmatrix()
@@ -2439,25 +2446,85 @@ def test_AdvancedSubtensor_bool(self):
24392446
n = dmatrix()
24402447
n_val = np.arange(6).reshape((2, 3))
24412448

2442-
# infer_shape is not implemented, but it should not crash
2449+
# Shape inference requires runtime broadcasting between the nonzero() shapes
24432450
self._compile_and_check(
24442451
[n],
24452452
[n[n[:, 0] > 2, n[0, :] > 2]],
24462453
[n_val],
24472454
AdvancedSubtensor,
2448-
check_topo=False,
24492455
)
24502456
self._compile_and_check(
24512457
[n],
24522458
[n[n[:, 0] > 2]],
24532459
[n_val],
24542460
AdvancedSubtensor,
2455-
check_topo=False,
2461+
)
2462+
self._compile_and_check(
2463+
[n],
2464+
[n[:, np.array([True, False, True])]],
2465+
[n_val],
2466+
AdvancedSubtensor,
2467+
)
2468+
self._compile_and_check(
2469+
[n],
2470+
[n[np.array([False, False]), 1:]],
2471+
[n_val],
2472+
AdvancedSubtensor,
2473+
)
2474+
self._compile_and_check(
2475+
[n],
2476+
[n[np.array([True, True]), 0]],
2477+
[n_val],
2478+
AdvancedSubtensor,
2479+
)
2480+
self._compile_and_check(
2481+
[n],
2482+
[n[self.random_bool_mask(n_val.shape)]],
2483+
[n_val],
2484+
AdvancedSubtensor,
2485+
)
2486+
self._compile_and_check(
2487+
[n],
2488+
[n[None, self.random_bool_mask(n_val.shape), None]],
2489+
[n_val],
2490+
AdvancedSubtensor,
2491+
)
2492+
self._compile_and_check(
2493+
[n],
2494+
[n[slice(5, None), self.random_bool_mask(n_val.shape[1])]],
2495+
[n_val],
2496+
AdvancedSubtensor,
24562497
)
24572498

24582499
abs_res = n[~isinf(n)]
24592500
assert abs_res.type.shape == (None,)
24602501

2502+
def test_AdvancedSubtensor_bool_mixed(self):
2503+
n = tensor5("x", dtype="float64")
2504+
shape = (18, 3, 4, 5, 6)
2505+
n_val = np.arange(np.prod(shape)).reshape(shape)
2506+
self._compile_and_check(
2507+
[n],
2508+
# Consecutive advanced index
2509+
[n[1:, self.random_bool_mask((3, 4)), 0, 1:]],
2510+
[n_val],
2511+
AdvancedSubtensor,
2512+
)
2513+
self._compile_and_check(
2514+
[n],
2515+
# Non-consecutive advanced index
2516+
[n[1:, self.random_bool_mask((3, 4)), 1:, 0]],
2517+
[n_val],
2518+
AdvancedSubtensor,
2519+
)
2520+
self._compile_and_check(
2521+
[n],
2522+
# Non-consecutive advanced index
2523+
[n[1:, self.random_bool_mask((3,)), 1:, None, np.zeros((6, 1), dtype=int)]],
2524+
[n_val],
2525+
AdvancedSubtensor,
2526+
)
2527+
24612528

24622529
@config.change_flags(compute_test_value="raise")
24632530
def test_basic_shape():

0 commit comments

Comments
 (0)