Skip to content

Commit d01d6e8

Browse files
committed
fixup! Add case_when API * Used to support conditional assignment operation.
1 parent b897b8e commit d01d6e8

File tree

1 file changed

+29
-5
lines changed

1 file changed

+29
-5
lines changed

pandas/core/case_when.py

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
from __future__ import annotations
22

33
from typing import Any
4+
import warnings
5+
6+
from pandas.util._exceptions import find_stack_level
47

58
from pandas.core.dtypes.common import is_list_like
69

710
import pandas as pd
811
import pandas.core.common as com
912

1013

14+
def warn_and_override_index(series, series_type, index):
15+
warnings.warn(
16+
f"Series {series_type} will be reindexed to match obj index.",
17+
UserWarning,
18+
stacklevel=find_stack_level(),
19+
)
20+
return pd.Series(series.values, index=index)
21+
22+
1123
def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
1224
"""
1325
Returns a Series based on multiple conditions assignment.
1426
1527
This is useful when you want to assign a column based on multiple conditions.
1628
Uses `Series.mask` to perform the assignment.
1729
18-
The returned Series will always have a new index (reset).
30+
The returned Series have the same index as `obj`.
1931
2032
Parameters
2133
----------
@@ -105,7 +117,7 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
105117
2 -1
106118
Name: a, dtype: int64
107119
108-
The index is not maintained. For example:
120+
The index will always follow that of `obj`. For example:
109121
>>> df = pd.DataFrame(
110122
... dict(a=[1, 2, 3], b=[4, 5, 6]),
111123
... index=['index 1', 'index 2', 'index 3']
@@ -122,9 +134,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
122134
... df.b,
123135
... default=0,
124136
... )
125-
0 4
126-
1 0
127-
2 0
137+
index 1 4
138+
index 2 0
139+
index 3 0
128140
dtype: int64
129141
"""
130142
len_args = len(args)
@@ -154,6 +166,18 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
154166
# get replacements
155167
replacements = args[i + 1]
156168

169+
if isinstance(replacements, pd.Series) and not replacements.index.equals(
170+
obj.index
171+
):
172+
replacements = warn_and_override_index(
173+
replacements, f"(in args[{i+1}])", obj.index
174+
)
175+
176+
if isinstance(conditions, pd.Series) and not conditions.index.equals(obj.index):
177+
conditions = warn_and_override_index(
178+
conditions, f"(in args[{i}])", obj.index
179+
)
180+
157181
# `Series.mask` call
158182
series = series.mask(conditions, replacements)
159183

0 commit comments

Comments
 (0)