1
1
from __future__ import annotations
2
2
3
3
from typing import Any
4
+ import warnings
5
+
6
+ from pandas .util ._exceptions import find_stack_level
4
7
5
8
from pandas .core .dtypes .common import is_list_like
6
9
7
10
import pandas as pd
8
11
import pandas .core .common as com
9
12
10
13
14
+ def warn_and_override_index (series , series_type , index ):
15
+ warnings .warn (
16
+ f"Series { series_type } will be reindexed to match obj index." ,
17
+ UserWarning ,
18
+ stacklevel = find_stack_level (),
19
+ )
20
+ return pd .Series (series .values , index = index )
21
+
22
+
11
23
def case_when (obj : pd .DataFrame | pd .Series , * args , default : Any ) -> pd .Series :
12
24
"""
13
25
Returns a Series based on multiple conditions assignment.
14
26
15
27
This is useful when you want to assign a column based on multiple conditions.
16
28
Uses `Series.mask` to perform the assignment.
17
29
18
- The returned Series will always have a new index (reset) .
30
+ The returned Series have the same index as `obj` .
19
31
20
32
Parameters
21
33
----------
@@ -105,7 +117,7 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
105
117
2 -1
106
118
Name: a, dtype: int64
107
119
108
- The index is not maintained . For example:
120
+ The index will always follow that of `obj` . For example:
109
121
>>> df = pd.DataFrame(
110
122
... dict(a=[1, 2, 3], b=[4, 5, 6]),
111
123
... index=['index 1', 'index 2', 'index 3']
@@ -122,9 +134,9 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
122
134
... df.b,
123
135
... default=0,
124
136
... )
125
- 0 4
126
- 1 0
127
- 2 0
137
+ index 1 4
138
+ index 2 0
139
+ index 3 0
128
140
dtype: int64
129
141
"""
130
142
len_args = len (args )
@@ -154,6 +166,18 @@ def case_when(obj: pd.DataFrame | pd.Series, *args, default: Any) -> pd.Series:
154
166
# get replacements
155
167
replacements = args [i + 1 ]
156
168
169
+ if isinstance (replacements , pd .Series ) and not replacements .index .equals (
170
+ obj .index
171
+ ):
172
+ replacements = warn_and_override_index (
173
+ replacements , f"(in args[{ i + 1 } ])" , obj .index
174
+ )
175
+
176
+ if isinstance (conditions , pd .Series ) and not conditions .index .equals (obj .index ):
177
+ conditions = warn_and_override_index (
178
+ conditions , f"(in args[{ i } ])" , obj .index
179
+ )
180
+
157
181
# `Series.mask` call
158
182
series = series .mask (conditions , replacements )
159
183
0 commit comments