Skip to content

Commit 4c5382a

Browse files
added implementation + tests for scalar condition in .where
1 parent 3f394c9 commit 4c5382a

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

pandas/core/generic.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10102,7 +10102,12 @@ def _where(
1010210102
else:
1010310103
if not hasattr(cond, "shape"):
1010410104
cond = np.asanyarray(cond)
10105-
if cond.shape != () and cond.shape != self.shape:
10105+
if cond.shape == ():
10106+
# Note: DataFrame(True, index=[1,2,3], columns=["a", "b", "c"]) works
10107+
# but DataFrame(np.array(True), index=[1,2,3], columns=["a", "b", "c"]) does not
10108+
# hence we need to unpack scalar
10109+
cond = cond.item()
10110+
elif cond.shape != self.shape:
1010610111
raise ValueError("Array conditional must be same shape as self")
1010710112
cond = self._constructor(cond, **self._construct_axes_dict(), copy=False)
1010810113

pandas/tests/frame/indexing/test_where.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,15 @@ def test_where_invalid(self):
159159
with pytest.raises(ValueError, match=msg):
160160
df.where(err2, other1)
161161

162-
with pytest.raises(ValueError, match=msg):
163-
df.mask(True)
164-
with pytest.raises(ValueError, match=msg):
165-
df.mask(0)
162+
def test_where_scalar_cond(self):
163+
df = DataFrame(np.random.randn(5, 3), columns=["A", "B", "C"])
164+
result = df.where(True)
165+
expected = df
166+
tm.assert_frame_equal(result, expected)
167+
168+
result = df.where(False)
169+
expected = DataFrame(np.nan, index=df.index, columns=df.columns)
170+
tm.assert_frame_equal(result, expected)
166171

167172
def test_where_set(self, where_frame, float_string_frame):
168173
# where inplace

pandas/tests/series/indexing/test_where.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,6 @@ def test_where_error():
147147
cond = s > 0
148148

149149
msg = "Array conditional must be same shape as self"
150-
with pytest.raises(ValueError, match=msg):
151-
s.where(1)
152150
with pytest.raises(ValueError, match=msg):
153151
s.where(cond[:3].values, -s)
154152

@@ -466,7 +464,7 @@ def test_where_datetimelike_categorical(tz_naive_fixture):
466464
tm.assert_frame_equal(res, pd.DataFrame(dr))
467465

468466

469-
def test_where_scalar_cond(self):
467+
def test_where_scalar_cond():
470468
# True
471469
ser = Series(pd.Categorical(["a", "b"]))
472470
result = ser.where(True)
@@ -477,4 +475,4 @@ def test_where_scalar_cond(self):
477475
ser = Series(pd.Categorical(["a", "b"]))
478476
result = ser.where(False)
479477
expected = Series(pd.Categorical([None, None], categories=["a", "b"]))
480-
tm.assert_series_equal(result, expected)
478+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)