Skip to content

Commit 0fcbd99

Browse files
committed
Fixed failing test; list comp for _fill method
1 parent aba1467 commit 0fcbd99

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pandas/core/groupby.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1882,10 +1882,10 @@ def _get_cythonized_result(self, how, grouper, needs_mask=False,
18821882
base_func = getattr(libgroupby, how)
18831883

18841884
for name, obj in self._iterate_slices():
1885-
indexer = np.zeros_like(labels)
1885+
indexer = np.zeros_like(labels, dtype=np.int64)
18861886
func = partial(base_func, indexer, labels)
18871887
if needs_mask:
1888-
mask = isnull(obj.values).astype(np.uint8, copy=False)
1888+
mask = isnull(obj.values).view(np.uint8)
18891889
func = partial(func, mask)
18901890

18911891
if needs_ngroups:
@@ -4633,12 +4633,11 @@ def _apply_to_column_groupbys(self, func):
46334633
keys=self._selected_obj.columns, axis=1)
46344634

46354635
def _fill(self, direction, limit=None):
4636-
"""Overriden method to concat grouped columns in output"""
4636+
"""Overriden method to join grouped columns in output"""
46374637
res = super()._fill(direction, limit=limit)
4638-
output = collections.OrderedDict()
4639-
for grp in self.grouper.groupings:
4640-
ser = grp.group_index.take(grp.labels)
4641-
output[ser.name] = ser.values
4638+
output = collections.OrderedDict(
4639+
(grp.name, grp.group_index.take(grp.labels)) for grp in
4640+
self.grouper.groupings)
46424641

46434642
return self._wrap_transformed_output(output).join(res)
46444643

pandas/tests/groupby/test_transform.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,9 @@ def test_cython_transform_frame(self, op, args, targop):
552552
tm.assert_frame_equal(expected,
553553
gb.transform(op, *args).sort_index(
554554
axis=1))
555-
tm.assert_frame_equal(expected, getattr(gb, op)(*args))
555+
tm.assert_frame_equal(expected,
556+
getattr(gb, op)(*args).sort_index(axis=1)
557+
)
556558
# individual columns
557559
for c in df:
558560
if c not in ['float', 'int', 'float_missing'

0 commit comments

Comments
 (0)