Skip to content

Fix bug in storage_input alignment of the JAX backend #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 12, 2024

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 12, 2024

Description

When transpiling functions with Shared RNGs to the JAX backend, we replace the original variables by copies so that we can update them in the same container without changing the data type (JAX PRNGs vs numpy Generator), so that the same shared variable can be reused in other functions, which may use distinct backends.

There was a bug in that this replacement could change the position of the input variable in the FunctionGraph, and disalign it with the pre-defined input_storage list. This PR fixes the bug by forcing the position of the new input variable in the FunctionGraph to be the same as the one in input_storage list.

AFAICT the order of the FunctionGraph inputs is not stored/referenced anywhere internally, so this should be safe.

This showed up when trying to use PyMC VI with the JAX backend.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

When replacing the Shared RNG variables, the input order of the FunctionGraph was not explicitly aligned with the input storage of the function being compiled.
@ricardoV94 ricardoV94 added bug Something isn't working jax random variables labels Jan 12, 2024
@ricardoV94 ricardoV94 requested a review from ferrine January 12, 2024 12:55
@ricardoV94
Copy link
Member Author

ricardoV94 commented Jan 12, 2024

CC @abdalazizrashid

@codecov-commenter
Copy link

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (c5b96d9) 80.92% compared to head (4f79fd4) 80.93%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #587   +/-   ##
=======================================
  Coverage   80.92%   80.93%           
=======================================
  Files         162      162           
  Lines       46641    46644    +3     
  Branches    11399    11399           
=======================================
+ Hits        37746    37749    +3     
  Misses       6667     6667           
  Partials     2228     2228           
Files Coverage Δ
pytensor/link/jax/linker.py 95.23% <100.00%> (+0.36%) ⬆️

@ricardoV94 ricardoV94 merged commit 0666fd5 into pymc-devs:main Jan 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jax random variables
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants