Skip to content

Constant fold branches of variadic add/mul #1422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 27, 2025

Refactoring and renaming:

  • Renamed the local_add_mul_fusion function to flatten_nested_add_mul to more precisely reflect how it works (one could also fuse non-nested add/mul, like the FusionOptimizer does). The function now explicitly tracks add and mul operations instead of relying on generic Elemwise checks. [1] [2] [3]

New optimization for constant folding:

  • Introduced a new rewrite function, constant_fold_branches_of_add_mul, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database, add_mul_flat_seqopt, which runs before generic elementwise fusion.

The two rewrites are pulled out to a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)


📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/

Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t know why the test fails but it looks like the fusion rewrite is applied only once. Maybe the equilibrium rewrite that you took out should be added back in?

@lucianopaz
Copy link
Member

lucianopaz commented May 29, 2025

@ricardoV94, I just went through your branch's code and found that the error is coming from the fact that the TestFusion class is including: "canonicalize", "fusion", and "inplace" rewrite databases. The add_mul flatten and fusion rewrites that you moved or added here are only included in "fast_run". My question then is whether your rewrites should also be added to fusion, or if you only want to add the fast_run database rewrites to the TestFusion includes?

@ricardoV94 ricardoV94 force-pushed the constant_fold_variadic_add_mul branch from ca12b58 to 082e1b7 Compare May 30, 2025 10:24
@@ -238,6 +238,7 @@ class TestFusion:
include=[
"canonicalize",
"fusion",
"add_mul_flat",
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was the change needed to get the test to pass @lucianopaz

@ricardoV94
Copy link
Member Author

@lucianopaz I came to the same conclusion, I just added the rewrite explicitly. Mentioned in an inline comment above

@ricardoV94 ricardoV94 requested a review from lucianopaz May 30, 2025 10:30
@ricardoV94 ricardoV94 force-pushed the constant_fold_variadic_add_mul branch from 082e1b7 to 70db72e Compare May 30, 2025 10:35
Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks @ricardoV94 !

Copy link

codecov bot commented May 30, 2025

Codecov Report

Attention: Patch coverage is 95.34884% with 2 lines in your changes missing coverage. Please review.

Project coverage is 82.13%. Comparing base (5a462e9) to head (70db72e).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/elemwise.py 95.34% 1 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (95.34%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1422      +/-   ##
==========================================
+ Coverage   82.11%   82.13%   +0.01%     
==========================================
  Files         211      211              
  Lines       49743    49773      +30     
  Branches     8824     8830       +6     
==========================================
+ Hits        40847    40879      +32     
+ Misses       6715     6714       -1     
+ Partials     2181     2180       -1     
Files with missing lines Coverage Δ
pytensor/tensor/rewriting/elemwise.py 91.77% <95.34%> (+0.74%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 merged commit ff09268 into pymc-devs:main May 30, 2025
72 of 73 checks passed
@ricardoV94 ricardoV94 deleted the constant_fold_variadic_add_mul branch May 30, 2025 11:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants