Skip to content

BUG: merge_asof raising KeyError for extension dtypes #53458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ Groupby/resample/rolling
Reshaping
^^^^^^^^^
- Bug in :func:`crosstab` when ``dropna=False`` would not keep ``np.nan`` in the result (:issue:`10772`)
- Bug in :func:`merge_asof` raising ``KeyError`` for extension dtypes (:issue:`52904`)
- Bug in :meth:`DataFrame.agg` and :meth:`Series.agg` on non-unique columns would return incorrect type when dist-like argument passed in (:issue:`51099`)
- Bug in :meth:`DataFrame.idxmin` and :meth:`DataFrame.idxmax`, where the axis dtype would be lost for empty frames (:issue:`53265`)
- Bug in :meth:`DataFrame.merge` not merging correctly when having ``MultiIndex`` with single level (:issue:`52331`)
Expand Down
66 changes: 41 additions & 25 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def floordiv_compat(
)

from pandas import Series
from pandas.core.arrays.datetimes import DatetimeArray
from pandas.core.arrays.timedeltas import TimedeltaArray


def get_unit_from_pa_dtype(pa_dtype):
Expand Down Expand Up @@ -1168,6 +1170,41 @@ def take(
indices_array[indices_array < 0] += len(self._pa_array)
return type(self)(self._pa_array.take(indices_array))

def _maybe_convert_datelike_array(self):
"""Maybe convert to a datelike array."""
pa_type = self._pa_array.type
if pa.types.is_timestamp(pa_type):
return self._to_datetimearray()
elif pa.types.is_duration(pa_type):
return self._to_timedeltaarray()
return self

def _to_datetimearray(self) -> DatetimeArray:
"""Convert a pyarrow timestamp typed array to a DatetimeArray."""
from pandas.core.arrays.datetimes import (
DatetimeArray,
tz_to_dtype,
)

pa_type = self._pa_array.type
assert pa.types.is_timestamp(pa_type)
np_dtype = np.dtype(f"M8[{pa_type.unit}]")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC we do something like the below in another area of this file correct? If so can we reuse?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, updated a few locations to reuse these methods.

dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
np_array = self._pa_array.to_numpy()
np_array = np_array.astype(np_dtype)
return DatetimeArray._simple_new(np_array, dtype=dtype)

def _to_timedeltaarray(self) -> TimedeltaArray:
"""Convert a pyarrow duration typed array to a TimedeltaArray."""
from pandas.core.arrays.timedeltas import TimedeltaArray

pa_type = self._pa_array.type
assert pa.types.is_duration(pa_type)
np_dtype = np.dtype(f"m8[{pa_type.unit}]")
np_array = self._pa_array.to_numpy()
np_array = np_array.astype(np_dtype)
return TimedeltaArray._simple_new(np_array, dtype=np_dtype)

@doc(ExtensionArray.to_numpy)
def to_numpy(
self,
Expand All @@ -1184,33 +1221,12 @@ def to_numpy(
na_value = self.dtype.na_value

pa_type = self._pa_array.type
if pa.types.is_timestamp(pa_type):
from pandas.core.arrays.datetimes import (
DatetimeArray,
tz_to_dtype,
)

np_dtype = np.dtype(f"M8[{pa_type.unit}]")
result = self._pa_array.to_numpy()
result = result.astype(np_dtype, copy=copy)
if pa.types.is_timestamp(pa_type) or pa.types.is_duration(pa_type):
result = self._maybe_convert_datelike_array()
if dtype is None or dtype.kind == "O":
dta_dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
result = DatetimeArray._simple_new(result, dtype=dta_dtype)
result = result.to_numpy(dtype=object, na_value=na_value)
elif result.dtype != dtype:
result = result.astype(dtype, copy=False)
return result
elif pa.types.is_duration(pa_type):
from pandas.core.arrays.timedeltas import TimedeltaArray

np_dtype = np.dtype(f"m8[{pa_type.unit}]")
result = self._pa_array.to_numpy()
result = result.astype(np_dtype, copy=copy)
if dtype is None or dtype.kind == "O":
result = TimedeltaArray._simple_new(result, dtype=np_dtype)
result = result.to_numpy(dtype=object, na_value=na_value)
elif result.dtype != dtype:
result = result.astype(dtype, copy=False)
else:
result = result.to_numpy(dtype=dtype)
return result
elif pa.types.is_time(pa_type):
# convert to list of python datetime.time objects before
Expand Down
20 changes: 4 additions & 16 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -2223,23 +2223,11 @@ def ensure_arraylike_for_datetimelike(data, copy: bool, cls_name: str):
):
data = data.to_numpy("int64", na_value=iNaT)
copy = False
elif isinstance(data, ArrowExtensionArray) and data.dtype.kind == "M":
from pandas.core.arrays import DatetimeArray
from pandas.core.arrays.datetimes import tz_to_dtype

pa_type = data._pa_array.type
dtype = tz_to_dtype(tz=pa_type.tz, unit=pa_type.unit)
data = data.to_numpy(f"M8[{pa_type.unit}]", na_value=iNaT)
data = DatetimeArray._simple_new(data, dtype=dtype)
elif isinstance(data, ArrowExtensionArray):
data = data._maybe_convert_datelike_array()
data = data.to_numpy()
copy = False
elif isinstance(data, ArrowExtensionArray) and data.dtype.kind == "m":
pa_type = data._pa_array.type
dtype = np.dtype(f"m8[{pa_type.unit}]")
data = data.to_numpy(dtype, na_value=iNaT)
copy = False
elif not isinstance(data, (np.ndarray, ExtensionArray)) or isinstance(
data, ArrowExtensionArray
):
elif not isinstance(data, (np.ndarray, ExtensionArray)):
# GH#24539 e.g. xarray, dask object
data = np.asarray(data)

Expand Down
23 changes: 5 additions & 18 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,6 @@
Categorical,
ExtensionArray,
)
from pandas.core.arrays.datetimes import tz_to_dtype
from pandas.core.arrays.string_ import StringArray
from pandas.core.base import (
IndexOpsMixin,
Expand Down Expand Up @@ -192,11 +191,7 @@
MultiIndex,
Series,
)
from pandas.core.arrays import (
DatetimeArray,
PeriodArray,
TimedeltaArray,
)
from pandas.core.arrays import PeriodArray

__all__ = ["Index"]

Expand Down Expand Up @@ -845,14 +840,10 @@ def _engine(

pa_type = self._values._pa_array.type
if pa.types.is_timestamp(pa_type):
dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
target_values = self._values.astype(dtype)
target_values = cast("DatetimeArray", target_values)
target_values = self._values._to_datetimearray()
return libindex.DatetimeEngine(target_values._ndarray)
elif pa.types.is_duration(pa_type):
dtype = np.dtype(f"m8[{pa_type.unit}]")
target_values = self._values.astype(dtype)
target_values = cast("TimedeltaArray", target_values)
target_values = self._values._to_timedeltaarray()
return libindex.TimedeltaEngine(target_values._ndarray)

if isinstance(target_values, ExtensionArray):
Expand Down Expand Up @@ -5117,14 +5108,10 @@ def _get_engine_target(self) -> ArrayLike:

pa_type = vals._pa_array.type
if pa.types.is_timestamp(pa_type):
dtype = tz_to_dtype(pa_type.tz, pa_type.unit)
vals = vals.astype(dtype)
vals = cast("DatetimeArray", vals)
vals = vals._to_datetimearray()
return vals._ndarray.view("i8")
elif pa.types.is_duration(pa_type):
dtype = np.dtype(f"m8[{pa_type.unit}]")
vals = vals.astype(dtype)
vals = cast("TimedeltaArray", vals)
vals = vals._to_timedeltaarray()
return vals._ndarray.view("i8")
if (
type(self) is Index
Expand Down
18 changes: 18 additions & 0 deletions pandas/core/reshape/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2112,6 +2112,12 @@ def injection(obj):
raise ValueError(f"Merge keys contain null values on {side} side")
raise ValueError(f"{side} keys must be sorted")

if isinstance(left_values, ArrowExtensionArray):
left_values = left_values._maybe_convert_datelike_array()

if isinstance(right_values, ArrowExtensionArray):
right_values = right_values._maybe_convert_datelike_array()

# initial type conversion as needed
if needs_i8_conversion(getattr(left_values, "dtype", None)):
if tolerance is not None:
Expand All @@ -2132,6 +2138,18 @@ def injection(obj):
left_values = left_values.view("i8")
right_values = right_values.view("i8")

if isinstance(left_values, BaseMaskedArray):
# we've verified above that no nulls exist
left_values = left_values._data
elif isinstance(left_values, ExtensionArray):
left_values = np.array(left_values)

if isinstance(right_values, BaseMaskedArray):
# we've verified above that no nulls exist
right_values = right_values._data
elif isinstance(right_values, ExtensionArray):
right_values = np.array(right_values)

# a "by" parameter requires special handling
if self.left_by is not None:
# remove 'on' parameter from values if one existed
Expand Down
38 changes: 38 additions & 0 deletions pandas/tests/reshape/merge/test_merge_asof.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import pytest
import pytz

import pandas.util._test_decorators as td

import pandas as pd
from pandas import (
Index,
Expand Down Expand Up @@ -1589,3 +1591,39 @@ def test_merge_asof_raise_for_duplicate_columns():

with pytest.raises(ValueError, match="column label 'a'"):
merge_asof(left, right, left_on="left_val", right_on="a")


@pytest.mark.parametrize(
"dtype",
[
"Int64",
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("timestamp[s][pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_merge_asof_extension_dtype(dtype):
# GH 52904
left = pd.DataFrame(
{
"join_col": [1, 3, 5],
"left_val": [1, 2, 3],
}
)
right = pd.DataFrame(
{
"join_col": [2, 3, 4],
"right_val": [1, 2, 3],
}
)
left = left.astype({"join_col": dtype})
right = right.astype({"join_col": dtype})
result = merge_asof(left, right, on="join_col")
expected = pd.DataFrame(
{
"join_col": [1, 3, 5],
"left_val": [1, 2, 3],
"right_val": [np.nan, 2.0, 3.0],
}
)
expected = expected.astype({"join_col": dtype})
tm.assert_frame_equal(result, expected)