Skip to content

Simplify dispatch of JAX random variables by handling rng split automatically #1315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 24, 2025

Conversation

educhesne
Copy link
Contributor

@educhesne educhesne commented Mar 22, 2025

Description

Move the jax rng splits from the jax_sample_fn of each RandomVariable dispatch to jax_funcify_RandomVariable

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify): refacto

📚 Documentation preview 📚: https://pytensor--1315.org.readthedocs.build/en/1315/

Copy link

codecov bot commented Mar 23, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.98%. Comparing base (95ce102) to head (98443f1).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1315      +/-   ##
==========================================
- Coverage   81.99%   81.98%   -0.02%     
==========================================
  Files         188      188              
  Lines       48508    48474      -34     
  Branches     8672     8672              
==========================================
- Hits        39773    39739      -34     
  Misses       6583     6583              
  Partials     2152     2152              
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/random.py 92.72% <100.00%> (-0.98%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 merged commit 9ab8df5 into pymc-devs:main Mar 24, 2025
74 checks passed
@ricardoV94
Copy link
Member

Thanks @educhesne, much cleaner now

@educhesne educhesne deleted the rng_jax_dispatch branch March 24, 2025 09:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Simplify dispatch of JAX random variables by handling rng split automatically
2 participants