Skip to content

Solve to Solve LU optimization #1374

Open
@ricardoV94

Description

@ricardoV94

Description

Unfinished optimization:

from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.slinalg import Solve

class GlobalSolveToLUSolve(GraphRewriter):
    
    def __init__(self, eager: bool):
        self.eager = eager

    def apply(self, fgraph):
        
        def A_is_expand_dims_or_transpose(A):
        
            def is_matrix_transpose(node):
                if not isinstance(node.op, DimShuffle):
                    return False
        
                if node.op.drop:
                    return False
        
                order = list(node.op.new_order)
                while order[0] == "x":
                    order.pop(0)
        
                mt_order = list(range(len(order)))
                mt_order[-2:] = reversed(mt_order[-2:])
                return mt_order == order
        
            return (
                A.owner is not None
                and isinstance(A.owner.op, DimShuffle)
                and (
                    is_matrix_transpose(A.owner)
                    or A.owner.op.is_left_expand_dims
                )
            )
    
        nodes = fn.maker.fgraph.toposort()
        solve_nodes = [node for node in nodes if isinstance(node.op, Solve)]
        assert len(solve_nodes) > 1
        
        for i, solve_node in enumerate(solve_nodes):
            A, b, A_is_transposed = *solve_node.inputs, False
            if A_is_expand_dims_or_transpose(A):
                A, transpose = A.owner.inputs[0], not A.owner.op.is_left_expand_dims
        
            info = [(b, A_is_transposed, solve_node.outputs[0])]
            for j, other_solve_node in enumerate(solve_nodes):
                if i == j:
                    continue
                other_A, other_b = other_solve_node.inputs
                if other_A is A:
                    transpose = False
                elif A_is_expand_dims_or_transpose(other_A) and other_A is A:
                    transpose = not other_A.owner.op.is_left_expand_dims
                info.append((other_b, transpose, other_solve_node.outputs[0]))
                
            if self.eager or len(info) > 1 or A_is_broadcasted(info[-1][-1]):
                lu_and_pivots = pt.linalg.lu_factor(info[0][0])
                replacements = tuple(
                    (
                        old_out,
                        pt.linalg.lu_solve(lu_and_pivots, b, trans)
                    ) for b, trans, old_out in info
                )

            toposort_replace(fgraph, replacements)
    
        return ...

TODO:

  • Implement A_is_broadcasted
  • Bring toposort_replace from PyMC
  • Scan rewrite that forces this in the inner Scan graph eagerly, before the pushout_nonsequences rewrite is triggered

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions