Description
On a few occasions while using this library, I've bumped against the issue having to set array values using advanced indexing. Here is an example:
import numpy as np
import numpy.array_api as anp
def one_hot_np(array, num_classes):
n = array.shape[0]
categorical = np.zeros((n, num_classes))
categorical[np.arange(n), array] = 1
return categorical
def one_hot_anp(array, num_classes):
one_hot = anp.zeros((array.shape[0], num_classes))
indices = anp.stack(
(anp.arange(array.shape[0]), anp.reshape(array, (-1,))), axis=-1
)
indices = anp.reshape(indices, shape=(-1, indices.shape[-1]))
for idx in range(indices.shape[0]):
one_hot[tuple(indices[idx, ...])] = 1
return one_hot
I'm using the numpy.array_api
namespace because it follows the API standard closely.
Is there a different (better) way of setting values of an array using integer (array) indices that adhere to the 2021.12 version of the array API standard?
For the example I gave, I'm aware that I can do something like this (but not with numpy.array_api
namespace, as it only supports v2021.12):
import numpy as np
import numpy.array_api as anp
def one_hot(array, num_classes):
id_arr = anp.eye(num_classes)
return np.take(id_arr, array, axis=0)
But I have other cases in my codebase that follow the first pattern - looping through array indices and using basic indexing to set array values. For example, using the indices from xp.argsort
to mark the top-k values. Is there a better way than looping through the indices?