Skip to content

Commit 3f394c9

Browse files
added scalar condition to Series.where
1 parent 28a0c65 commit 3f394c9

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

pandas/core/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10102,7 +10102,7 @@ def _where(
1010210102
else:
1010310103
if not hasattr(cond, "shape"):
1010410104
cond = np.asanyarray(cond)
10105-
if cond.shape != self.shape:
10105+
if cond.shape != () and cond.shape != self.shape:
1010610106
raise ValueError("Array conditional must be same shape as self")
1010710107
cond = self._constructor(cond, **self._construct_axes_dict(), copy=False)
1010810108

pandas/tests/series/indexing/test_where.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,3 +464,17 @@ def test_where_datetimelike_categorical(tz_naive_fixture):
464464
res = pd.DataFrame(lvals).where(mask[:, None], pd.DataFrame(rvals))
465465

466466
tm.assert_frame_equal(res, pd.DataFrame(dr))
467+
468+
469+
def test_where_scalar_cond(self):
470+
# True
471+
ser = Series(pd.Categorical(["a", "b"]))
472+
result = ser.where(True)
473+
expected = ser
474+
tm.assert_series_equal(result, expected)
475+
476+
# False
477+
ser = Series(pd.Categorical(["a", "b"]))
478+
result = ser.where(False)
479+
expected = Series(pd.Categorical([None, None], categories=["a", "b"]))
480+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)