Skip to content

simple refactor #45

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 1 addition & 21 deletions src/array_api_extra/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
m = atleast_nd(m, ndim=2, xp=xp)
m = xp.astype(m, dtype)

avg = _mean(m, axis=1, xp=xp)
avg = _utils.mean(m, axis=1, xp=xp)
fact = m.shape[1] - 1

if fact <= 0:
Expand Down Expand Up @@ -199,26 +199,6 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
return xp.reshape(diag, (n, n))


def _mean(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
xp: ModuleType,
) -> Array:
"""
Complex mean, https://github.com/data-apis/array-api/issues/846.
"""
if xp.isdtype(x.dtype, "complex floating"):
x_real = xp.real(x)
x_imag = xp.imag(x)
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
return mean_real + (mean_imag * xp.asarray(1j))
return xp.mean(x, axis=axis, keepdims=keepdims)


def expand_dims(
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
) -> Array:
Expand Down
22 changes: 21 additions & 1 deletion src/array_api_extra/_lib/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import _compat

__all__ = ["in1d"]
__all__ = ["in1d", "mean"]


def in1d(
Expand Down Expand Up @@ -63,3 +63,23 @@ def in1d(
if assume_unique:
return ret[: x1.shape[0]]
return xp.take(ret, rev_idx, axis=0)


def mean(
x: Array,
/,
*,
axis: int | tuple[int, ...] | None = None,
keepdims: bool = False,
xp: ModuleType,
) -> Array:
"""
Complex mean, https://github.com/data-apis/array-api/issues/846.
"""
if xp.isdtype(x.dtype, "complex floating"):
x_real = xp.real(x)
x_imag = xp.imag(x)
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
return mean_real + (mean_imag * xp.asarray(1j))
return xp.mean(x, axis=axis, keepdims=keepdims)
98 changes: 49 additions & 49 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,55 @@ def test_2d(self):
create_diagonal(xp.asarray([[1]]), xp=xp)


class TestExpandDims:
def test_functionality(self):
def _squeeze_all(b: Array) -> Array:
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
for axis in range(b.ndim):
with contextlib.suppress(ValueError):
b = xp.squeeze(b, axis=axis)
return b

s = (2, 3, 4, 5)
a = xp.empty(s)
for axis in range(-5, 4):
b = expand_dims(a, axis=axis, xp=xp)
assert b.shape[axis] == 1
assert _squeeze_all(b).shape == s

def test_axis_tuple(self):
a = xp.empty((3, 3, 3))
assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)

def test_axis_out_of_range(self):
s = (2, 3, 4, 5)
a = xp.empty(s)
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=-6, xp=xp)
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=5, xp=xp)

a = xp.empty((3, 3, 3))
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=(0, -6), xp=xp)
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=(0, 5), xp=xp)

def test_repeated_axis(self):
a = xp.empty((3, 3, 3))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(1, 1), xp=xp)

def test_positive_negative_repeated(self):
# https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
a = xp.empty((2, 3, 4, 5))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(3, -3), xp=xp)


class TestKron:
def test_basic(self):
# Using 0-dimensional array
Expand Down Expand Up @@ -222,55 +271,6 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron")


class TestExpandDims:
def test_functionality(self):
def _squeeze_all(b: Array) -> Array:
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
for axis in range(b.ndim):
with contextlib.suppress(ValueError):
b = xp.squeeze(b, axis=axis)
return b

s = (2, 3, 4, 5)
a = xp.empty(s)
for axis in range(-5, 4):
b = expand_dims(a, axis=axis, xp=xp)
assert b.shape[axis] == 1
assert _squeeze_all(b).shape == s

def test_axis_tuple(self):
a = xp.empty((3, 3, 3))
assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)

def test_axis_out_of_range(self):
s = (2, 3, 4, 5)
a = xp.empty(s)
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=-6, xp=xp)
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=5, xp=xp)

a = xp.empty((3, 3, 3))
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=(0, -6), xp=xp)
with pytest.raises(IndexError, match="out of bounds"):
expand_dims(a, axis=(0, 5), xp=xp)

def test_repeated_axis(self):
a = xp.empty((3, 3, 3))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(1, 1), xp=xp)

def test_positive_negative_repeated(self):
# https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
a = xp.empty((2, 3, 4, 5))
with pytest.raises(ValueError, match="Duplicate dimensions"):
expand_dims(a, axis=(3, -3), xp=xp)


class TestSetDiff1D:
def test_setdiff1d(self):
x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])
Expand Down