Skip to content

Commit 47c6d16

Browse files
authored
REG: DataFrame/Series.transform with list and non-list dict values (#40090)
1 parent 11afc76 commit 47c6d16

File tree

4 files changed

+52
-23
lines changed

4 files changed

+52
-23
lines changed

doc/source/whatsnew/v1.2.3.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ Fixed regressions
2424
Passing ``ascending=None`` is still considered invalid,
2525
and the new error message suggests a proper usage
2626
(``ascending`` must be a boolean or a list-like boolean).
27+
- Fixed regression in :meth:`DataFrame.transform` and :meth:`Series.transform` giving incorrect column labels when passed a dictionary with a mix of list and non-list values (:issue:`40018`)
28+
-
2729

2830
.. ---------------------------------------------------------------------------
2931

pandas/core/apply.py

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def transform_dict_like(self, func):
264264
if len(func) == 0:
265265
raise ValueError("No transform functions were provided")
266266

267-
self.validate_dictlike_arg("transform", obj, func)
267+
func = self.normalize_dictlike_arg("transform", obj, func)
268268

269269
results: Dict[Hashable, FrameOrSeriesUnion] = {}
270270
for name, how in func.items():
@@ -405,32 +405,17 @@ def agg_dict_like(self, _axis: int) -> FrameOrSeriesUnion:
405405
-------
406406
Result of aggregation.
407407
"""
408+
from pandas.core.reshape.concat import concat
409+
408410
obj = self.obj
409411
arg = cast(AggFuncTypeDict, self.f)
410412

411-
is_aggregator = lambda x: isinstance(x, (list, tuple, dict))
412-
413413
if _axis != 0: # pragma: no cover
414414
raise ValueError("Can only pass dict with axis=0")
415415

416416
selected_obj = obj._selected_obj
417417

418-
self.validate_dictlike_arg("agg", selected_obj, arg)
419-
420-
# if we have a dict of any non-scalars
421-
# eg. {'A' : ['mean']}, normalize all to
422-
# be list-likes
423-
# Cannot use arg.values() because arg may be a Series
424-
if any(is_aggregator(x) for _, x in arg.items()):
425-
new_arg: AggFuncTypeDict = {}
426-
for k, v in arg.items():
427-
if not isinstance(v, (tuple, list, dict)):
428-
new_arg[k] = [v]
429-
else:
430-
new_arg[k] = v
431-
arg = new_arg
432-
433-
from pandas.core.reshape.concat import concat
418+
arg = self.normalize_dictlike_arg("agg", selected_obj, arg)
434419

435420
if selected_obj.ndim == 1:
436421
# key only used for output
@@ -524,14 +509,15 @@ def maybe_apply_multiple(self) -> Optional[FrameOrSeriesUnion]:
524509
return None
525510
return self.obj.aggregate(self.f, self.axis, *self.args, **self.kwargs)
526511

527-
def validate_dictlike_arg(
512+
def normalize_dictlike_arg(
528513
self, how: str, obj: FrameOrSeriesUnion, func: AggFuncTypeDict
529-
) -> None:
514+
) -> AggFuncTypeDict:
530515
"""
531-
Raise if dict-like argument is invalid.
516+
Handler for dict-like argument.
532517
533518
Ensures that necessary columns exist if obj is a DataFrame, and
534-
that a nested renamer is not passed.
519+
that a nested renamer is not passed. Also normalizes to all lists
520+
when values consists of a mix of list and non-lists.
535521
"""
536522
assert how in ("apply", "agg", "transform")
537523

@@ -551,6 +537,23 @@ def validate_dictlike_arg(
551537
cols_sorted = list(safe_sort(list(cols)))
552538
raise KeyError(f"Column(s) {cols_sorted} do not exist")
553539

540+
is_aggregator = lambda x: isinstance(x, (list, tuple, dict))
541+
542+
# if we have a dict of any non-scalars
543+
# eg. {'A' : ['mean']}, normalize all to
544+
# be list-likes
545+
# Cannot use func.values() because arg may be a Series
546+
if any(is_aggregator(x) for _, x in func.items()):
547+
new_func: AggFuncTypeDict = {}
548+
for k, v in func.items():
549+
if not is_aggregator(v):
550+
# mypy can't realize v is not a list here
551+
new_func[k] = [v] # type:ignore[list-item]
552+
else:
553+
new_func[k] = v
554+
func = new_func
555+
return func
556+
554557

555558
class FrameApply(Apply):
556559
obj: DataFrame

pandas/tests/apply/test_frame_transform.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,17 @@ def test_transform_dictlike(axis, float_frame, box):
103103
tm.assert_frame_equal(result, expected)
104104

105105

106+
def test_transform_dictlike_mixed():
107+
# GH 40018 - mix of lists and non-lists in values of a dictionary
108+
df = DataFrame({"a": [1, 2], "b": [1, 4], "c": [1, 4]})
109+
result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"})
110+
expected = DataFrame(
111+
[[1.0, 1, 1.0], [2.0, 4, 2.0]],
112+
columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]),
113+
)
114+
tm.assert_frame_equal(result, expected)
115+
116+
106117
@pytest.mark.parametrize(
107118
"ops",
108119
[

pandas/tests/apply/test_series_transform.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33

44
from pandas import (
5+
DataFrame,
6+
MultiIndex,
57
Series,
68
concat,
79
)
@@ -55,6 +57,17 @@ def test_transform_dictlike(string_series, box):
5557
tm.assert_frame_equal(result, expected)
5658

5759

60+
def test_transform_dictlike_mixed():
61+
# GH 40018 - mix of lists and non-lists in values of a dictionary
62+
df = Series([1, 4])
63+
result = df.transform({"b": ["sqrt", "abs"], "c": "sqrt"})
64+
expected = DataFrame(
65+
[[1.0, 1, 1.0], [2.0, 4, 2.0]],
66+
columns=MultiIndex([("b", "c"), ("sqrt", "abs")], [(0, 0, 1), (0, 1, 0)]),
67+
)
68+
tm.assert_frame_equal(result, expected)
69+
70+
5871
def test_transform_wont_agg(string_series):
5972
# GH 35964
6073
# we are trying to transform with an aggregator

0 commit comments

Comments
 (0)