Open
Description
Description
Unfinished optimization:
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.slinalg import Solve
class GlobalSolveToLUSolve(GraphRewriter):
def __init__(self, eager: bool):
self.eager = eager
def apply(self, fgraph):
def A_is_expand_dims_or_transpose(A):
def is_matrix_transpose(node):
if not isinstance(node.op, DimShuffle):
return False
if node.op.drop:
return False
order = list(node.op.new_order)
while order[0] == "x":
order.pop(0)
mt_order = list(range(len(order)))
mt_order[-2:] = reversed(mt_order[-2:])
return mt_order == order
return (
A.owner is not None
and isinstance(A.owner.op, DimShuffle)
and (
is_matrix_transpose(A.owner)
or A.owner.op.is_left_expand_dims
)
)
nodes = fn.maker.fgraph.toposort()
solve_nodes = [node for node in nodes if isinstance(node.op, Solve)]
assert len(solve_nodes) > 1
for i, solve_node in enumerate(solve_nodes):
A, b, A_is_transposed = *solve_node.inputs, False
if A_is_expand_dims_or_transpose(A):
A, transpose = A.owner.inputs[0], not A.owner.op.is_left_expand_dims
info = [(b, A_is_transposed, solve_node.outputs[0])]
for j, other_solve_node in enumerate(solve_nodes):
if i == j:
continue
other_A, other_b = other_solve_node.inputs
if other_A is A:
transpose = False
elif A_is_expand_dims_or_transpose(other_A) and other_A is A:
transpose = not other_A.owner.op.is_left_expand_dims
info.append((other_b, transpose, other_solve_node.outputs[0]))
if self.eager or len(info) > 1 or A_is_broadcasted(info[-1][-1]):
lu_and_pivots = pt.linalg.lu_factor(info[0][0])
replacements = tuple(
(
old_out,
pt.linalg.lu_solve(lu_and_pivots, b, trans)
) for b, trans, old_out in info
)
toposort_replace(fgraph, replacements)
return ...
TODO:
- Implement
A_is_broadcasted
- Bring toposort_replace from PyMC
- Scan rewrite that forces this in the inner Scan graph eagerly, before the
pushout_nonsequences
rewrite is triggered