Skip to content

Commit d2dc56f

Browse files
authored
REF: de-duplicate NDFrame.take, remove Manager.take keyword (#51482)
1 parent c82ff3f commit d2dc56f

File tree

6 files changed

+30
-61
lines changed

6 files changed

+30
-61
lines changed

pandas/core/generic.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3816,6 +3816,7 @@ def _clear_item_cache(self) -> None:
38163816
# ----------------------------------------------------------------------
38173817
# Indexing Methods
38183818

3819+
@final
38193820
def take(self: NDFrameT, indices, axis: Axis = 0, **kwargs) -> NDFrameT:
38203821
"""
38213822
Return the elements in the given *positional* indices along an axis.
@@ -3893,20 +3894,6 @@ class max_speed
38933894

38943895
nv.validate_take((), kwargs)
38953896

3896-
return self._take(indices, axis)
3897-
3898-
@final
3899-
def _take(
3900-
self: NDFrameT,
3901-
indices,
3902-
axis: Axis = 0,
3903-
convert_indices: bool_t = True,
3904-
) -> NDFrameT:
3905-
"""
3906-
Internal version of the `take` allowing specification of additional args.
3907-
3908-
See the docstring of `take` for full explanation of the parameters.
3909-
"""
39103897
if not isinstance(indices, slice):
39113898
indices = np.asarray(indices, dtype=np.intp)
39123899
if (
@@ -3916,8 +3903,14 @@ def _take(
39163903
and is_range_indexer(indices, len(self))
39173904
):
39183905
return self.copy(deep=None)
3906+
elif self.ndim == 1:
3907+
# TODO: be consistent here for DataFrame vs Series
3908+
raise TypeError(
3909+
f"{type(self).__name__}.take requires a sequence of integers, "
3910+
"not slice."
3911+
)
39193912
else:
3920-
# We can get here with a slice via DataFrame.__geittem__
3913+
# We can get here with a slice via DataFrame.__getitem__
39213914
indices = np.arange(
39223915
indices.start, indices.stop, indices.step, dtype=np.intp
39233916
)
@@ -3926,21 +3919,23 @@ def _take(
39263919
indices,
39273920
axis=self._get_block_manager_axis(axis),
39283921
verify=True,
3929-
convert_indices=convert_indices,
39303922
)
39313923
return self._constructor(new_data).__finalize__(self, method="take")
39323924

3925+
@final
39333926
def _take_with_is_copy(self: NDFrameT, indices, axis: Axis = 0) -> NDFrameT:
39343927
"""
39353928
Internal version of the `take` method that sets the `_is_copy`
39363929
attribute to keep track of the parent dataframe (using in indexing
39373930
for the SettingWithCopyWarning).
39383931
3932+
For Series this does the same as the public take (it never sets `_is_copy`).
3933+
39393934
See the docstring of `take` for full explanation of the parameters.
39403935
"""
3941-
result = self._take(indices=indices, axis=axis)
3936+
result = self.take(indices=indices, axis=axis)
39423937
# Maybe set copy if we didn't actually change the index.
3943-
if not result._get_axis(axis).equals(self._get_axis(axis)):
3938+
if self.ndim == 2 and not result._get_axis(axis).equals(self._get_axis(axis)):
39443939
result._set_is_copy(self)
39453940
return result
39463941

pandas/core/groupby/groupby.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1572,7 +1572,10 @@ def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:
15721572
# GH#46209
15731573
# Don't convert indices: negative indices need to give rise
15741574
# to null values in the result
1575-
output = result._take(ids, axis=axis, convert_indices=False)
1575+
new_ax = result.axes[axis].take(ids)
1576+
output = result._reindex_with_indexers(
1577+
{axis: (new_ax, ids)}, allow_dups=True, copy=False
1578+
)
15761579
output = output.set_axis(obj._get_axis(self.axis), axis=axis)
15771580
return output
15781581

pandas/core/internals/array_manager.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,6 @@ def take(
633633
indexer: npt.NDArray[np.intp],
634634
axis: AxisInt = 1,
635635
verify: bool = True,
636-
convert_indices: bool = True,
637636
) -> T:
638637
"""
639638
Take items along any axis.
@@ -647,8 +646,7 @@ def take(
647646
raise ValueError("indexer should be 1-dimensional")
648647

649648
n = self.shape_proper[axis]
650-
if convert_indices:
651-
indexer = maybe_convert_indices(indexer, n, verify=verify)
649+
indexer = maybe_convert_indices(indexer, n, verify=verify)
652650

653651
new_labels = self._axes[axis].take(indexer)
654652
return self._reindex_indexer(

pandas/core/internals/managers.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -910,7 +910,6 @@ def take(
910910
indexer: npt.NDArray[np.intp],
911911
axis: AxisInt = 1,
912912
verify: bool = True,
913-
convert_indices: bool = True,
914913
) -> T:
915914
"""
916915
Take items along any axis.
@@ -920,8 +919,6 @@ def take(
920919
verify : bool, default True
921920
Check that all entries are between 0 and len(self) - 1, inclusive.
922921
Pass verify=False if this check has been done by the caller.
923-
convert_indices : bool, default True
924-
Whether to attempt to convert indices to positive values.
925922
926923
Returns
927924
-------
@@ -931,8 +928,7 @@ def take(
931928
assert indexer.dtype == np.intp, indexer.dtype
932929

933930
n = self.shape[axis]
934-
if convert_indices:
935-
indexer = maybe_convert_indices(indexer, n, verify=verify)
931+
indexer = maybe_convert_indices(indexer, n, verify=verify)
936932

937933
new_labels = self.axes[axis].take(indexer)
938934
return self.reindex_indexer(

pandas/core/series.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
maybe_cast_pointwise_result,
9797
)
9898
from pandas.core.dtypes.common import (
99-
ensure_platform_int,
10099
is_dict_like,
101100
is_extension_array_dtype,
102101
is_integer,
@@ -908,36 +907,6 @@ def axes(self) -> list[Index]:
908907
# ----------------------------------------------------------------------
909908
# Indexing Methods
910909

911-
@Appender(NDFrame.take.__doc__)
912-
def take(self, indices, axis: Axis = 0, **kwargs) -> Series:
913-
nv.validate_take((), kwargs)
914-
915-
indices = ensure_platform_int(indices)
916-
917-
if (
918-
indices.ndim == 1
919-
and using_copy_on_write()
920-
and is_range_indexer(indices, len(self))
921-
):
922-
return self.copy(deep=None)
923-
924-
new_index = self.index.take(indices)
925-
new_values = self._values.take(indices)
926-
927-
result = self._constructor(new_values, index=new_index, fastpath=True)
928-
return result.__finalize__(self, method="take")
929-
930-
def _take_with_is_copy(self, indices, axis: Axis = 0) -> Series:
931-
"""
932-
Internal version of the `take` method that sets the `_is_copy`
933-
attribute to keep track of the parent dataframe (using in indexing
934-
for the SettingWithCopyWarning). For Series this does the same
935-
as the public take (it never sets `_is_copy`).
936-
937-
See the docstring of `take` for full explanation of the parameters.
938-
"""
939-
return self.take(indices=indices, axis=axis)
940-
941910
def _ixs(self, i: int, axis: AxisInt = 0) -> Any:
942911
"""
943912
Return the i-th value or values in the Series by location.

pandas/tests/series/indexing/test_take.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@ def test_take():
1616
expected = Series([4, 2, 4], index=[4, 3, 4])
1717
tm.assert_series_equal(actual, expected)
1818

19-
msg = lambda x: f"index {x} is out of bounds for( axis 0 with)? size 5"
20-
with pytest.raises(IndexError, match=msg(10)):
19+
msg = "indices are out-of-bounds"
20+
with pytest.raises(IndexError, match=msg):
2121
ser.take([1, 10])
22-
with pytest.raises(IndexError, match=msg(5)):
22+
with pytest.raises(IndexError, match=msg):
2323
ser.take([2, 5])
2424

2525

@@ -31,3 +31,11 @@ def test_take_categorical():
3131
pd.Categorical(["b", "b", "a"], categories=["a", "b", "c"]), index=[1, 1, 0]
3232
)
3333
tm.assert_series_equal(result, expected)
34+
35+
36+
def test_take_slice_raises():
37+
ser = Series([-1, 5, 6, 2, 4])
38+
39+
msg = "Series.take requires a sequence of integers, not slice"
40+
with pytest.raises(TypeError, match=msg):
41+
ser.take(slice(0, 3, 1))

0 commit comments

Comments
 (0)