|
1 |
| -from itertools import zip_longest |
| 1 | +from itertools import chain |
2 | 2 |
|
3 | 3 | from pytensor.compile import optdb
|
4 | 4 | from pytensor.configdefaults import config
|
| 5 | +from pytensor.graph import ancestors |
5 | 6 | from pytensor.graph.op import compute_test_value
|
6 |
| -from pytensor.graph.rewriting.basic import in2out, node_rewriter |
| 7 | +from pytensor.graph.rewriting.basic import copy_stack_trace, in2out, node_rewriter |
| 8 | +from pytensor.scalar import integer_types |
7 | 9 | from pytensor.tensor import NoneConst
|
8 | 10 | from pytensor.tensor.basic import constant, get_vector_length
|
9 | 11 | from pytensor.tensor.elemwise import DimShuffle
|
10 | 12 | from pytensor.tensor.extra_ops import broadcast_to
|
11 |
| -from pytensor.tensor.math import sum as at_sum |
12 | 13 | from pytensor.tensor.random.op import RandomVariable
|
13 | 14 | from pytensor.tensor.random.utils import broadcast_params
|
14 | 15 | from pytensor.tensor.shape import Shape, Shape_i, shape_padleft
|
|
18 | 19 | Subtensor,
|
19 | 20 | as_index_variable,
|
20 | 21 | get_idx_list,
|
21 |
| - indexed_result_shape, |
22 | 22 | )
|
23 | 23 | from pytensor.tensor.type_other import SliceType
|
24 | 24 |
|
@@ -207,98 +207,148 @@ def local_subtensor_rv_lift(fgraph, node):
|
207 | 207 | ``mvnormal(mu, cov, size=(2,))[0, 0]``.
|
208 | 208 | """
|
209 | 209 |
|
210 |
| - st_op = node.op |
| 210 | + def is_nd_advanced_idx(idx, dtype): |
| 211 | + if isinstance(dtype, str): |
| 212 | + return (getattr(idx.type, "dtype", None) == dtype) and (idx.type.ndim >= 1) |
| 213 | + else: |
| 214 | + return (getattr(idx.type, "dtype", None) in dtype) and (idx.type.ndim >= 1) |
211 | 215 |
|
212 |
| - if not isinstance(st_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)): |
213 |
| - return False |
| 216 | + subtensor_op = node.op |
214 | 217 |
|
| 218 | + old_subtensor = node.outputs[0] |
215 | 219 | rv = node.inputs[0]
|
216 | 220 | rv_node = rv.owner
|
217 | 221 |
|
218 | 222 | if not (rv_node and isinstance(rv_node.op, RandomVariable)):
|
219 | 223 | return False
|
220 | 224 |
|
| 225 | + shape_feature = getattr(fgraph, "shape_feature", None) |
| 226 | + if not shape_feature: |
| 227 | + return None |
| 228 | + |
| 229 | + # Use shape_feature to facilitate inferring final shape. |
| 230 | + # Check that neither the RV nor the old Subtensor are in the shape graph. |
| 231 | + output_shape = fgraph.shape_feature.shape_of.get(old_subtensor, None) |
| 232 | + if output_shape is None or {old_subtensor, rv} & set(ancestors(output_shape)): |
| 233 | + return None |
| 234 | + |
221 | 235 | rv_op = rv_node.op
|
222 | 236 | rng, size, dtype, *dist_params = rv_node.inputs
|
223 | 237 |
|
224 | 238 | # Parse indices
|
225 |
| - idx_list = getattr(st_op, "idx_list", None) |
| 239 | + idx_list = getattr(subtensor_op, "idx_list", None) |
226 | 240 | if idx_list:
|
227 |
| - cdata = get_idx_list(node.inputs, idx_list) |
| 241 | + idx_vars = get_idx_list(node.inputs, idx_list) |
228 | 242 | else:
|
229 |
| - cdata = node.inputs[1:] |
230 |
| - st_indices, st_is_bool = zip( |
231 |
| - *tuple( |
232 |
| - (as_index_variable(i), getattr(i, "dtype", None) == "bool") for i in cdata |
233 |
| - ) |
234 |
| - ) |
| 243 | + idx_vars = node.inputs[1:] |
| 244 | + indices = tuple(as_index_variable(idx) for idx in idx_vars) |
| 245 | + |
| 246 | + # The rewrite doesn't apply if advanced indexing could broadcast the samples (leading to duplicates) |
| 247 | + # Note: For simplicity this also excludes subtensor-related expand_dims (np.newaxis). |
| 248 | + # If we wanted to support that we could rewrite it as subtensor + dimshuffle |
| 249 | + # and make use of the dimshuffle lift rewrite |
| 250 | + integer_dtypes = {type.dtype for type in integer_types} |
| 251 | + if any( |
| 252 | + is_nd_advanced_idx(idx, integer_dtypes) or NoneConst.equals(idx) |
| 253 | + for idx in indices |
| 254 | + ): |
| 255 | + return False |
235 | 256 |
|
236 | 257 | # Check that indexing does not act on support dims
|
237 |
| - batched_ndims = rv.ndim - rv_op.ndim_supp |
238 |
| - if len(st_indices) > batched_ndims: |
239 |
| - # If the last indexes are just dummy `slice(None)` we discard them |
240 |
| - st_is_bool = st_is_bool[:batched_ndims] |
241 |
| - st_indices, supp_indices = ( |
242 |
| - st_indices[:batched_ndims], |
243 |
| - st_indices[batched_ndims:], |
| 258 | + batch_ndims = rv.ndim - rv_op.ndim_supp |
| 259 | + # We decompose the boolean indexes, which makes it clear whether they act on support dims or not |
| 260 | + non_bool_indices = tuple( |
| 261 | + chain.from_iterable( |
| 262 | + idx.nonzero() if is_nd_advanced_idx(idx, "bool") else (idx,) |
| 263 | + for idx in indices |
| 264 | + ) |
| 265 | + ) |
| 266 | + if len(non_bool_indices) > batch_ndims: |
| 267 | + # If the last indexes are just dummy `slice(None)` we discard them instead of quitting |
| 268 | + non_bool_indices, supp_indices = ( |
| 269 | + non_bool_indices[:batch_ndims], |
| 270 | + non_bool_indices[batch_ndims:], |
244 | 271 | )
|
245 |
| - for index in supp_indices: |
| 272 | + for idx in supp_indices: |
246 | 273 | if not (
|
247 |
| - isinstance(index.type, SliceType) |
248 |
| - and all(NoneConst.equals(i) for i in index.owner.inputs) |
| 274 | + isinstance(idx.type, SliceType) |
| 275 | + and all(NoneConst.equals(i) for i in idx.owner.inputs) |
249 | 276 | ):
|
250 | 277 | return False
|
| 278 | + n_discarded_idxs = len(supp_indices) |
| 279 | + indices = indices[:-n_discarded_idxs] |
251 | 280 |
|
252 | 281 | # If no one else is using the underlying `RandomVariable`, then we can
|
253 | 282 | # do this; otherwise, the graph would be internally inconsistent.
|
254 | 283 | if is_rv_used_in_graph(rv, node, fgraph):
|
255 | 284 | return False
|
256 | 285 |
|
257 | 286 | # Update the size to reflect the indexed dimensions
|
258 |
| - # TODO: Could use `ShapeFeature` info. We would need to be sure that |
259 |
| - # `node` isn't in the results, though. |
260 |
| - # if hasattr(fgraph, "shape_feature"): |
261 |
| - # output_shape = fgraph.shape_feature.shape_of(node.outputs[0]) |
262 |
| - # else: |
263 |
| - output_shape_ignoring_bool = indexed_result_shape(rv.shape, st_indices) |
264 |
| - new_size_ignoring_boolean = ( |
265 |
| - output_shape_ignoring_bool |
266 |
| - if rv_op.ndim_supp == 0 |
267 |
| - else output_shape_ignoring_bool[: -rv_op.ndim_supp] |
268 |
| - ) |
| 287 | + new_size = output_shape[: len(output_shape) - rv_op.ndim_supp] |
269 | 288 |
|
270 |
| - # Boolean indices can actually change the `size` value (compared to just *which* dimensions of `size` are used). |
271 |
| - # The `indexed_result_shape` helper does not consider this |
272 |
| - if any(st_is_bool): |
273 |
| - new_size = tuple( |
274 |
| - at_sum(idx) if is_bool else s |
275 |
| - for s, is_bool, idx in zip_longest( |
276 |
| - new_size_ignoring_boolean, st_is_bool, st_indices, fillvalue=False |
277 |
| - ) |
278 |
| - ) |
279 |
| - else: |
280 |
| - new_size = new_size_ignoring_boolean |
281 |
| - |
282 |
| - # Update the parameters to reflect the indexed dimensions |
| 289 | + # Propagate indexing to the parameters' batch dims. |
| 290 | + # We try to avoid broadcasting the parameters together (and with size), by only indexing |
| 291 | + # non-broadcastable (non-degenerate) parameter dims. These parameters and the new size |
| 292 | + # should still correctly broadcast any degenerate parameter dims. |
283 | 293 | new_dist_params = []
|
284 | 294 | for param, param_ndim_supp in zip(dist_params, rv_op.ndims_params):
|
285 |
| - # Apply indexing on the batched dimensions of the parameter |
286 |
| - batched_param_dims_missing = batched_ndims - (param.ndim - param_ndim_supp) |
287 |
| - batched_param = shape_padleft(param, batched_param_dims_missing) |
288 |
| - batched_st_indices = [] |
289 |
| - for st_index, batched_param_shape in zip(st_indices, batched_param.type.shape): |
290 |
| - # If we have a degenerate dimension indexing it should always do the job |
291 |
| - if batched_param_shape == 1: |
292 |
| - batched_st_indices.append(0) |
| 295 | + # We first expand any missing parameter dims (and later index them away or keep them with none-slicing) |
| 296 | + batch_param_dims_missing = batch_ndims - (param.ndim - param_ndim_supp) |
| 297 | + batch_param = ( |
| 298 | + shape_padleft(param, batch_param_dims_missing) |
| 299 | + if batch_param_dims_missing |
| 300 | + else param |
| 301 | + ) |
| 302 | + # Check which dims are actually broadcasted |
| 303 | + bcast_batch_param_dims = tuple( |
| 304 | + dim |
| 305 | + for dim, (param_dim, output_dim) in enumerate( |
| 306 | + zip(batch_param.type.shape, rv.type.shape) |
| 307 | + ) |
| 308 | + if (param_dim == 1) and (output_dim != 1) |
| 309 | + ) |
| 310 | + batch_indices = [] |
| 311 | + curr_dim = 0 |
| 312 | + for idx in indices: |
| 313 | + # Advanced boolean indexing |
| 314 | + if is_nd_advanced_idx(idx, "bool"): |
| 315 | + # Check if any broadcasted dim overlaps with advanced boolean indexing. |
| 316 | + # If not, we use that directly, instead of the more inefficient `nonzero` form |
| 317 | + bool_dims = range(curr_dim, curr_dim + idx.type.ndim) |
| 318 | + # There's an overlap, we have to decompose the boolean mask as a `nonzero` |
| 319 | + if set(bool_dims) & set(bcast_batch_param_dims): |
| 320 | + int_indices = list(idx.nonzero()) |
| 321 | + # Indexing by 0 drops the degenerate dims |
| 322 | + for bool_dim in bool_dims: |
| 323 | + if bool_dim in bcast_batch_param_dims: |
| 324 | + int_indices[bool_dim - curr_dim] = 0 |
| 325 | + batch_indices.extend(int_indices) |
| 326 | + # No overlap, use index as is |
| 327 | + else: |
| 328 | + batch_indices.append(idx) |
| 329 | + curr_dim += len(bool_dims) |
| 330 | + # Basic-indexing (slice or integer) |
293 | 331 | else:
|
294 |
| - batched_st_indices.append(st_index) |
295 |
| - new_dist_params.append(batched_param[tuple(batched_st_indices)]) |
| 332 | + # Broadcasted dim |
| 333 | + if curr_dim in bcast_batch_param_dims: |
| 334 | + # Slice indexing, keep degenerate dim by none-slicing |
| 335 | + if isinstance(idx.type, SliceType): |
| 336 | + batch_indices.append(slice(None)) |
| 337 | + # Integer indexing, drop degenerate dim by 0-indexing |
| 338 | + else: |
| 339 | + batch_indices.append(0) |
| 340 | + # Non-broadcasted dim |
| 341 | + else: |
| 342 | + # Use index as is |
| 343 | + batch_indices.append(idx) |
| 344 | + curr_dim += 1 |
| 345 | + |
| 346 | + new_dist_params.append(batch_param[tuple(batch_indices)]) |
296 | 347 |
|
297 | 348 | # Create new RV
|
298 | 349 | new_node = rv_op.make_node(rng, new_size, dtype, *new_dist_params)
|
299 | 350 | new_rv = new_node.default_output()
|
300 | 351 |
|
301 |
| - if config.compute_test_value != "off": |
302 |
| - compute_test_value(new_node) |
| 352 | + copy_stack_trace(rv, new_rv) |
303 | 353 |
|
304 | 354 | return [new_rv]
|
0 commit comments