|
4 | 4 | from typing import DefaultDict, Generator, List, Set, Tuple, TypeVar
|
5 | 5 | from warnings import warn
|
6 | 6 |
|
| 7 | +import numpy as np |
| 8 | + |
7 | 9 | import pytensor
|
8 | 10 | import pytensor.scalar.basic as aes
|
9 | 11 | from pytensor import clone_replace, compile
|
|
28 | 30 | MakeVector,
|
29 | 31 | alloc,
|
30 | 32 | cast,
|
| 33 | + constant, |
31 | 34 | get_underlying_scalar_constant_value,
|
32 | 35 | )
|
33 | 36 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
34 | 37 | from pytensor.tensor.exceptions import NotScalarConstantError
|
35 | 38 | from pytensor.tensor.math import exp
|
36 |
| -from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize |
| 39 | +from pytensor.tensor.rewriting.basic import ( |
| 40 | + broadcast_like, |
| 41 | + register_canonicalize, |
| 42 | + register_specialize, |
| 43 | +) |
37 | 44 | from pytensor.tensor.shape import shape_padleft
|
38 |
| -from pytensor.tensor.var import TensorConstant |
| 45 | +from pytensor.tensor.var import TensorConstant, get_unique_constant_value |
39 | 46 |
|
40 | 47 |
|
41 | 48 | class InplaceElemwiseOptimizer(GraphRewriter):
|
@@ -1295,6 +1302,65 @@ def local_inline_composite_constants(fgraph, node):
|
1295 | 1302 | )
|
1296 | 1303 |
|
1297 | 1304 |
|
| 1305 | +@node_rewriter([Elemwise]) |
| 1306 | +def local_replace_broadcasted_constants(fgraph, node): |
| 1307 | + """Remove broadcasted constants from Elemwise graphs |
| 1308 | +
|
| 1309 | + Elemwise(matrix, ones((3, 4))) -> Elemwise(vector, ones((1, 1))) |
| 1310 | +
|
| 1311 | + In cases where the constant influenced the final shape of the Elemwise operation |
| 1312 | + We broadcast (via alloc) the new Elemwise result: |
| 1313 | +
|
| 1314 | + Elemwise(row, ones((3, 4))) -> Alloc(Elemwise(row, ones((1, 1))), 3, 4) |
| 1315 | +
|
| 1316 | + This will avoid useless iterations over constant arrays. |
| 1317 | + """ |
| 1318 | + if len(node.inputs) == 1: |
| 1319 | + return None |
| 1320 | + |
| 1321 | + new_elem_inps = [] |
| 1322 | + ndims = node.outputs[0].type.ndim |
| 1323 | + found_const = False |
| 1324 | + for inp in node.inputs: |
| 1325 | + # If input has non-broadcastable dims |
| 1326 | + if not all(b for b in inp.type.broadcastable): |
| 1327 | + constant_value = get_unique_constant_value(inp) |
| 1328 | + if constant_value is not None: |
| 1329 | + constant_value = np.expand_dims( |
| 1330 | + constant_value, axis=tuple(range(ndims)) |
| 1331 | + ).astype(inp.type.dtype) |
| 1332 | + new_elem_inps.append(constant(constant_value)) |
| 1333 | + found_const = True |
| 1334 | + continue |
| 1335 | + |
| 1336 | + new_elem_inps.append(inp) |
| 1337 | + |
| 1338 | + if not found_const: |
| 1339 | + return None |
| 1340 | + |
| 1341 | + new_outs = node.op.make_node(*new_elem_inps).outputs |
| 1342 | + |
| 1343 | + # The constants were needed to enforce the output shape |
| 1344 | + if node.outputs[0].type.broadcastable != new_outs[0].type.broadcastable: |
| 1345 | + new_outs = [ |
| 1346 | + broadcast_like(new_out, template=node.outputs[0], fgraph=fgraph) |
| 1347 | + for new_out in new_outs |
| 1348 | + ] |
| 1349 | + |
| 1350 | + copy_stack_trace(node.outputs, new_outs) |
| 1351 | + return new_outs |
| 1352 | + |
| 1353 | + |
| 1354 | +# We register this immediately after the fusion database. |
| 1355 | +# We don't want Allocs to break up the fusion rewrites |
| 1356 | +compile.optdb.register( |
| 1357 | + "local_replace_broadcasted_constants", |
| 1358 | + in2out(local_replace_broadcasted_constants), |
| 1359 | + "fast_run", |
| 1360 | + position=49.01, |
| 1361 | +) |
| 1362 | + |
| 1363 | + |
1298 | 1364 | def _rebuild_partial_2f1grad_loop(node, wrt):
|
1299 | 1365 | a, b, c, log_z, sign_z = node.inputs[-5:]
|
1300 | 1366 | z = exp(log_z) * sign_z
|
|
0 commit comments