Skip to content

BUG: sinc: broken on torch #49

Closed
@ev-br

Description

@ev-br

Consider

In [1]: import torch

In [2]: from array_api_compat import array_namespace

In [3]: xp = array_namespace(torch.ones(3))

In [4]: m, n = 7, 4.0

In [5]: import array_api_extra as xpx

In [6]: xpx.sinc(2. * xp.arange(n, m, dtype=xp.float64) / (m - 1) - 1.0, xp=xp)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 xpx.sinc(2. * xp.arange(n, m, dtype=xp.float64) / (m - 1) - 1.0, xp=xp)

File ~/.conda/envs/scipy-dev/lib/python3.11/site-packages/array_api_extra/_funcs.py:518, in sinc(x, xp)
    516     raise ValueError(err_msg)
    517 # no scalars in `where` - array-api#807
--> 518 y = xp.pi * xp.where(
    519     x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype)
    520 )
    521 return xp.sin(y) / y

File ~/repos/array-api-compat/array_api_compat/torch/_aliases.py:503, in where(condition, x1, x2)
    501 def where(condition: array, x1: array, x2: array, /) -> array:
    502     x1, x2 = _fix_promotion(x1, x2)
--> 503     return torch.where(condition, x1, x2)

RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Double

The same MRE on numpy or jax.numpy is OK.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions