Skip to content

Commit 7a82a3f

Browse files
committed
Fix advanced indexing in subtensor_rv_lift
Also excludes the following cases: 1. expand_dims via broadcasting 2. multi-dimensional integer indexing (could lead to duplicates which is inconsitent with the lifted RV graph)
1 parent 9df54a5 commit 7a82a3f

File tree

2 files changed

+328
-93
lines changed

2 files changed

+328
-93
lines changed

pytensor/tensor/random/rewriting/basic.py

+111-61
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
from itertools import zip_longest
1+
from itertools import chain
22

33
from pytensor.compile import optdb
44
from pytensor.configdefaults import config
5+
from pytensor.graph import ancestors
56
from pytensor.graph.op import compute_test_value
6-
from pytensor.graph.rewriting.basic import in2out, node_rewriter
7+
from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter
8+
from pytensor.scalar import integer_types
79
from pytensor.tensor import NoneConst
810
from pytensor.tensor.basic import constant, get_vector_length
911
from pytensor.tensor.elemwise import DimShuffle
1012
from pytensor.tensor.extra_ops import broadcast_to
11-
from pytensor.tensor.math import sum as at_sum
1213
from pytensor.tensor.random.op import RandomVariable
1314
from pytensor.tensor.random.utils import broadcast_params
1415
from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
@@ -18,7 +19,6 @@
1819
Subtensor,
1920
as_index_variable,
2021
get_idx_list,
21-
indexed_result_shape,
2222
)
2323
from pytensor.tensor.type_other import SliceType
2424

@@ -207,98 +207,148 @@ def local_subtensor_rv_lift(fgraph, node):
207207
``mvnormal(mu, cov, size=(2,))[0, 0]``.
208208
"""
209209

210-
st_op = node.op
210+
def is_nd_advanced_idx(idx, dtype):
211+
if isinstance(dtype, str):
212+
return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1)
213+
else:
214+
return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1)
211215

212-
if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
213-
return False
216+
subtensor_op = node.op
214217

218+
old_subtensor = node.outputs[0]
215219
rv = node.inputs[0]
216220
rv_node = rv.owner
217221

218222
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
219223
return False
220224

225+
shape_feature = getattr(fgraph, "shape_feature", None)
226+
if not shape_feature:
227+
return None
228+
229+
# Use shape_feature to facilitate inferring final shape.
230+
# Check that neither the RV nor the old Subtensor are in the shape graph.
231+
output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None)
232+
if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)):
233+
return None
234+
221235
rv_op = rv_node.op
222236
rng, size, dtype, *dist_params = rv_node.inputs
223237

224238
# Parse indices
225-
idx_list = getattr(st_op, "idx_list", None)
239+
idx_list = getattr(subtensor_op, "idx_list", None)
226240
if idx_list:
227-
cdata = get_idx_list(node.inputs, idx_list)
241+
idx_vars = get_idx_list(node.inputs, idx_list)
228242
else:
229-
cdata = node.inputs[1:]
230-
st_indices, st_is_bool = zip(
231-
*tuple(
232-
(as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata
233-
)
234-
)
243+
idx_vars = node.inputs[1:]
244+
indices = tuple(as_index_variable(idx) for idx in idx_vars)
245+
246+
# The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates)
247+
# Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis).
248+
# If we wanted to support that we could rewrite it as subtensor + dimshuffle
249+
# and make use of the dimshuffle lift rewrite
250+
integer_dtypes = {type.dtype for type in integer_types}
251+
if any(
252+
is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx)
253+
for idx in indices
254+
):
255+
return False
235256

236257
# Check that indexing does not act on support dims
237-
batched_ndims = rv.ndim - rv_op.ndim_supp
238-
if len(st_indices) > batched_ndims:
239-
# If the last indexes are just dummy `slice(None)` we discard them
240-
st_is_bool = st_is_bool[:batched_ndims]
241-
st_indices, supp_indices = (
242-
st_indices[:batched_ndims],
243-
st_indices[batched_ndims:],
258+
batch_ndims = rv.ndim - rv_op.ndim_supp
259+
# We decompose the boolean indexes, which makes it clear whether they act on support dims or not
260+
non_bool_indices = tuple(
261+
chain.from_iterable(
262+
idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,)
263+
for idx in indices
264+
)
265+
)
266+
if len(non_bool_indices) > batch_ndims:
267+
# If the last indexes are just dummy `slice(None)` we discard them instead of quitting
268+
non_bool_indices, supp_indices = (
269+
non_bool_indices[:batch_ndims],
270+
non_bool_indices[batch_ndims:],
244271
)
245-
for index in supp_indices:
272+
for idx in supp_indices:
246273
if not (
247-
isinstance(index.type, SliceType)
248-
and all(NoneConst.equals(i) for i in index.owner.inputs)
274+
isinstance(idx.type, SliceType)
275+
and all(NoneConst.equals(i) for i in idx.owner.inputs)
249276
):
250277
return False
278+
n_discarded_idxs = len(supp_indices)
279+
indices = indices[:-n_discarded_idxs]
251280

252281
# If no one else is using the underlying `RandomVariable`, then we can
253282
# do this; otherwise, the graph would be internally inconsistent.
254283
if is_rv_used_in_graph(rv, node, fgraph):
255284
return False
256285

257286
# Update the size to reflect the indexed dimensions
258-
# TODO: Could use `ShapeFeature` info. We would need to be sure that
259-
# `node` isn't in the results, though.
260-
# if hasattr(fgraph, "shape_feature"):
261-
# output_shape = fgraph.shape_feature.shape_of(node.outputs[0])
262-
# else:
263-
output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices)
264-
new_size_ignoring_boolean = (
265-
output_shape_ignoring_bool
266-
if rv_op.ndim_supp == 0
267-
else output_shape_ignoring_bool[: -rv_op.ndim_supp]
268-
)
287+
new_size = output_shape[: len(output_shape) - rv_op.ndim_supp]
269288

270-
# Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used).
271-
# The `indexed_result_shape` helper does not consider this
272-
if any(st_is_bool):
273-
new_size = tuple(
274-
at_sum(idx) if is_bool else s
275-
for s, is_bool, idx in zip_longest(
276-
new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False
277-
)
278-
)
279-
else:
280-
new_size = new_size_ignoring_boolean
281-
282-
# Update the parameters to reflect the indexed dimensions
289+
# Propagate indexing to the parameters' batch dims.
290+
# We try to avoid broadcasting the parameters together (and with size), by only indexing
291+
# non-broadcastable (non-degenerate) parameter dims. These parameters and the new size
292+
# should still correctly broadcast any degenerate parameter dims.
283293
new_dist_params = []
284294
for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
285-
# Apply indexing on the batched dimensions of the parameter
286-
batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp)
287-
batched_param = shape_padleft(param, batched_param_dims_missing)
288-
batched_st_indices = []
289-
for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape):
290-
# If we have a degenerate dimension indexing it should always do the job
291-
if batched_param_shape == 1:
292-
batched_st_indices.append(0)
295+
# We first expand any missing parameter dims (and later index them away or keep them with none-slicing)
296+
batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp)
297+
batch_param = (
298+
shape_padleft(param, batch_param_dims_missing)
299+
if batch_param_dims_missing
300+
else param
301+
)
302+
# Check which dims are actually broadcasted
303+
bcast_batch_param_dims = tuple(
304+
dim
305+
for dim, (param_dim, output_dim) in enumerate(
306+
zip(batch_param.type.shape, rv.type.shape)
307+
)
308+
if (param_dim == 1) and (output_dim != 1)
309+
)
310+
batch_indices = []
311+
curr_dim = 0
312+
for idx in indices:
313+
# Advanced boolean indexing
314+
if is_nd_advanced_idx(idx, "bool"):
315+
# Check if any broadcasted dim overlaps with advanced boolean indexing.
316+
# If not, we use that directly, instead of the more inefficient `nonzero` form
317+
bool_dims = range(curr_dim, curr_dim + idx.type.ndim)
318+
# There's an overlap, we have to decompose the boolean mask as a `nonzero`
319+
if set(bool_dims) & set(bcast_batch_param_dims):
320+
int_indices = list(idx.nonzero())
321+
# Indexing by 0 drops the degenerate dims
322+
for bool_dim in bool_dims:
323+
if bool_dim in bcast_batch_param_dims:
324+
int_indices[bool_dim - curr_dim] = 0
325+
batch_indices.extend(int_indices)
326+
# No overlap, use index as is
327+
else:
328+
batch_indices.append(idx)
329+
curr_dim += len(bool_dims)
330+
# Basic-indexing (slice or integer)
293331
else:
294-
batched_st_indices.append(st_index)
295-
new_dist_params.append(batched_param[tuple(batched_st_indices)])
332+
# Broadcasted dim
333+
if curr_dim in bcast_batch_param_dims:
334+
# Slice indexing, keep degenerate dim by none-slicing
335+
if isinstance(idx.type, SliceType):
336+
batch_indices.append(slice(None))
337+
# Integer indexing, drop degenerate dim by 0-indexing
338+
else:
339+
batch_indices.append(0)
340+
# Non-broadcasted dim
341+
else:
342+
# Use index as is
343+
batch_indices.append(idx)
344+
curr_dim += 1
345+
346+
new_dist_params.append(batch_param[tuple(batch_indices)])
296347

297348
# Create new RV
298349
new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
299350
new_rv = new_node.default_output()
300351

301-
if config.compute_test_value != "off":
302-
compute_test_value(new_node)
352+
copy_stack_trace(rv, new_rv)
303353

304354
return [new_rv]

0 commit comments

Comments
 (0)