Skip to content

Replace our Tri op with an OpFromGraph #1265

Open
@jessegrabowski

Description

@jessegrabowski

Description

Currently we have an Op that calls np.tri, but we can very easily build lower triangular mask matrices with _iota:

from pytensor.tensor.einsum import _iota
def tri(M, N, k):
    return ((_iota(M) + k) > _iota(N)).astype(int)

This is what jax does. The benefit of doing things this way is that we'll automatically have a dispatchable Op for Numba (numba supports np.tri, but only under specific circumstances -- I tried a naive dispatch and it didn't work ) and Pytorch (#821 asks for Tri, so this would check off that box)

I suggest we wrap this in a dummy OpFromGraph like we do for Kron and AllocDiag so that the dprints are nicer. We can also overload the L_op if we want? The current tri has grad_undefined, so we could keep that if it's correct. Or just keep the autodiff solution -- the proposed _iota function should be differentiable.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions