Skip to content

ENH/API: xp-bound namespaces, array-api-compat #6

Closed
@lucascolley

Description

@lucascolley

Currently, functions of this package require passing a standard-compatible namespace as xp=xp. This works fine, but there have been suggestions that it might be nice to avoid this requirement. There are at least a few ways we could go about this:

(1) xpx.bind_namespace

Usage:

import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.bind_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xp.sum(x)
z = xpx.some_func(y)

A potential implementation:

extra_funcs = {'atleast_nd': atleast_nd, ...}

def bind_namespace(xp: ModuleType) -> ModuleType:
    class BoundNamespace:
        def __getattr__(self, name: str):
            if name in extra_funcs:
                return functools.partial(extra_funcs[name], xp=xp)
            else:
               return AttributeError(...)

    return BoundNamespace(xp)

I like this idea. If we encounter use cases where a library wants to use multiple xpx functions in the same local scope and finds the xp=xp pattern too cumbersome, I think we should add this. I think we can leave it out for now until that situation arises.

(2) xpx.extra_namespace

Usage:

import array_api_strict as xpx
...
xp = array_namespace(x)
xpx = xpx.extra_namespace(xp)
x = xpx.atleast_nd(x, ndim=2)
y = xpx.sum(x)  # XXX: xpx instead of xp
z = xpx.some_func(y)

A potential implementation:

extra_funcs = {'atleast_nd': atleast_nd, ...}

def extra_namespace(xp: ModuleType) -> ModuleType:
    class ExtraNamespace:
        def __getattr__(self, name: str):
            if name in extra_funcs:
                return functools.partial(extra_funcs[name], xp=xp)
            else:
               return getattr(xp, name)  # XXX: delegate to xp instead of error

    return ExtraNamespace(xp)

I would not want to add this yet. I think we should keep separation between the standard namespace and the 'extra' namespace, at least until this library matures.

(3) Use array_api_compat.array_namespace internally

This would provide the most flexible API and be the least LOC to use. One could use xpx functions on standard-incompatible arrays, and let array-api-compat handle the compatibility, without having to pass an xp argument.

We don't yet have a use case where it is clearly beneficial to be able to pass standard-incompatible arrays. Consumer libraries using array-api-extra would already be computing with standard-compatible arrays internally. I don't see the need to support the following use case:

import torch
import array_api_strict as xpx
...
x = torch.asarray([1, 2, 3])
xpx.some_func(x)             # works
torch.some_standard_func(x)  # does not work

Another complication is that consumer libraries like SciPy wrap array_namespace to provide custom behaviour for scalars and other types. We would want the internal array_namespace to be the consumer library's wrapped version rather than the base one from array-api-compat.

I'm also not sure that the 1 LOC save over option (1) of this post for standard-compatible arrays is worth introducing a dependency on array-api-compat.

Overall, this would complicate things a lot with situations of co-vendoring array-api-compat and array-api-extra, which is the primary use-case for the library right now. This might be a better idea in the future if a need for handling standard-incompatible arrays arises (for example, if one wants to use functions from xpx with just a single library).

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions