Skip to content

Improve performance of dpnp.nanmedian #2240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 14 additions & 21 deletions dpnp/dpnp_utils/dpnp_utils_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def _calc_median(a, axis, out=None):
return res


def _calc_nanmedian(a, axis, out=None):
def _calc_nanmedian(a, out=None):
"""Compute the median of an array along a specified axis, ignoring NaNs."""
mask = dpnp.isnan(a)
valid_counts = dpnp.sum(~mask, axis=axis)
valid_counts = dpnp.sum(~mask, axis=-1)
if out is None:
res = dpnp.empty_like(valid_counts, dtype=a.dtype)
else:
Expand All @@ -76,27 +76,19 @@ def _calc_nanmedian(a, axis, out=None):
)
res = out

# Iterate over all indices of the output shape
for idx in dpnp.ndindex(res.shape):
current_valid_counts = valid_counts[idx]
left = (valid_counts - 1) // 2
right = valid_counts // 2

if current_valid_counts > 0:
# Extract the corresponding slice from the last axis of `a`
data = a[idx][:current_valid_counts]
left = (current_valid_counts - 1) // 2
right = current_valid_counts // 2
left_data = dpnp.take_along_axis(a, left[..., None], axis=-1)
right_data = dpnp.take_along_axis(a, right[..., None], axis=-1)
res = dpnp.where(
valid_counts[..., None] > 0, (left_data + right_data) / 2.0, dpnp.nan
)

if left == right:
res[idx] = data[left]
else:
res[idx] = (data[left] + data[right]) / 2.0
else:
warnings.warn(
"All-NaN slice encountered", RuntimeWarning, stacklevel=6
)
res[idx] = dpnp.nan
if mask.all(axis=-1).any():
warnings.warn("All-NaN slice encountered", RuntimeWarning, stacklevel=6)

return res
return dpnp.squeeze(res)


def _flatten_array_along_axes(a, axes_to_flatten, overwrite_input):
Expand Down Expand Up @@ -232,7 +224,8 @@ def dpnp_median(

if ignore_nan:
# sorting puts NaNs at the end
res = _calc_nanmedian(a_sorted, axis=axis, out=out)
assert axis == -1
res = _calc_nanmedian(a_sorted, out=out)
else:
# We can't pass keepdims and use it in dpnp.mean and dpnp.any
# because of the reshape hack that might have been used in
Expand Down
Loading