Skip to content

Commit 5be6c57

Browse files
committed
Make create_diagonal support broadcasting
1 parent 27b0bf2 commit 5be6c57

File tree

2 files changed

+40
-16
lines changed

2 files changed

+40
-16
lines changed

src/array_api_extra/_lib/_funcs.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import math
77
import warnings
8-
from collections.abc import Sequence
8+
from collections.abc import Generator, Sequence
99
from types import ModuleType
1010
from typing import cast
1111

@@ -163,6 +163,16 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array:
163163
return xp.squeeze(c, axis=axes)
164164

165165

166+
def ndindex(*x: int) -> Generator[tuple[int, ...]]:
167+
if not x:
168+
yield ()
169+
return
170+
indices = list(ndindex(*x[1:]))
171+
for i in range(x[0]):
172+
for j in indices:
173+
yield i, *j
174+
175+
166176
def create_diagonal(
167177
x: Array, /, *, offset: int = 0, xp: ModuleType | None = None
168178
) -> Array:
@@ -172,7 +182,7 @@ def create_diagonal(
172182
Parameters
173183
----------
174184
x : array
175-
A 1-D array.
185+
An array having shape (*broadcast_dims, k).
176186
offset : int, optional
177187
Offset from the leading diagonal (default is ``0``).
178188
Use positive ints for diagonals above the leading diagonal,
@@ -183,7 +193,8 @@ def create_diagonal(
183193
Returns
184194
-------
185195
array
186-
A 2-D array with `x` on the diagonal (offset by `offset`).
196+
An array having shape (*broadcast_dims, k+abs(offset), k+abs(offset)) with `x`
197+
on the diagonal (offset by `offset`).
187198
188199
Examples
189200
--------
@@ -206,18 +217,21 @@ def create_diagonal(
206217
if xp is None:
207218
xp = array_namespace(x)
208219

209-
if x.ndim != 1:
210-
err_msg = "`x` must be 1-dimensional."
220+
if x.ndim == 0:
221+
err_msg = "`x` must be at least 1-dimensional."
211222
raise ValueError(err_msg)
212-
n = x.shape[0] + abs(offset)
213-
diag = xp.zeros(n**2, dtype=x.dtype, device=_compat.device(x))
214-
215-
start = offset if offset >= 0 else abs(offset) * n
216-
stop = min(n * (n - offset), diag.shape[0])
217-
step = n + 1
218-
diag = at(diag)[start:stop:step].set(x)
219-
220-
return xp.reshape(diag, (n, n))
223+
pre = x.shape[:-1]
224+
n = x.shape[-1] + abs(offset)
225+
diag = xp.zeros((*pre, n**2), dtype=x.dtype, device=_compat.device(x))
226+
227+
target_slice = slice(
228+
offset if offset >= 0 else abs(offset) * n,
229+
min(n * (n - offset), diag.shape[-1]),
230+
n + 1,
231+
)
232+
for index in ndindex(*pre):
233+
diag = at(diag)[(*index, target_slice)].set(x[*index, :])
234+
return xp.reshape(diag, (*pre, n, n))
221235

222236

223237
def expand_dims(

tests/test_funcs.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import warnings
33
from types import ModuleType
4+
from math import prod
45

56
import numpy as np
67
import pytest
@@ -19,6 +20,7 @@
1920
sinc,
2021
)
2122
from array_api_extra._lib import Backend
23+
from array_api_extra._lib._funcs import ndindex
2224
from array_api_extra._lib._testing import xp_assert_close, xp_assert_equal
2325
from array_api_extra._lib._utils._compat import device as get_device
2426
from array_api_extra._lib._utils._typing import Array, Device
@@ -193,9 +195,17 @@ def test_0d(self, xp: ModuleType):
193195
with pytest.raises(ValueError, match="1-dimensional"):
194196
create_diagonal(xp.asarray(1))
195197

198+
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
196199
def test_2d(self, xp: ModuleType):
197-
with pytest.raises(ValueError, match="1-dimensional"):
198-
create_diagonal(xp.asarray([[1]]))
200+
result = create_diagonal(xp.asarray([[1]]))
201+
xp_assert_equal(result, xp.asarray([[[1]]]))
202+
b = xp.zeros((3, 2, 4, 5), dtype=xp.int64)
203+
for i in ndindex(*b.shape):
204+
b = at(b)[*i].set(prod(i))
205+
c = create_diagonal(b)
206+
zero = xp.zeros((), dtype=xp.int64)
207+
for i in ndindex(*c.shape):
208+
xp_assert_equal(c[*i], b[*(i[:-1])] if i[-2] == i[-1] else zero)
199209

200210
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no device kwarg in zeros()")
201211
def test_device(self, xp: ModuleType, device: Device):

0 commit comments

Comments
 (0)