Description
Currently, functions of this package require passing a standard-compatible namespace as xp=xp
. This works fine, but there have been suggestions that it might be nice to avoid this requirement. There are at least a few ways we could go about this:
(1) xpx.bind_namespace
Usage:
import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.bind_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xp.sum(x)
z = xpx.some_func(y)
A potential implementation:
extra_funcs = {'atleast_nd': atleast_nd, ...}
def bind_namespace(xp: ModuleType) -> ModuleType:
class BoundNamespace:
def __getattr__(self, name: str):
if name in extra_funcs:
return functools.partial(extra_funcs[name], xp=xp)
else:
return AttributeError(...)
return BoundNamespace(xp)
I like this idea. If we encounter use cases where a library wants to use multiple xpx
functions in the same local scope and finds the xp=xp
pattern too cumbersome, I think we should add this. I think we can leave it out for now until that situation arises.
(2) xpx.extra_namespace
Usage:
import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.extra_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xpx.sum(x) # XXX: xpx instead of xp
z = xpx.some_func(y)
A potential implementation:
extra_funcs = {'atleast_nd': atleast_nd, ...}
def extra_namespace(xp: ModuleType) -> ModuleType:
class ExtraNamespace:
def __getattr__(self, name: str):
if name in extra_funcs:
return functools.partial(extra_funcs[name], xp=xp)
else:
return getattr(xp, name) # XXX: delegate to xp instead of error
return ExtraNamespace(xp)
I would not want to add this yet. I think we should keep separation between the standard namespace and the 'extra' namespace, at least until this library matures.
(3) Use array_api_compat.array_namespace
internally
This would provide the most flexible API and be the least LOC to use. One could use xpx
functions on standard-incompatible arrays, and let array-api-compat handle the compatibility, without having to pass an xp
argument.
We don't yet have a use case where it is clearly beneficial to be able to pass standard-incompatible arrays. Consumer libraries using array-api-extra would already be computing with standard-compatible arrays internally. I don't see the need to support the following use case:
import torch
import array_api_strict as xpx
...
x = torch.asarray([1, 2, 3])
xpx.some_func(x) # works
torch.some_standard_func(x) # does not work
Another complication is that consumer libraries like SciPy wrap array_namespace
to provide custom behaviour for scalars and other types. We would want the internal array_namespace
to be the consumer library's wrapped version rather than the base one from array-api-compat.
I'm also not sure that the 1 LOC save over option (1) of this post for standard-compatible arrays is worth introducing a dependency on array-api-compat.
Overall, this would complicate things a lot with situations of co-vendoring array-api-compat and array-api-extra, which is the primary use-case for the library right now. This might be a better idea in the future if a need for handling standard-incompatible arrays arises (for example, if one wants to use functions from xpx
with just a single library).