|
4 | 4 | from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
|
5 | 5 | from warnings import warn
|
6 | 6 |
|
| 7 | +import numpy as np |
| 8 | + |
7 | 9 | import pytensor
|
8 | 10 | import pytensor.scalar.basic as aes
|
9 | 11 | from pytensor import clone_replace, compile
|
|
28 | 30 | MakeVector,
|
29 | 31 | alloc,
|
30 | 32 | cast,
|
| 33 | + constant, |
31 | 34 | get_underlying_scalar_constant_value,
|
32 | 35 | )
|
33 | 36 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
34 | 37 | from pytensor.tensor.exceptions import NotScalarConstantError
|
35 | 38 | from pytensor.tensor.math import exp
|
36 |
| -from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize |
| 39 | +from pytensor.tensor.rewriting.basic import ( |
| 40 | + broadcast_like, |
| 41 | + register_canonicalize, |
| 42 | + register_specialize, |
| 43 | +) |
37 | 44 | from pytensor.tensor.shape import shape_padleft
|
38 |
| -from pytensor.tensor.var import TensorConstant |
| 45 | +from pytensor.tensor.var import TensorConstant, get_unique_constant_value |
39 | 46 |
|
40 | 47 |
|
41 | 48 | class InplaceElemwiseOptimizer(GraphRewriter):
|
@@ -552,6 +559,50 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
|
552 | 559 | return rval
|
553 | 560 |
|
554 | 561 |
|
| 562 | +@register_specialize |
| 563 | +@node_rewriter([Elemwise]) |
| 564 | +def local_replace_broadcasted_constant(fgraph, node): |
| 565 | + """Remove broadcasted constants from Elemwise graphs |
| 566 | +
|
| 567 | + Elemwise(scalar, ones((3, 4))) -> Alloc(Elemwise(scalar, ones((1, 1))), 3, 4) |
| 568 | +
|
| 569 | + This will avoid a useless iterations over constant arrays |
| 570 | + """ |
| 571 | + if len(node.inputs) == 1: |
| 572 | + return None |
| 573 | + |
| 574 | + new_elem_inps = [] |
| 575 | + ndims = node.outputs[0].type.ndim |
| 576 | + found_const = False |
| 577 | + for inp in node.inputs: |
| 578 | + # If input has non-broadcastable dims |
| 579 | + if not all(b for b in inp.type.broadcastable): |
| 580 | + constant_value = get_unique_constant_value(inp) |
| 581 | + if constant_value is not None: |
| 582 | + constant_value = np.expand_dims( |
| 583 | + constant_value, axis=tuple(range(ndims)) |
| 584 | + ).astype(inp.type.dtype) |
| 585 | + new_elem_inps.append(constant(constant_value)) |
| 586 | + found_const = True |
| 587 | + continue |
| 588 | + |
| 589 | + new_elem_inps.append(inp) |
| 590 | + |
| 591 | + if not found_const: |
| 592 | + return None |
| 593 | + |
| 594 | + new_outs = node.op.make_node(*new_elem_inps).outputs |
| 595 | + # The constants were needed to enforce the output shape |
| 596 | + if node.outputs[0].type.broadcastable != new_outs[0].type.broadcastable: |
| 597 | + new_outs = [ |
| 598 | + broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph) |
| 599 | + for new_out in new_outs |
| 600 | + ] |
| 601 | + |
| 602 | + copy_stack_trace(node.outputs, new_outs) |
| 603 | + return new_outs |
| 604 | + |
| 605 | + |
555 | 606 | @node_rewriter([Elemwise])
|
556 | 607 | def local_add_mul_fusion(fgraph, node):
|
557 | 608 | """Fuse consecutive add or mul in one such node with more inputs.
|
|
0 commit comments