Skip to content

Commit bc68c37

Browse files
committed
BUG: Maintain column order with groupby.nth
1 parent 383d052 commit bc68c37

File tree

4 files changed

+40
-8
lines changed

4 files changed

+40
-8
lines changed

doc/source/whatsnew/v0.24.0.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1323,6 +1323,7 @@ Groupby/Resample/Rolling
13231323
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
13241324
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
13251325
- Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`).
1326+
- Bug in :func:`pandas.core.groupby.GroupBy.nth` where column order was not always preserved (:issue:`20760`)
13261327

13271328
Reshaping
13281329
^^^^^^^^^

pandas/core/groupby/groupby.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,8 @@ def _set_group_selection(self):
493493

494494
if len(groupers):
495495
# GH12839 clear selected obj cache when group selection changes
496-
self._group_selection = ax.difference(Index(groupers)).tolist()
496+
self._group_selection = ax.difference(Index(groupers),
497+
sort=False).tolist()
497498
self._reset_cache('_selected_obj')
498499

499500
def _set_result_index_ordered(self, result):

pandas/core/indexes/base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2910,17 +2910,20 @@ def intersection(self, other):
29102910
taken.name = None
29112911
return taken
29122912

2913-
def difference(self, other):
2913+
def difference(self, other, sort=True):
29142914
"""
29152915
Return a new Index with elements from the index that are not in
29162916
`other`.
29172917
29182918
This is the set difference of two Index objects.
2919-
It's sorted if sorting is possible.
29202919
29212920
Parameters
29222921
----------
29232922
other : Index or array-like
2923+
sort : bool, default True
2924+
Sort the resulting index if possible
2925+
2926+
.. versionadded:: 0.24.0
29242927
29252928
Returns
29262929
-------
@@ -2929,10 +2932,12 @@ def difference(self, other):
29292932
Examples
29302933
--------
29312934
2932-
>>> idx1 = pd.Index([1, 2, 3, 4])
2935+
>>> idx1 = pd.Index([2, 1, 3, 4])
29332936
>>> idx2 = pd.Index([3, 4, 5, 6])
29342937
>>> idx1.difference(idx2)
29352938
Int64Index([1, 2], dtype='int64')
2939+
>>> idx1.difference(idx2, sort=False)
2940+
Int64Index([2, 1], dtype='int64')
29362941
29372942
"""
29382943
self._assert_can_do_setop(other)
@@ -2951,10 +2956,11 @@ def difference(self, other):
29512956
label_diff = np.setdiff1d(np.arange(this.size), indexer,
29522957
assume_unique=True)
29532958
the_diff = this.values.take(label_diff)
2954-
try:
2955-
the_diff = sorting.safe_sort(the_diff)
2956-
except TypeError:
2957-
pass
2959+
if sort:
2960+
try:
2961+
the_diff = sorting.safe_sort(the_diff)
2962+
except TypeError:
2963+
pass
29582964

29592965
return this._shallow_copy(the_diff, name=result_name, freq=None)
29602966

pandas/tests/groupby/test_nth.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,27 @@ def test_nth_empty():
390390
names=['a', 'b']),
391391
columns=['c'])
392392
assert_frame_equal(result, expected)
393+
394+
395+
def test_nth_column_order():
396+
# GH 20760
397+
# Check that nth preserves column order
398+
df = DataFrame([[1, 'b', 100],
399+
[1, 'a', 50],
400+
[1, 'a', np.nan],
401+
[2, 'c', 200],
402+
[2, 'd', 150]],
403+
columns=['A', 'C', 'B'])
404+
result = df.groupby('A').nth(0)
405+
expected = DataFrame([['b', 100.0],
406+
['c', 200.0]],
407+
columns=['C', 'B'],
408+
index=Index([1, 2], name='A'))
409+
assert_frame_equal(result, expected)
410+
411+
result = df.groupby('A').nth(-1, dropna='any')
412+
expected = DataFrame([['a', 50.0],
413+
['d', 150.0]],
414+
columns=['C', 'B'],
415+
index=Index([1, 2], name='A'))
416+
assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)