Skip to content

Add jax implementation of pt.linalg.pinv #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 13, 2023

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented May 13, 2023

Motivation for these changes

Add a jax implementation of pt.linalg.pinv

Implementation details

Not the most elaborate PR:

@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
    def pinv(x):
        return jnp.linalg.pinv(x)

    return pinv

Checklist

Major / Breaking Changes

None

New features

You can compile graphs with pinv to JAX

Bugfixes

None

Documentation

None

Maintenance

None

@ricardoV94
Copy link
Member

Thanks @jessegrabowski

@jessegrabowski jessegrabowski deleted the jax_pinv branch May 13, 2023 23:44
@ricardoV94
Copy link
Member

ricardoV94 commented May 14, 2023

Oops I missed one float32 test that was failing: https://github.com/pymc-devs/pytensor/actions/runs/4968627533/jobs/8891362025?pr=294

@jessegrabowski
Copy link
Member Author

jessegrabowski commented May 14, 2023

I didn't add astype(config.floatX) to the test array (again...). Not sure how to proceed -- I deleted the branch on my fork already so it's not auto-pushing the change into this PR.

@ricardoV94
Copy link
Member

No worries, I opened a fix PR in #296

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants