-
Notifications
You must be signed in to change notification settings - Fork 131
Add support for negative axis in specify_broadcastable
#710
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
scripts/mypy-failing.txt
Outdated
@@ -25,6 +25,7 @@ pytensor/tensor/random/basic.py | |||
pytensor/tensor/random/op.py | |||
pytensor/tensor/random/utils.py | |||
pytensor/tensor/rewriting/basic.py | |||
pytensor/tensor/shape.py |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be done. Over time we want less files, not more. What's mypy complaining about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mypy complains about the line from numpy.core.numeric import normalize_axis_tuple
. The message is that numpy.core.numeric
does not have normalize_axis_tuple
.
This already occurs in other places as well and are ignored in the mypy-failing-list. For eg:
pytensor/pytensor/tensor/basic.py
Line 18 in 5e612ab
from numpy.core.numeric import normalize_axis_tuple |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can do the same. Better to ignore a single line than a whole file.
Going forward: In numpy 2.0 at least they moved this to a user facing location so it should be legal to import then. I don't know if it's already there in the current releases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It already exists. Can you check if it's fine to import from here:
https://numpy.org/devdocs/reference/generated/numpy.lib.array_utils.normalize_axis_tuple.html#:~:text=Normalizes%20an%20axis%20argument%20into,from%20being%20specified%20multiple%20times.
If so we should import like this in the other places as remove the type ignore
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll just try this out 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I checked I am currently using numpy version 1.26.4 and I get the error AttributeError: module 'numpy.lib' has no attribute 'array_utils'
. I checked the site-packages and it seems like this feature is introduced in numpy 2.0 and greater. Here's the reference numpy.lib
_init_.py
that matches with what I have locally:
We are yet to switch to numpy2.0 right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Strange because the docs say it's new from v1.13.
Anyway that answers it for now. No we're not switching to numpy 2.0 for a while, so let's keep importing and telling mypy to ignore it
Hey is there something left to do here in this PR? |
Revert the mypy failing the whole file and just add type: ignore on the imports |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #710 +/- ##
=======================================
Coverage 80.76% 80.76%
=======================================
Files 162 162
Lines 46698 46707 +9
Branches 11421 11422 +1
=======================================
+ Hits 37715 37723 +8
- Misses 6734 6735 +1
Partials 2249 2249
|
specify_broadcastable
@Dhruvanshu-Joshi pro tip: next time you can try and squash all your commits locally and force-push. |
Description
Normalized the axes argument in
specify_broadcastable
usingnormalize_axis_tuple
from numpy. Also added tensor/shape.py to mypy-failing-list as import statements cause expected failure.Related Issue
pt.specify_broadcastable
does not work with negative axis values #698Checklist
Type of change