Open
Description
Description
Currently we have an Op that calls np.tri
, but we can very easily build lower triangular mask matrices with _iota
:
from pytensor.tensor.einsum import _iota
def tri(M, N, k):
return ((_iota(M) + k) > _iota(N)).astype(int)
This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri
, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)
I suggest we wrap this in a dummy OpFromGraph like we do for Kron
and AllocDiag
so that the dprints are nicer. We can also overload the L_op
if we want? The current tri
has grad_undefined
, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota
function should be differentiable.