Closed
Description
Consider
In [1]: import torch
In [2]: from array_api_compat import array_namespace
In [3]: xp = array_namespace(torch.ones(3))
In [4]: m, n = 7, 4.0
In [5]: import array_api_extra as xpx
In [6]: xpx.sinc(2. * xp.arange(n, m, dtype=xp.float64) / (m - 1) - 1.0, xp=xp)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 xpx.sinc(2. * xp.arange(n, m, dtype=xp.float64) / (m - 1) - 1.0, xp=xp)
File ~/.conda/envs/scipy-dev/lib/python3.11/site-packages/array_api_extra/_funcs.py:518, in sinc(x, xp)
516 raise ValueError(err_msg)
517 # no scalars in `where` - array-api#807
--> 518 y = xp.pi * xp.where(
519 x, x, xp.asarray(xp.finfo(x.dtype).smallest_normal, dtype=x.dtype)
520 )
521 return xp.sin(y) / y
File ~/repos/array-api-compat/array_api_compat/torch/_aliases.py:503, in where(condition, x1, x2)
501 def where(condition: array, x1: array, x2: array, /) -> array:
502 x1, x2 = _fix_promotion(x1, x2)
--> 503 return torch.where(condition, x1, x2)
RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Double
The same MRE on numpy or jax.numpy is OK.