-
Notifications
You must be signed in to change notification settings - Fork 11
POC: appease linter for gh-53 #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
806f037
aa0d364
de00bde
bea4427
7ae5766
55a039a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,7 @@ | |
:nosignatures: | ||
:toctree: generated | ||
|
||
at | ||
atleast_nd | ||
cov | ||
create_diagonal | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,26 @@ | ||
from __future__ import annotations # https://github.com/pylint-dev/pylint/pull/9990 | ||
|
||
import operator | ||
import typing | ||
import warnings | ||
|
||
if typing.TYPE_CHECKING: | ||
from ._lib._typing import Array, ModuleType | ||
# https://github.com/pylint-dev/pylint/issues/10112 | ||
from collections.abc import Callable # pylint: disable=import-error | ||
from typing import ClassVar, Literal | ||
|
||
from ._lib import _utils | ||
from ._lib._compat import array_namespace | ||
from ._lib._compat import ( | ||
array_namespace, | ||
is_array_api_obj, | ||
is_dask_array, | ||
is_writeable_array, | ||
) | ||
|
||
if typing.TYPE_CHECKING: | ||
from ._lib._typing import Array, Index, ModuleType, Untyped | ||
|
||
__all__ = [ | ||
"at", | ||
"atleast_nd", | ||
"cov", | ||
"create_diagonal", | ||
|
@@ -548,3 +559,278 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: | |
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=x.device), | ||
) | ||
return xp.sin(y) / y | ||
|
||
|
||
_undef = object() | ||
|
||
|
||
class at: # pylint: disable=invalid-name | ||
""" | ||
Update operations for read-only arrays. | ||
|
||
This implements ``jax.numpy.ndarray.at`` for all backends. | ||
|
||
Parameters | ||
---------- | ||
x : array | ||
Input array. | ||
idx : index, optional | ||
You may use two alternate syntaxes:: | ||
|
||
at(x, idx).set(value) # or get(), add(), etc. | ||
at(x)[idx].set(value) | ||
|
||
copy : bool, optional | ||
True (default) | ||
Ensure that the inputs are not modified. | ||
False | ||
Ensure that the update operation writes back to the input. | ||
Raise ValueError if a copy cannot be avoided. | ||
None | ||
The array parameter *may* be modified in place if it is possible and | ||
beneficial for performance. | ||
You should not reuse it after calling this function. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer | ||
|
||
**kwargs: | ||
If the backend supports an `at` method, any additional keyword | ||
arguments are passed to it verbatim; e.g. this allows passing | ||
``indices_are_sorted=True`` to JAX. | ||
|
||
Returns | ||
------- | ||
Updated input array. | ||
|
||
Examples | ||
-------- | ||
Given either of these equivalent expressions:: | ||
|
||
x = at(x)[1].add(2, copy=None) | ||
x = at(x, 1).add(2, copy=None) | ||
|
||
If x is a JAX array, they are the same as:: | ||
|
||
x = x.at[1].add(2) | ||
|
||
If x is a read-only numpy array, they are the same as:: | ||
|
||
x = x.copy() | ||
x[1] += 2 | ||
|
||
Otherwise, they are the same as:: | ||
|
||
x[1] += 2 | ||
|
||
Warning | ||
------- | ||
When you use copy=None, you should always immediately overwrite | ||
the parameter array:: | ||
|
||
x = at(x, 0).set(2, copy=None) | ||
|
||
The anti-pattern below must be avoided, as it will result in different behaviour | ||
on read-only versus writeable arrays:: | ||
|
||
x = xp.asarray([0, 0, 0]) | ||
y = at(x, 0).set(2, copy=None) | ||
z = at(x, 1).set(3, copy=None) | ||
|
||
In the above example, ``x == [0, 0, 0]``, ``y == [2, 0, 0]`` and z == ``[0, 3, 0]`` | ||
when x is read-only, whereas ``x == y == z == [2, 3, 0]`` when x is writeable! | ||
|
||
Warning | ||
------- | ||
The array API standard does not support integer array indices. | ||
The behaviour of update methods when the index is an array of integers | ||
is undefined; this is particularly true when the index contains multiple | ||
occurrences of the same index, e.g. ``at(x, [0, 0]).set(2)``. | ||
|
||
Note | ||
---- | ||
`sparse <https://sparse.pydata.org/>`_ is not supported by update methods yet. | ||
|
||
See Also | ||
-------- | ||
`jax.numpy.ndarray.at <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html>`_ | ||
""" | ||
|
||
x: Array | ||
idx: Index | ||
__slots__: ClassVar[tuple[str, str]] = ("idx", "x") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMHO the linter should not force me to define the type of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like this is no longer needed EDIT: only if |
||
|
||
def __init__(self, x: Array, idx: Index = _undef, /) -> None: | ||
self.x = x | ||
self.idx = idx | ||
|
||
def __getitem__(self, idx: Index, /) -> at: | ||
"""Allow for the alternate syntax ``at(x)[start:stop:step]``, | ||
which looks prettier than ``at(x, slice(start, stop, step))`` | ||
and feels more intuitive coming from the JAX documentation. | ||
""" | ||
if self.idx is not _undef: | ||
msg = "Index has already been set" | ||
raise ValueError(msg) | ||
self.idx = idx | ||
return self | ||
|
||
def _common( | ||
self, | ||
at_op: str, | ||
y: Array = _undef, | ||
/, | ||
copy: bool | None = True, | ||
xp: ModuleType | None = None, | ||
_is_update: bool = True, | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
**kwargs: Untyped, | ||
) -> tuple[Untyped, None] | tuple[None, Array]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it perhaps possible to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It isn't because the return type depends on a duck-type test on x. And I'm definitely unwilling to explore writing a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I am willing, so here you go: class _CanAt(Protocol):
@property
def at(self) -> Mapping[Index, Untyped] ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. making this change - could you clarify how to use |
||
"""Perform common prepocessing. | ||
|
||
Returns | ||
------- | ||
If the operation can be resolved by at[], (return value, None) | ||
Otherwise, (None, preprocessed x) | ||
""" | ||
if self.idx is _undef: | ||
msg = ( | ||
"Index has not been set.\n" | ||
"Usage: either\n" | ||
" at(x, idx).set(value)\n" | ||
"or\n" | ||
" at(x)[idx].set(value)\n" | ||
"(same for all other methods)." | ||
) | ||
raise TypeError(msg) | ||
|
||
x = self.x | ||
|
||
if copy is None: | ||
writeable = is_writeable_array(x) | ||
copy = _is_update and not writeable | ||
elif copy: | ||
writeable = None | ||
else: | ||
writeable = is_writeable_array(x) | ||
if not writeable: | ||
msg = "Cannot modify parameter in place" | ||
raise ValueError(msg) | ||
|
||
if copy: | ||
try: | ||
at_ = x.at | ||
except AttributeError: | ||
# Emulate at[] behaviour for non-JAX arrays | ||
# with a copy followed by an update | ||
if xp is None: | ||
xp = array_namespace(x) | ||
# Create writeable copy of read-only numpy array | ||
x = xp.asarray(x, copy=True) | ||
if writeable is False: | ||
# A copy of a read-only numpy array is writeable | ||
writeable = None | ||
else: | ||
# Use JAX's at[] or other library that with the same duck-type API | ||
args = (y,) if y is not _undef else () | ||
return getattr(at_[self.idx], at_op)(*args, **kwargs), None | ||
|
||
if _is_update: | ||
if writeable is None: | ||
writeable = is_writeable_array(x) | ||
if not writeable: | ||
# sparse crashes here | ||
msg = f"Array {x} has no `at` method and is read-only" | ||
raise ValueError(msg) | ||
|
||
return None, x | ||
|
||
def get(self, **kwargs: Untyped) -> Untyped: | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Return ``x[idx]``. In addition to plain ``__getitem__``, this allows ensuring | ||
that the output is either a copy or a view; it also allows passing | ||
keyword arguments to the backend. | ||
""" | ||
if kwargs.get("copy") is False: | ||
if is_array_api_obj(self.idx): | ||
# Boolean index. Note that the array API spec | ||
# https://data-apis.org/array-api/latest/API_specification/indexing.html | ||
# does not allow for list, tuple, and tuples of slices plus one or more | ||
# one-dimensional array indices, although many backends support them. | ||
# So this check will encounter a lot of false negatives in real life, | ||
# which can be caught by testing the user code vs. array-api-strict. | ||
msg = "get() with an array index always returns a copy" | ||
raise ValueError(msg) | ||
if is_dask_array(self.x): | ||
msg = "get() on Dask arrays always returns a copy" | ||
raise ValueError(msg) | ||
|
||
res, x = self._common("get", _is_update=False, **kwargs) | ||
if res is not None: | ||
return res | ||
assert x is not None | ||
return x[self.idx] | ||
|
||
def set(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] = y`` and return the update array""" | ||
res, x = self._common("set", y, **kwargs) | ||
if res is not None: | ||
return res | ||
assert x is not None | ||
x[self.idx] = y | ||
return x | ||
|
||
def _iop( | ||
self, | ||
at_op: Literal[ | ||
"set", "add", "subtract", "multiply", "divide", "power", "min", "max" | ||
], | ||
elwise_op: Callable[[Array, Array], Array], | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y: Array, | ||
/, | ||
**kwargs: Untyped, | ||
) -> Array: | ||
"""x[idx] += y or equivalent in-place operation on a subset of x | ||
|
||
which is the same as saying | ||
x[idx] = x[idx] + y | ||
Note that this is not the same as | ||
operator.iadd(x[idx], y) | ||
Consider for example when x is a numpy array and idx is a fancy index, which | ||
triggers a deep copy on __getitem__. | ||
""" | ||
res, x = self._common(at_op, y, **kwargs) | ||
if res is not None: | ||
return res | ||
assert x is not None | ||
x[self.idx] = elwise_op(x[self.idx], y) | ||
return x | ||
|
||
def add(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] += y`` and return the updated array""" | ||
return self._iop("add", operator.add, y, **kwargs) | ||
|
||
def subtract(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] -= y`` and return the updated array""" | ||
return self._iop("subtract", operator.sub, y, **kwargs) | ||
|
||
def multiply(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] *= y`` and return the updated array""" | ||
return self._iop("multiply", operator.mul, y, **kwargs) | ||
|
||
def divide(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] /= y`` and return the updated array""" | ||
return self._iop("divide", operator.truediv, y, **kwargs) | ||
|
||
def power(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] **= y`` and return the updated array""" | ||
return self._iop("power", operator.pow, y, **kwargs) | ||
|
||
def min(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] = minimum(x[idx], y)`` and return the updated array""" | ||
xp = array_namespace(self.x) | ||
y = xp.asarray(y) | ||
return self._iop("min", xp.minimum, y, **kwargs) | ||
|
||
def max(self, y: Array, /, **kwargs: Untyped) -> Array: | ||
"""Apply ``x[idx] = maximum(x[idx], y)`` and return the updated array""" | ||
xp = array_namespace(self.x) | ||
y = xp.asarray(y) | ||
return self._iop("max", xp.maximum, y, **kwargs) |
Uh oh!
There was an error while loading. Please reload this page.