Skip to content

Commit 45113f8

Browse files
authored
BUG: Series.where with PeriodDtype raising on no-op (#45135)
1 parent f7f0cee commit 45113f8

File tree

3 files changed

+85
-50
lines changed

3 files changed

+85
-50
lines changed

doc/source/whatsnew/v1.4.0.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,8 @@ Period
908908
- Bug in :meth:`PeriodIndex.to_timestamp` when the index has ``freq="B"`` inferring ``freq="D"`` for its result instead of ``freq="B"`` (:issue:`44105`)
909909
- Bug in :class:`Period` constructor incorrectly allowing ``np.timedelta64("NaT")`` (:issue:`44507`)
910910
- Bug in :meth:`PeriodIndex.to_timestamp` giving incorrect values for indexes with non-contiguous data (:issue:`44100`)
911+
- Bug in :meth:`Series.where` with ``PeriodDtype`` incorrectly raising when the ``where`` call should not replace anything (:issue:`45135`)
912+
911913
-
912914

913915
Plotting

pandas/core/internals/blocks.py

Lines changed: 53 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1330,6 +1330,59 @@ class EABackedBlock(Block):
13301330

13311331
values: ExtensionArray
13321332

1333+
def where(self, other, cond) -> list[Block]:
1334+
arr = self.values.T
1335+
1336+
cond = extract_bool_array(cond)
1337+
1338+
other = self._maybe_squeeze_arg(other)
1339+
cond = self._maybe_squeeze_arg(cond)
1340+
1341+
if other is lib.no_default:
1342+
other = self.fill_value
1343+
1344+
icond, noop = validate_putmask(arr, ~cond)
1345+
if noop:
1346+
# GH#44181, GH#45135
1347+
# Avoid a) raising for Interval/PeriodDtype and b) unnecessary object upcast
1348+
return self.copy()
1349+
1350+
try:
1351+
res_values = arr._where(cond, other).T
1352+
except (ValueError, TypeError) as err:
1353+
if isinstance(err, ValueError):
1354+
# TODO(2.0): once DTA._validate_setitem_value deprecation
1355+
# is enforced, stop catching ValueError here altogether
1356+
if "Timezones don't match" not in str(err):
1357+
raise
1358+
1359+
if is_interval_dtype(self.dtype):
1360+
# TestSetitemFloatIntervalWithIntIntervalValues
1361+
blk = self.coerce_to_target_dtype(other)
1362+
if blk.dtype == _dtype_obj:
1363+
# For now at least only support casting e.g.
1364+
# Interval[int64]->Interval[float64]
1365+
raise
1366+
return blk.where(other, cond)
1367+
1368+
elif isinstance(self, NDArrayBackedExtensionBlock):
1369+
# NB: not (yet) the same as
1370+
# isinstance(values, NDArrayBackedExtensionArray)
1371+
if isinstance(self.dtype, PeriodDtype):
1372+
# TODO: don't special-case
1373+
# Note: this is the main place where the fallback logic
1374+
# is different from EABackedBlock.putmask.
1375+
raise
1376+
blk = self.coerce_to_target_dtype(other)
1377+
nbs = blk.where(other, cond)
1378+
return self._maybe_downcast(nbs, "infer")
1379+
1380+
else:
1381+
raise
1382+
1383+
nb = self.make_block_same_class(res_values)
1384+
return [nb]
1385+
13331386
def putmask(self, mask, new) -> list[Block]:
13341387
"""
13351388
See Block.putmask.__doc__
@@ -1648,36 +1701,6 @@ def shift(self, periods: int, axis: int = 0, fill_value: Any = None) -> list[Blo
16481701
new_values = self.values.shift(periods=periods, fill_value=fill_value)
16491702
return [self.make_block_same_class(new_values)]
16501703

1651-
def where(self, other, cond) -> list[Block]:
1652-
1653-
cond = extract_bool_array(cond)
1654-
assert not isinstance(other, (ABCIndex, ABCSeries, ABCDataFrame))
1655-
1656-
other = self._maybe_squeeze_arg(other)
1657-
cond = self._maybe_squeeze_arg(cond)
1658-
1659-
if other is lib.no_default:
1660-
other = self.fill_value
1661-
1662-
icond, noop = validate_putmask(self.values, ~cond)
1663-
if noop:
1664-
return self.copy()
1665-
1666-
try:
1667-
result = self.values._where(cond, other)
1668-
except TypeError:
1669-
if is_interval_dtype(self.dtype):
1670-
# TestSetitemFloatIntervalWithIntIntervalValues
1671-
blk = self.coerce_to_target_dtype(other)
1672-
if blk.dtype == _dtype_obj:
1673-
# For now at least only support casting e.g.
1674-
# Interval[int64]->Interval[float64]
1675-
raise
1676-
return blk.where(other, cond)
1677-
raise
1678-
1679-
return [self.make_block_same_class(result)]
1680-
16811704
def _unstack(
16821705
self,
16831706
unstacker,
@@ -1760,26 +1783,6 @@ def setitem(self, indexer, value):
17601783
values[indexer] = value
17611784
return self
17621785

1763-
def where(self, other, cond) -> list[Block]:
1764-
arr = self.values
1765-
1766-
cond = extract_bool_array(cond)
1767-
if other is lib.no_default:
1768-
other = self.fill_value
1769-
1770-
try:
1771-
res_values = arr.T._where(cond, other).T
1772-
except (ValueError, TypeError):
1773-
if isinstance(self.dtype, PeriodDtype):
1774-
# TODO: don't special-case
1775-
raise
1776-
blk = self.coerce_to_target_dtype(other)
1777-
nbs = blk.where(other, cond)
1778-
return self._maybe_downcast(nbs, "infer")
1779-
1780-
nb = self.make_block_same_class(res_values)
1781-
return [nb]
1782-
17831786
def diff(self, n: int, axis: int = 0) -> list[Block]:
17841787
"""
17851788
1st discrete difference.

pandas/tests/frame/indexing/test_where.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,36 @@ def test_where_interval_noop(self):
711711
res = ser.where(ser.notna())
712712
tm.assert_series_equal(res, ser)
713713

714+
@pytest.mark.parametrize(
715+
"dtype",
716+
[
717+
"timedelta64[ns]",
718+
"datetime64[ns]",
719+
"datetime64[ns, Asia/Tokyo]",
720+
"Period[D]",
721+
],
722+
)
723+
def test_where_datetimelike_noop(self, dtype):
724+
# GH#45135, analogue to GH#44181 for Period don't raise on no-op
725+
# For td64/dt64/dt64tz we already don't raise, but also are
726+
# checking that we don't unnecessarily upcast to object.
727+
ser = Series(np.arange(3) * 10 ** 9, dtype=np.int64).view(dtype)
728+
df = ser.to_frame()
729+
mask = np.array([False, False, False])
730+
731+
res = ser.where(~mask, "foo")
732+
tm.assert_series_equal(res, ser)
733+
734+
mask2 = mask.reshape(-1, 1)
735+
res2 = df.where(~mask2, "foo")
736+
tm.assert_frame_equal(res2, df)
737+
738+
res3 = ser.mask(mask, "foo")
739+
tm.assert_series_equal(res3, ser)
740+
741+
res4 = df.mask(mask2, "foo")
742+
tm.assert_frame_equal(res4, df)
743+
714744

715745
def test_where_try_cast_deprecated(frame_or_series):
716746
obj = DataFrame(np.random.randn(4, 3))

0 commit comments

Comments
 (0)