-
-
Notifications
You must be signed in to change notification settings - Fork 18.5k
ENH: Add case_when method #56059
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
phofl
merged 43 commits into
pandas-dev:main
from
samukweku:samukweku/case_when_function
Jan 9, 2024
Merged
ENH: Add case_when method #56059
Changes from 11 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
f48502f
updates
samukweku 40057c7
add test for default if Series
samukweku 4a8be16
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku 089bbe6
updates based on feedback
samukweku bcfd458
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku b95ce55
updates based on feedback
samukweku acc3fdb
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku 8be4349
update typing hints for *args, based on feedback
samukweku 8d08458
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku ec18086
update typehints; add caselist argument - based on feedback
samukweku 0b72fbb
cleanup docstrings
samukweku 0085956
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku a441481
support method only for case_when
samukweku 29ad697
minor update
samukweku bf740f9
fix test
samukweku 264a675
remove redundant tests
samukweku 2a3035e
cleanup docs
samukweku 5e33304
use singular version - common_dtype
samukweku 5c7c287
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku 8569cd1
fix doctest failure
samukweku bbb5887
fix for whatnew
samukweku e03e3dc
Update doc/source/whatsnew/v2.2.0.rst
samukweku 283488f
Update v2.2.0.rst
phofl 7a8694c
Update v2.2.0.rst
phofl f6cf725
Merge remote-tracking branch 'upstream/main' into samukweku/case_when…
samukweku 67dfcaa
improve typing and add test for callable
samukweku 3da7cf2
fix typing error
samukweku bdc54f6
Update pandas/core/series.py
samukweku 649fb84
Merge branch 'main' into samukweku/case_when_function
rhshadrach b68d20e
Update doc/source/whatsnew/v2.2.0.rst
samukweku b4de208
PERF: resolution, is_normalized (#56637)
jbrockmendel 5966bfe
TYP: more simple return types from ruff (#56628)
twoertwein 3e404fa
ENH: Update CFF with publication reference, Zenodo DOI, and other det…
cgobat 21659bc
DOC: Fixup CoW userguide (#56636)
phofl f6d8cd0
REF: check monotonicity inside _can_use_libjoin (#55342)
jbrockmendel becc626
DOC: Minor fixups for 2.2.0 whatsnew (#56632)
rhshadrach 918a19e
TYP: Fix some PythonParser and Plotting types (#56643)
twoertwein 5744df2
BUG: Series.to_numpy raising for arrow floats to numpy floats (#56644)
phofl bc6ba0e
updates based on feedback
samukweku a0f4797
add to API reference
samukweku cb7d6e3
fix whitespace
samukweku c8f0e2e
updates
samukweku 9679b9e
Update series.py
samukweku File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import numpy as np | ||
|
||
from pandas._libs import lib | ||
|
||
from pandas.core.dtypes.cast import ( | ||
construct_1d_arraylike_from_scalar, | ||
find_common_type, | ||
infer_dtype_from, | ||
) | ||
from pandas.core.dtypes.common import is_scalar | ||
from pandas.core.dtypes.generic import ABCSeries | ||
|
||
from pandas.core.construction import array as pd_array | ||
|
||
if TYPE_CHECKING: | ||
from pandas._typing import ( | ||
ArrayLike, | ||
Scalar, | ||
Series, | ||
) | ||
|
||
|
||
def case_when( | ||
caselist: list[tuple[ArrayLike, ArrayLike | Scalar]], | ||
default: ArrayLike | Scalar = lib.no_default, | ||
) -> Series: | ||
""" | ||
Replace values where the conditions are True. | ||
|
||
Parameters | ||
---------- | ||
caselist : List of tuples of conditions and expected replacements. | ||
Takes the form: ``(condition0, replacement0)``, | ||
``(condition1, replacement1)``, ... . | ||
``condition`` should be a 1-D boolean array. | ||
When multiple boolean conditions are satisfied, | ||
the first replacement is used. | ||
If ``condition`` is a Series, and the equivalent ``replacement`` | ||
is a Series, they must have the same index. | ||
If there are multiple replacement options, | ||
and they are Series, they must have the same index. | ||
|
||
default : scalar, array-like, default None | ||
If provided, it is the replacement value to use | ||
if all conditions evaluate to False. | ||
If not specified, entries will be filled with the | ||
corresponding NULL value. | ||
|
||
.. versionadded:: 2.2.0 | ||
|
||
Returns | ||
------- | ||
Series | ||
|
||
See Also | ||
-------- | ||
Series.mask : Replace values where the condition is True. | ||
|
||
Examples | ||
-------- | ||
>>> df = pd.DataFrame({ | ||
... "a": [0,0,1,2], | ||
... "b": [0,3,4,5], | ||
... "c": [6,7,8,9] | ||
... }) | ||
>>> df | ||
a b c | ||
0 0 0 6 | ||
1 0 3 7 | ||
2 1 4 8 | ||
3 2 5 9 | ||
|
||
>>> caselist = [(df.a.gt(0), df.a), (df.b.gt(0), df.b)] # condition, replacement | ||
>>> pd.case_when(caselist=caselist, default=df.c) # default is optional | ||
0 6 | ||
1 3 | ||
2 1 | ||
3 2 | ||
Name: c, dtype: int64 | ||
""" | ||
from pandas import Series | ||
|
||
validate_case_when(caselist=caselist) | ||
|
||
conditions, replacements = zip(*caselist) | ||
common_dtypes = [infer_dtype_from(replacement)[0] for replacement in replacements] | ||
|
||
if default is not lib.no_default: | ||
arg_dtype, _ = infer_dtype_from(default) | ||
common_dtypes.append(arg_dtype) | ||
else: | ||
default = None | ||
if len(set(common_dtypes)) > 1: | ||
common_dtypes = find_common_type(common_dtypes) | ||
updated_replacements = [] | ||
for condition, replacement in zip(conditions, replacements): | ||
if is_scalar(replacement): | ||
replacement = construct_1d_arraylike_from_scalar( | ||
value=replacement, length=len(condition), dtype=common_dtypes | ||
) | ||
elif isinstance(replacement, ABCSeries): | ||
replacement = replacement.astype(common_dtypes) | ||
else: | ||
replacement = pd_array(replacement, dtype=common_dtypes) | ||
updated_replacements.append(replacement) | ||
replacements = updated_replacements | ||
if (default is not None) and isinstance(default, ABCSeries): | ||
default = default.astype(common_dtypes) | ||
else: | ||
common_dtypes = common_dtypes[0] | ||
if not isinstance(default, ABCSeries): | ||
cond_indices = [cond for cond in conditions if isinstance(cond, ABCSeries)] | ||
replacement_indices = [ | ||
replacement | ||
for replacement in replacements | ||
if isinstance(replacement, ABCSeries) | ||
] | ||
cond_length = None | ||
if replacement_indices: | ||
for left, right in zip(replacement_indices, replacement_indices[1:]): | ||
if not left.index.equals(right.index): | ||
raise AssertionError( | ||
"All replacement objects must have the same index." | ||
) | ||
if cond_indices: | ||
for left, right in zip(cond_indices, cond_indices[1:]): | ||
if not left.index.equals(right.index): | ||
raise AssertionError( | ||
"All condition objects must have the same index." | ||
) | ||
if replacement_indices: | ||
if not replacement_indices[0].index.equals(cond_indices[0].index): | ||
raise AssertionError( | ||
"All replacement objects and condition objects " | ||
"should have the same index." | ||
) | ||
else: | ||
conditions = [ | ||
np.asanyarray(cond) if not hasattr(cond, "shape") else cond | ||
for cond in conditions | ||
] | ||
cond_length = {len(cond) for cond in conditions} | ||
if len(cond_length) > 1: | ||
raise ValueError("The boolean conditions should have the same length.") | ||
cond_length = len(conditions[0]) | ||
if not is_scalar(default): | ||
if len(default) != cond_length: | ||
raise ValueError( | ||
"length of `default` does not match the length " | ||
"of any of the conditions." | ||
) | ||
if not replacement_indices: | ||
for num, replacement in enumerate(replacements): | ||
if is_scalar(replacement): | ||
continue | ||
if not hasattr(replacement, "shape"): | ||
replacement = np.asanyarray(replacement) | ||
if len(replacement) != cond_length: | ||
raise ValueError( | ||
f"Length of condition{num} does not match " | ||
f"the length of replacement{num}; " | ||
f"{cond_length} != {len(replacement)}" | ||
) | ||
if cond_indices: | ||
default_index = cond_indices[0].index | ||
elif replacement_indices: | ||
default_index = replacement_indices[0].index | ||
else: | ||
default_index = range(cond_length) | ||
default = Series(default, index=default_index, dtype=common_dtypes) | ||
counter = reversed(range(len(conditions))) | ||
for position, condition, replacement in zip( | ||
counter, conditions[::-1], replacements[::-1] | ||
): | ||
try: | ||
default = default.mask( | ||
condition, other=replacement, axis=0, inplace=False, level=None | ||
) | ||
except Exception as error: | ||
raise ValueError( | ||
f"Failed to apply condition{position} and replacement{position}." | ||
) from error | ||
return default | ||
|
||
|
||
def validate_case_when(caselist: list) -> None: | ||
""" | ||
Validates the arguments for the case_when function. | ||
""" | ||
|
||
if not isinstance(caselist, list): | ||
raise TypeError( | ||
f"The caselist argument should be a list; instead got {type(caselist)}" | ||
) | ||
|
||
if not len(caselist): | ||
raise ValueError( | ||
"provide at least one boolean condition, " | ||
"with a corresponding replacement." | ||
) | ||
|
||
for num, entry in enumerate(caselist): | ||
if not isinstance(entry, tuple): | ||
raise TypeError( | ||
f"Argument {num} must be a tuple; instead got {type(entry)}." | ||
) | ||
if len(entry) != 2: | ||
raise ValueError( | ||
f"Argument {num} must have length 2; " | ||
"a condition and replacement; " | ||
f"instead got length {len(entry)}." | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @MarcoGorelli
This might upcast, thoughts related to PDEP6?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this is related to PDEP6, we are creating a new Series and not doing something like setitem.