Open
Description
Description
import pytensor.tensor as pt
A = pt.tensor3("A")
b = pt.vector("b")
lu_and_pivots = pt.linalg.lu_factor(A)
x = pt.linalg.lu_solve(lu_and_pivots, b, b_ndim=1) # ValueError: PivotToPermutations only works on 1-D inputs
We'll need to maybe use vectorize
, because just blockwising PivotToPermutations
doesn't cut it. We may also consider fusing the PivotToPermutations and the index, to avoid a copy introduced by AdvancedIndexing
This is required for #1374