Skip to content

REF: de-duplicate reso-casting code #48953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pandas/_libs/tslibs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,13 @@
"periods_per_day",
"periods_per_second",
"is_supported_unit",
"npy_unit_to_abbrev",
]

from pandas._libs.tslibs import dtypes
from pandas._libs.tslibs.conversion import localize_pydatetime
from pandas._libs.tslibs.dtypes import (
Resolution,
is_supported_unit,
npy_unit_to_abbrev,
periods_per_day,
periods_per_second,
)
Expand Down
3 changes: 3 additions & 0 deletions pandas/_libs/tslibs/timedeltas.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def delta_to_nanoseconds(
) -> int: ...

class Timedelta(timedelta):
_reso: int
min: ClassVar[Timedelta]
max: ClassVar[Timedelta]
resolution: ClassVar[Timedelta]
Expand Down Expand Up @@ -153,4 +154,6 @@ class Timedelta(timedelta):
def freq(self) -> None: ...
@property
def is_populated(self) -> bool: ...
@property
def _unit(self) -> str: ...
def _as_unit(self, unit: str, round_ok: bool = ...) -> Timedelta: ...
7 changes: 7 additions & 0 deletions pandas/_libs/tslibs/timedeltas.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,13 @@ cdef class _Timedelta(timedelta):
max = MinMaxReso("max")
resolution = MinMaxReso("resolution")

@property
def _unit(self) -> str:
"""
The abbreviation associated with self._reso.
"""
return npy_unit_to_abbrev(self._reso)

@property
def days(self) -> int: # TODO(cython3): make cdef property
# NB: using the python C-API PyDateTime_DELTA_GET_DAYS will fail
Expand Down
2 changes: 2 additions & 0 deletions pandas/_libs/tslibs/timestamps.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,6 @@ class Timestamp(datetime):
def days_in_month(self) -> int: ...
@property
def daysinmonth(self) -> int: ...
@property
def _unit(self) -> str: ...
def _as_unit(self, unit: str, round_ok: bool = ...) -> Timestamp: ...
7 changes: 7 additions & 0 deletions pandas/_libs/tslibs/timestamps.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ cdef class _Timestamp(ABCTimestamp):
)
return self._freq

@property
def _unit(self) -> str:
"""
The abbreviation associated with self._reso.
"""
return npy_unit_to_abbrev(self._reso)

# -----------------------------------------------------------------
# Constructors

Expand Down
52 changes: 20 additions & 32 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
iNaT,
ints_to_pydatetime,
ints_to_pytimedelta,
npy_unit_to_abbrev,
to_offset,
)
from pandas._libs.tslibs.fields import (
Expand Down Expand Up @@ -1169,13 +1168,8 @@ def _add_datetimelike_scalar(self, other) -> DatetimeArray:
# Preserve our resolution
return DatetimeArray._simple_new(result, dtype=result.dtype)

if self._reso != other._reso:
# Just as with Timestamp/Timedelta, we cast to the higher resolution
if self._reso < other._reso:
unit = npy_unit_to_abbrev(other._reso)
self = self._as_unit(unit)
else:
other = other._as_unit(self._unit)
self, other = self._ensure_matching_resos(other)
self = cast("TimedeltaArray", self)

i8 = self.asi8
result = checked_add_with_arr(i8, other.value, arr_mask=self._isnan)
Expand Down Expand Up @@ -1208,16 +1202,10 @@ def _sub_datetimelike_scalar(self, other: datetime | np.datetime64):
# i.e. np.datetime64("NaT")
return self - NaT

other = Timestamp(other)
ts = Timestamp(other)

if other._reso != self._reso:
if other._reso < self._reso:
other = other._as_unit(self._unit)
else:
unit = npy_unit_to_abbrev(other._reso)
self = self._as_unit(unit)

return self._sub_datetimelike(other)
self, ts = self._ensure_matching_resos(ts)
return self._sub_datetimelike(ts)

@final
def _sub_datetime_arraylike(self, other):
Expand All @@ -1230,12 +1218,7 @@ def _sub_datetime_arraylike(self, other):
self = cast("DatetimeArray", self)
other = ensure_wrapped_if_datetimelike(other)

if other._reso != self._reso:
if other._reso < self._reso:
other = other._as_unit(self._unit)
else:
self = self._as_unit(other._unit)

self, other = self._ensure_matching_resos(other)
return self._sub_datetimelike(other)

@final
Expand Down Expand Up @@ -1319,17 +1302,11 @@ def _add_timedelta_arraylike(
raise ValueError("cannot add indices of unequal length")

other = ensure_wrapped_if_datetimelike(other)
other = cast("TimedeltaArray", other)
tda = cast("TimedeltaArray", other)
self = cast("DatetimeArray | TimedeltaArray", self)

if self._reso != other._reso:
# Just as with Timestamp/Timedelta, we cast to the higher resolution
if self._reso < other._reso:
self = self._as_unit(other._unit)
else:
other = other._as_unit(self._unit)

return self._add_timedeltalike(other)
self, tda = self._ensure_matching_resos(tda)
return self._add_timedeltalike(tda)

@final
def _add_timedeltalike(self, other: Timedelta | TimedeltaArray):
Expand Down Expand Up @@ -2098,6 +2075,17 @@ def _as_unit(self: TimelikeOpsT, unit: str) -> TimelikeOpsT:
new_values, dtype=new_dtype, freq=self.freq # type: ignore[call-arg]
)

# TODO: annotate other as DatetimeArray | TimedeltaArray | Timestamp | Timedelta
# with the return type matching input type. TypeVar?
def _ensure_matching_resos(self, other):
if self._reso != other._reso:
# Just as with Timestamp/Timedelta, we cast to the higher resolution
if self._reso < other._reso:
self = self._as_unit(other._unit)
else:
other = other._as_unit(self._unit)
return self, other

# --------------------------------------------------------------

def __array_ufunc__(self, ufunc: np.ufunc, method: str, *inputs, **kwargs):
Expand Down
13 changes: 5 additions & 8 deletions pandas/tests/scalar/timestamp/test_timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@
utc,
)

from pandas._libs.tslibs.dtypes import (
NpyDatetimeUnit,
npy_unit_to_abbrev,
)
from pandas._libs.tslibs.dtypes import NpyDatetimeUnit
from pandas._libs.tslibs.timezones import (
dateutil_gettz as gettz,
get_timezone,
Expand Down Expand Up @@ -964,7 +961,7 @@ def test_sub_datetimelike_mismatched_reso(self, ts_tz):
if ts._reso < other._reso:
# Case where rounding is lossy
other2 = other + Timedelta._from_value_and_reso(1, other._reso)
exp = ts._as_unit(npy_unit_to_abbrev(other._reso)) - other2
exp = ts._as_unit(other._unit) - other2

res = ts - other2
assert res == exp
Expand All @@ -975,7 +972,7 @@ def test_sub_datetimelike_mismatched_reso(self, ts_tz):
assert res._reso == max(ts._reso, other._reso)
else:
ts2 = ts + Timedelta._from_value_and_reso(1, ts._reso)
exp = ts2 - other._as_unit(npy_unit_to_abbrev(ts2._reso))
exp = ts2 - other._as_unit(ts2._unit)

res = ts2 - other
assert res == exp
Expand Down Expand Up @@ -1012,7 +1009,7 @@ def test_sub_timedeltalike_mismatched_reso(self, ts_tz):
if ts._reso < other._reso:
# Case where rounding is lossy
other2 = other + Timedelta._from_value_and_reso(1, other._reso)
exp = ts._as_unit(npy_unit_to_abbrev(other._reso)) + other2
exp = ts._as_unit(other._unit) + other2
res = ts + other2
assert res == exp
assert res._reso == max(ts._reso, other._reso)
Expand All @@ -1021,7 +1018,7 @@ def test_sub_timedeltalike_mismatched_reso(self, ts_tz):
assert res._reso == max(ts._reso, other._reso)
else:
ts2 = ts + Timedelta._from_value_and_reso(1, ts._reso)
exp = ts2 + other._as_unit(npy_unit_to_abbrev(ts2._reso))
exp = ts2 + other._as_unit(ts2._unit)

res = ts2 + other
assert res == exp
Expand Down
1 change: 0 additions & 1 deletion pandas/tests/tslibs/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def test_namespace():
"periods_per_day",
"periods_per_second",
"is_supported_unit",
"npy_unit_to_abbrev",
]

expected = set(submodules + api)
Expand Down