Skip to content

lu_solve doesn't work with batched inputs #1376

Open
@ricardoV94

Description

@ricardoV94

Description

import pytensor.tensor as pt
A = pt.tensor3("A")
b = pt.vector("b")

lu_and_pivots = pt.linalg.lu_factor(A)
x = pt.linalg.lu_solve(lu_and_pivots, b, b_ndim=1)  # ValueError: PivotToPermutations only works on 1-D inputs

We'll need to maybe use vectorize, because just blockwising PivotToPermutations doesn't cut it. We may also consider fusing the PivotToPermutations and the index, to avoid a copy introduced by AdvancedIndexing

This is required for #1374

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions