-
Notifications
You must be signed in to change notification settings - Fork 132
Constant fold branches of variadic add/mul #1422
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Constant fold branches of variadic add/mul #1422
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don’t know why the test fails but it looks like the fusion rewrite is applied only once. Maybe the equilibrium rewrite that you took out should be added back in?
@ricardoV94, I just went through your branch's code and found that the error is coming from the fact that the |
ca12b58
to
082e1b7
Compare
@@ -238,6 +238,7 @@ class TestFusion: | |||
include=[ | |||
"canonicalize", | |||
"fusion", | |||
"add_mul_flat", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the change needed to get the test to pass @lucianopaz
@lucianopaz I came to the same conclusion, I just added the rewrite explicitly. Mentioned in an inline comment above |
082e1b7
to
70db72e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks @ricardoV94 !
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (95.34%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1422 +/- ##
==========================================
+ Coverage 82.11% 82.13% +0.01%
==========================================
Files 211 211
Lines 49743 49773 +30
Branches 8824 8830 +6
==========================================
+ Hits 40847 40879 +32
+ Misses 6715 6714 -1
+ Partials 2181 2180 -1
🚀 New features to boost your workflow:
|
Refactoring and renaming:
local_add_mul_fusion
function toflatten_nested_add_mul
to more precisely reflect how it works (one could also fuse non-nested add/mul, like the FusionOptimizer does). The function now explicitly tracksadd
andmul
operations instead of relying on genericElemwise
checks. [1] [2] [3]New optimization for constant folding:
constant_fold_branches_of_add_mul
, which folds constants in add/mul operations when it does not result in higher intermediate memory usage. This optimization is registered in a new sequence database,add_mul_flat_seqopt
, which runs before generic elementwise fusion.The two rewrites are pulled out to a separate database so it's included in JAX rewrites (JAX does not include fusion rewrites). We've found this could help avoding XLA constant fold (CC @lucianopaz)
📚 Documentation preview 📚: https://pytensor--1422.org.readthedocs.build/en/1422/