-
Notifications
You must be signed in to change notification settings - Fork 129
Rewrite products of exponents as exponent of sum #186
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great. I left some comments below
The decorators convert the function into a "proper rewriter". Something you can call import pytensor
import pytensor.tensor as pt
from pytensor.tensor.rewriting.math import local_log1p
x = pt.scalar("x")
y = pt.log(1 + x)
new_y = local_log1p.transform(None, y.owner)
pytensor.dprint(new_y)
# Elemwise{second,no_inplace} [id A]
# |TensorConstant{1} [id B]
# |Elemwise{log1p,no_inplace} [id C]
# |x [id D]
I think it is fine just as a specialization. The canonicalize database is supposed to include rewrites that convert different yet equivalent graphs into a common or "canonical" form, so that other rewrites can build on top of these forms. The specialize database is supposed to include rewrites that convert a graph into more efficient forms. It's not always obvious which category a rewrite should go into. |
I realize now you were asking only about the canonicalize and specialize decorators. Those simply register the rewrite in the respective databases. There is some documentation here: https://pytensor.readthedocs.io/en/latest/extending/graph_rewriting.html#registering-a-noderewriter |
Hi, I've implemented what we discussed above re the multiplication of more than two factors, and I added a similar node rewriter for the pow() op too. Please review. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, left just some small nitpicks in my review
…e e^x*e^y to e^(x+y), e^x/e^y to e^(x-y).
…than two factors with some of which may not be an exp
…d a redundant check. Moved import statement to top of file.
0b8670e
to
426f0c0
Compare
I rebased your PR so that the tests run. They were failing due to #195 |
One of the new tests seems to have a bug: https://github.com/pymc-devs/pytensor/actions/runs/3912986757/jobs/6688361568 |
Can you reproduce this test failure somehow? On my PC it runs ok (though I only tried it on the original branch). On my local run: the value of graph is [Elemwise{Composite{exp((i0 + i1 + i2 + i3))}}(x, y, z, w)], Whereas, on the failed test run it says: AttributeError: 'Exp' object has no attribute 'fgraph'. Without being able to reproduce the issue and properly understand how fgraphs are composed, I'm only guessing here, but I'm wondering if it's possible that the issue is that for some reason the Composite element was missing on the failed run? |
The failing tests are running with You can reproduce by setting import pytensor
pytensor.config.mode = ”FAST_COMPILE” At the top of the test file. Other tests usually tweak the mode do include/exclude fusion rewrites, so you can do the same. |
@@ -423,6 +424,100 @@ def local_sumsqr2dot(fgraph, node): | |||
return [new_out] | |||
|
|||
|
|||
@register_specialize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rethinking this, it may make sense to register these new rewrites in canonicalize as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, after looking around how it's resolved elsewhere and some experimentation with different options, I've ended up adding the following line at the beginning of the test function, which seems to have resolved the issue on my local runs in both with and without the FAST_COMPILE flag.
mode = get_mode("FAST_RUN") if config.mode == "FAST_COMPILE" else get_default_mode()
(Obviously, with passing the mode parameter to all function() calls)
If you're ok with this, on which branch do you want me to commit this change? The original one or the the one you had rebased that branch earlier?
Also, do you want me to add back the @register_canonicalize decorators?
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a more clean solution is to use mode=get_default_mode().excluding("fusion")
Then the tests will also be more straightforward because you don't need to check the graph inside the Composites
And yes, let's try to add the register_canonicalize
to these rewrites and see if it doesn't break anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An interesting side effect of making the rewrites canonical is that one of the test cases in test_math.py::TestSigmoidRewrites::test_local_sigm_times_exp breaks, as now my rewrites takes precedence over the sigma_times_exp rewrites. More specifically,
sigma(x) * sigma(-y) * (-exp(-x)) * exp(xy) * exp(y)
is supposed to become
(-1) * sigma(-x) * sigma(y) * exp(xy)
but my rewrites contracts the exp-s before the sigma*exp rewrite could take action, and the result becomes
(-1) * exp(xy+y-x) * sigma(x) * sigma(-y)
So, I guess, this is a trade-off, how should we resolve this? (Having run all tests, this is the only one the canonicalisation broke.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could maybe extend the sigma_times_exp
rewrite to check if the redundant term is somewhere inside the exponentiation (instead of being all there is)?
exp(y - x) * sigmoid(x) -> exp(y) * sigmoid(x)
Does that make sense? If the canonicalize doesn't do it we might need to represent the contents of the exponent as a flat a series of additions (with negated terms instead of substitution) so that we can match them more easily.
Might have other ramifications I am not seeing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking at the rewrite, such change does not seem trivial, so let's leave as is: don't canonicalize.
We can open a new issue to investigate if it's worth it.
Let's just change the mode thing in the tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering how frequent / marginal such use cases are. Are there any kind of statistics available on how often real world use cases trigger which rewrites?
Meanwhile, I've just committed the changes for the tests, hope they'll be fine this time.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, we have no telemetrics to study the types of graphs users create :D
…p_to_exp_add and local_mul_pow_to_pow_add, so that the checks in the test cases also work in FAST_COMPILE mode
…cit downcast of testing constants
…that the FAST_COMPILE test runs also work properly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suppose this can be merged now that all the tests are green?
Thanks @tamastokes, great work 👍 |
Hi,
Could someone please check whether I'm on the right track with this attempt to solve #54
This commit adds rewrites for e^x * e^y as e^(x+y) and e^x / e^y as e^(x-y). If you confirm that this is more or less the appropriate way, I'm happy to continue with power(base, x) * power(base, y) => power(base, x + y)
Apart from the general request for comments, I have a specific question too: I don't yet really understand how exactly the added decorators (@register_canonicalize and @register_specialize) change behaviour. I'd naively presume this particular node rewriter should be a genuine specializer, not sure about the canonicalize part though. Is there any document somewhere that explains this? Unfortunately, I haven't come across anything related to this so far.
Also, what else do you think would be useful to test here?
Thanks,
Tamás