Skip to content

Commit c798cc3

Browse files
committed
Fix advanced indexing in subtensor_rv_lift
Also excludes the following cases: 1. expand_dims via broadcasting 2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
1 parent a85ecce commit c798cc3

File tree

2 files changed

+201
-71
lines changed

2 files changed

+201
-71
lines changed

pytensor/tensor/random/rewriting/basic.py

Lines changed: 97 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from itertools import zip_longest
1+
from itertools import chain
22

33
from pytensor.compile import optdb
44
from pytensor.configdefaults import config
5+
from pytensor.graph import ancestors
56
from pytensor.graph.op import compute_test_value
67
from pytensor.graph.rewriting.basic import in2out, node_rewriter
8+
from pytensor.scalar import integer_types
79
from pytensor.tensor import NoneConst
810
from pytensor.tensor.basic import constant, get_vector_length
911
from pytensor.tensor.elemwise import DimShuffle
1012
from pytensor.tensor.extra_ops import broadcast_to
11-
from pytensor.tensor.math import sum as at_sum
1213
from pytensor.tensor.random.op import RandomVariable
1314
from pytensor.tensor.random.utils import broadcast_params
1415
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
@@ -18,7 +19,6 @@
1819
Subtensor,
1920
as_index_variable,
2021
get_idx_list,
21-
indexed_result_shape,
2222
)
2323
from pytensor.tensor.type_other import SliceType
2424

@@ -207,98 +207,136 @@ def local_subtensor_rv_lift(fgraph, node):
207207
``mvnormal(mu, cov, size=(2,))[0, 0]``.
208208
"""
209209

210-
st_op = node.op
210+
def is_nd_advanced_idx(idx, dtype):
211+
if isinstance(dtype, str):
212+
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
213+
else:
214+
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)
211215

212-
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
213-
return False
216+
subtensor_op = node.op
214217

215218
rv = node.inputs[0]
216219
rv_node = rv.owner
217220

218221
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
219222
return False
220223

224+
shape_feature = getattr(fgraph, "shape_feature", None)
225+
if not shape_feature:
226+
return None
227+
228+
# Use shape_feature to facilitate inferring final shape
229+
output_shape = fgraph.shape_feature.shape_of.get(node.outputs[0], None)
230+
if output_shape is None or rv in ancestors(output_shape):
231+
return None
232+
221233
rv_op = rv_node.op
222234
rng, size, dtype, *dist_params = rv_node.inputs
223235

224236
# Parse indices
225-
idx_list = getattr(st_op, "idx_list", None)
237+
idx_list = getattr(subtensor_op, "idx_list", None)
226238
if idx_list:
227-
cdata = get_idx_list(node.inputs, idx_list)
239+
idx_vars = get_idx_list(node.inputs, idx_list)
228240
else:
229-
cdata = node.inputs[1:]
230-
st_indices, st_is_bool = zip(
231-
*tuple(
232-
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
233-
)
234-
)
241+
idx_vars = node.inputs[1:]
242+
indices = tuple(as_index_variable(idx) for idx in idx_vars)
243+
244+
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
245+
# Note: For simplicity this also excludes subtensor related expand_dims (and empty indexing).
246+
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
247+
# and make use of the dimshuffle lift rewrite
248+
integer_dtypes = {type.dtype for type in integer_types}
249+
if any(
250+
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
251+
for idx in indices
252+
):
253+
return False
235254

236255
# Check that indexing does not act on support dims
237-
batched_ndims = rv.ndim - rv_op.ndim_supp
238-
if len(st_indices) > batched_ndims:
256+
batch_ndims = rv.ndim - rv_op.ndim_supp
257+
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not
258+
non_bool_indices = tuple(
259+
chain.from_iterable(
260+
idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,)
261+
for idx in indices
262+
)
263+
)
264+
if len(non_bool_indices) > batch_ndims:
239265
# If the last indexes are just dummy `slice(None)` we discard them
240-
st_is_bool = st_is_bool[:batched_ndims]
241-
st_indices, supp_indices = (
242-
st_indices[:batched_ndims],
243-
st_indices[batched_ndims:],
266+
non_bool_indices, supp_indices = (
267+
non_bool_indices[:batch_ndims],
268+
non_bool_indices[batch_ndims:],
244269
)
245-
for index in supp_indices:
270+
for idx in supp_indices:
246271
if not (
247-
isinstance(index.type, SliceType)
248-
and all(NoneConst.equals(i) for i in index.owner.inputs)
272+
isinstance(idx.type, SliceType)
273+
and all(NoneConst.equals(i) for i in idx.owner.inputs)
249274
):
250275
return False
276+
n_discarded_idxs = len(supp_indices)
277+
indices = indices[:-n_discarded_idxs]
251278

252279
# If no one else is using the underlying `RandomVariable`, then we can
253280
# do this; otherwise, the graph would be internally inconsistent.
254281
if is_rv_used_in_graph(rv, node, fgraph):
255282
return False
256283

257284
# Update the size to reflect the indexed dimensions
258-
# TODO: Could use `ShapeFeature` info. We would need to be sure that
259-
# `node` isn't in the results, though.
260-
# if hasattr(fgraph, "shape_feature"):
261-
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
262-
# else:
263-
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
264-
new_size_ignoring_boolean = (
265-
output_shape_ignoring_bool
266-
if rv_op.ndim_supp == 0
267-
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
268-
)
269-
270-
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
271-
# The `indexed_result_shape` helper does not consider this
272-
if any(st_is_bool):
273-
new_size = tuple(
274-
at_sum(idx) if is_bool else s
275-
for s, is_bool, idx in zip_longest(
276-
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
277-
)
278-
)
279-
else:
280-
new_size = new_size_ignoring_boolean
285+
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
281286

282287
# Update the parameters to reflect the indexed dimensions
288+
# We try to avoid broadcasting the parameters, by applying index on non-broadcastable dimensions
289+
# and avoiding indexing on broacastable ones.
283290
new_dist_params = []
284291
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
285-
# Apply indexing on the batched dimensions of the parameter
286-
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
287-
batched_param = shape_padleft(param, batched_param_dims_missing)
288-
batched_st_indices = []
289-
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
290-
# If we have a degenerate dimension indexing it should always do the job
291-
if batched_param_shape == 1:
292-
batched_st_indices.append(0)
292+
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
293+
batch_param = (
294+
shape_padleft(param, batch_param_dims_missing)
295+
if batch_param_dims_missing
296+
else param
297+
)
298+
# Check which dims are actually broadcasted
299+
bcast_batch_param_dims = tuple(
300+
dim
301+
for dim, (param_dim, output_dim) in enumerate(
302+
zip(batch_param.type.shape, rv.type.shape)
303+
)
304+
if (param_dim == 1) and (output_dim != 1)
305+
)
306+
batch_indices = []
307+
curr_dim = 0
308+
for idx in indices:
309+
# Check if any broadcasted dim overlaps with advanced boolean indexing.
310+
# If not, we use that directly, instead of the more inefficient `nonzero` form
311+
if is_nd_advanced_idx(idx, "bool"):
312+
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
313+
if set(bool_dims) & set(bcast_batch_param_dims):
314+
# There's an overlap, we have to decompose the boolean mask
315+
int_indices = list(idx.nonzero())
316+
# We index by 0 in the broadcasted dims
317+
for bool_dim in bool_dims:
318+
if bool_dim in bcast_batch_param_dims:
319+
int_indices[bool_dim - curr_dim] = 0
320+
batch_indices.extend(int_indices)
321+
else:
322+
# No overlap, use as is
323+
batch_indices.append(idx)
324+
curr_dim += len(bool_dims)
325+
# No boolean indexing
293326
else:
294-
batched_st_indices.append(st_index)
295-
new_dist_params.append(batched_param[tuple(batched_st_indices)])
327+
if curr_dim in bcast_batch_param_dims:
328+
# This degenerate dim will be dropped, we index by 0 which should work
329+
# Note: This wouldn't be the case if we allowed non-scalar integer indexing
330+
batch_indices.append(0)
331+
else:
332+
# We can use the index as is
333+
batch_indices.append(idx)
334+
curr_dim += 1
335+
336+
new_dist_params.append(batch_param[tuple(batch_indices)])
296337

297338
# Create new RV
298339
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
299340
new_rv = new_node.default_output()
300341

301-
if config.compute_test_value != "off":
302-
compute_test_value(new_node)
303-
304342
return [new_rv]

0 commit comments

Comments
 (0)