Skip to content

Commit 45bbe9a

Browse files
authored
REF: dispatch TDBlock._can_hold_element to TimedeltaArray._validate_setitem_value (#38674)
1 parent 6c2631b commit 45bbe9a

File tree

3 files changed

+59
-18
lines changed

3 files changed

+59
-18
lines changed

pandas/core/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8858,7 +8858,7 @@ def _where(
88588858
elif len(cond[icond]) == len(other):
88598859

88608860
# try to not change dtype at first
8861-
new_other = np.asarray(self)
8861+
new_other = self._values
88628862
new_other = new_other.copy()
88638863
new_other[icond] = other
88648864
other = new_other

pandas/core/internals/blocks.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from datetime import timedelta
21
import inspect
32
import re
43
from typing import TYPE_CHECKING, Any, List, Optional, Type, Union, cast
@@ -8,7 +7,6 @@
87

98
from pandas._libs import (
109
Interval,
11-
NaT,
1210
Period,
1311
Timestamp,
1412
algos as libalgos,
@@ -86,6 +84,7 @@
8684

8785
if TYPE_CHECKING:
8886
from pandas import Index
87+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
8988

9089

9190
class Block(PandasObject):
@@ -919,7 +918,11 @@ def setitem(self, indexer, value):
919918
if self._can_hold_element(value):
920919
# We only get here for non-Extension Blocks, so _try_coerce_args
921920
# is only relevant for DatetimeBlock and TimedeltaBlock
922-
if lib.is_scalar(value):
921+
if self.dtype.kind in ["m", "M"]:
922+
arr = self.array_values().T
923+
arr[indexer] = value
924+
return self
925+
elif lib.is_scalar(value):
923926
value = convert_scalar_for_putitemlike(value, values.dtype)
924927

925928
else:
@@ -1064,6 +1067,17 @@ def putmask(
10641067
if self._can_hold_element(new):
10651068
# We only get here for non-Extension Blocks, so _try_coerce_args
10661069
# is only relevant for DatetimeBlock and TimedeltaBlock
1070+
if self.dtype.kind in ["m", "M"]:
1071+
blk = self
1072+
if not inplace:
1073+
blk = self.copy()
1074+
arr = blk.array_values()
1075+
arr = cast("NDArrayBackedExtensionArray", arr)
1076+
if transpose:
1077+
arr = arr.T
1078+
arr.putmask(mask, new)
1079+
return [blk]
1080+
10671081
if lib.is_scalar(new):
10681082
new = convert_scalar_for_putitemlike(new, self.values.dtype)
10691083

@@ -2376,16 +2390,6 @@ def _maybe_coerce_values(self, values):
23762390
def _holder(self):
23772391
return TimedeltaArray
23782392

2379-
def _can_hold_element(self, element: Any) -> bool:
2380-
tipo = maybe_infer_dtype_type(element)
2381-
if tipo is not None:
2382-
return issubclass(tipo.type, np.timedelta64)
2383-
elif element is NaT:
2384-
return True
2385-
elif isinstance(element, (timedelta, np.timedelta64)):
2386-
return True
2387-
return is_valid_nat_for_dtype(element, self.dtype)
2388-
23892393
def fillna(self, value, **kwargs):
23902394
# TODO(EA2D): if we operated on array_values, TDA.fillna would handle
23912395
# raising here.

pandas/tests/indexing/test_indexing.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from pandas.core.dtypes.common import is_float_dtype, is_integer_dtype
1111

1212
import pandas as pd
13-
from pandas import DataFrame, Index, NaT, Series, date_range
13+
from pandas import DataFrame, Index, NaT, Series, date_range, offsets, timedelta_range
1414
import pandas._testing as tm
1515
from pandas.core.indexing import maybe_numeric_slice, non_reducing_slice
1616
from pandas.tests.indexing.common import _mklbl
@@ -970,19 +970,22 @@ class TestDatetimelikeCoercion:
970970
@pytest.mark.parametrize("indexer", [setitem, loc, iloc])
971971
def test_setitem_dt64_string_scalar(self, tz_naive_fixture, indexer):
972972
# dispatching _can_hold_element to underling DatetimeArray
973-
# TODO(EA2D) use tz_naive_fixture once DatetimeBlock is backed by DTA
974973
tz = tz_naive_fixture
975974

976975
dti = date_range("2016-01-01", periods=3, tz=tz)
977976
ser = Series(dti)
978977

979978
values = ser._values
980979

981-
indexer(ser)[0] = "2018-01-01"
980+
newval = "2018-01-01"
981+
values._validate_setitem_value(newval)
982+
983+
indexer(ser)[0] = newval
982984

983985
if tz is None:
984986
# TODO(EA2D): we can make this no-copy in tz-naive case too
985987
assert ser.dtype == dti.dtype
988+
assert ser._values._data is values._data
986989
else:
987990
assert ser._values is values
988991

@@ -993,7 +996,6 @@ def test_setitem_dt64_string_scalar(self, tz_naive_fixture, indexer):
993996
@pytest.mark.parametrize("indexer", [setitem, loc, iloc])
994997
def test_setitem_dt64_string_values(self, tz_naive_fixture, indexer, key, box):
995998
# dispatching _can_hold_element to underling DatetimeArray
996-
# TODO(EA2D) use tz_naive_fixture once DatetimeBlock is backed by DTA
997999
tz = tz_naive_fixture
9981000

9991001
if isinstance(key, slice) and indexer is loc:
@@ -1012,9 +1014,44 @@ def test_setitem_dt64_string_values(self, tz_naive_fixture, indexer, key, box):
10121014
if tz is None:
10131015
# TODO(EA2D): we can make this no-copy in tz-naive case too
10141016
assert ser.dtype == dti.dtype
1017+
assert ser._values._data is values._data
10151018
else:
10161019
assert ser._values is values
10171020

1021+
@pytest.mark.parametrize("scalar", ["3 Days", offsets.Hour(4)])
1022+
@pytest.mark.parametrize("indexer", [setitem, loc, iloc])
1023+
def test_setitem_td64_scalar(self, indexer, scalar):
1024+
# dispatching _can_hold_element to underling TimedeltaArray
1025+
tdi = timedelta_range("1 Day", periods=3)
1026+
ser = Series(tdi)
1027+
1028+
values = ser._values
1029+
values._validate_setitem_value(scalar)
1030+
1031+
indexer(ser)[0] = scalar
1032+
assert ser._values._data is values._data
1033+
1034+
@pytest.mark.parametrize("box", [list, np.array, pd.array])
1035+
@pytest.mark.parametrize(
1036+
"key", [[0, 1], slice(0, 2), np.array([True, True, False])]
1037+
)
1038+
@pytest.mark.parametrize("indexer", [setitem, loc, iloc])
1039+
def test_setitem_td64_string_values(self, indexer, key, box):
1040+
# dispatching _can_hold_element to underling TimedeltaArray
1041+
if isinstance(key, slice) and indexer is loc:
1042+
key = slice(0, 1)
1043+
1044+
tdi = timedelta_range("1 Day", periods=3)
1045+
ser = Series(tdi)
1046+
1047+
values = ser._values
1048+
1049+
newvals = box(["10 Days", "44 hours"])
1050+
values._validate_setitem_value(newvals)
1051+
1052+
indexer(ser)[key] = newvals
1053+
assert ser._values._data is values._data
1054+
10181055

10191056
def test_extension_array_cross_section():
10201057
# A cross-section of a homogeneous EA should be an EA

0 commit comments

Comments
 (0)