Skip to content

Commit d6dbe6f

Browse files
authored
ENH: implement tm.shares_memory (#44747)
1 parent 4bacee5 commit d6dbe6f

File tree

11 files changed

+66
-22
lines changed

11 files changed

+66
-22
lines changed

pandas/_testing/__init__.py

+51
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,15 @@
110110
UInt64Index,
111111
)
112112
from pandas.core.arrays import (
113+
BaseMaskedArray,
113114
DatetimeArray,
115+
ExtensionArray,
114116
PandasArray,
115117
PeriodArray,
116118
TimedeltaArray,
117119
period_array,
118120
)
121+
from pandas.core.arrays._mixins import NDArrayBackedExtensionArray
119122

120123
if TYPE_CHECKING:
121124
from pandas import (
@@ -1050,3 +1053,51 @@ def at(x):
10501053

10511054
def iat(x):
10521055
return x.iat
1056+
1057+
1058+
# -----------------------------------------------------------------------------
1059+
1060+
1061+
def shares_memory(left, right) -> bool:
1062+
"""
1063+
Pandas-compat for np.shares_memory.
1064+
"""
1065+
if isinstance(left, np.ndarray) and isinstance(right, np.ndarray):
1066+
return np.shares_memory(left, right)
1067+
elif isinstance(left, np.ndarray):
1068+
# Call with reversed args to get to unpacking logic below.
1069+
return shares_memory(right, left)
1070+
1071+
if isinstance(left, RangeIndex):
1072+
return False
1073+
if isinstance(left, MultiIndex):
1074+
return shares_memory(left._codes, right)
1075+
if isinstance(left, (Index, Series)):
1076+
return shares_memory(left._values, right)
1077+
1078+
if isinstance(left, NDArrayBackedExtensionArray):
1079+
return shares_memory(left._ndarray, right)
1080+
if isinstance(left, pd.SparseArray):
1081+
return shares_memory(left.sp_values, right)
1082+
1083+
if isinstance(left, ExtensionArray) and left.dtype == "string[pyarrow]":
1084+
# https://github.com/pandas-dev/pandas/pull/43930#discussion_r736862669
1085+
if isinstance(right, ExtensionArray) and right.dtype == "string[pyarrow]":
1086+
left_pa_data = left._data
1087+
right_pa_data = right._data
1088+
left_buf1 = left_pa_data.chunk(0).buffers()[1]
1089+
right_buf1 = right_pa_data.chunk(0).buffers()[1]
1090+
return left_buf1 == right_buf1
1091+
1092+
if isinstance(left, BaseMaskedArray) and isinstance(right, BaseMaskedArray):
1093+
# By convention, we'll say these share memory if they share *either*
1094+
# the _data or the _mask
1095+
return np.shares_memory(left._data, right._data) or np.shares_memory(
1096+
left._mask, right._mask
1097+
)
1098+
1099+
if isinstance(left, DataFrame) and len(left._mgr.arrays) == 1:
1100+
arr = left._mgr.arrays[0]
1101+
return shares_memory(arr, right)
1102+
1103+
raise NotImplementedError(type(left), type(right))

pandas/tests/arrays/categorical/test_constructors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def test_from_sequence_copy(self):
738738

739739
result = Categorical._from_sequence(cat, dtype=None, copy=True)
740740

741-
assert not np.shares_memory(result._codes, cat._codes)
741+
assert not tm.shares_memory(result, cat)
742742

743743
@pytest.mark.xfail(
744744
not IS64 or is_platform_windows(),

pandas/tests/arrays/floating/test_arithmetic.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,5 @@ def test_unary_float_operators(float_ea_dtype, source, neg_target, abs_target):
199199

200200
tm.assert_extension_array_equal(neg_result, neg_target)
201201
tm.assert_extension_array_equal(pos_result, arr)
202-
assert not np.shares_memory(pos_result._data, arr._data)
203-
assert not np.shares_memory(pos_result._mask, arr._mask)
202+
assert not tm.shares_memory(pos_result, arr)
204203
tm.assert_extension_array_equal(abs_result, abs_target)

pandas/tests/arrays/floating/test_astype.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ def test_astype_copy():
7878
# copy=True -> ensure both data and mask are actual copies
7979
result = arr.astype("Float64", copy=True)
8080
assert result is not arr
81-
assert not np.shares_memory(result._data, arr._data)
82-
assert not np.shares_memory(result._mask, arr._mask)
81+
assert not tm.shares_memory(result, arr)
8382
result[0] = 10
8483
tm.assert_extension_array_equal(arr, orig)
8584
result[0] = pd.NA
@@ -101,8 +100,7 @@ def test_astype_copy():
101100
orig = pd.array([0.1, 0.2, None], dtype="Float64")
102101

103102
result = arr.astype("Float32", copy=False)
104-
assert not np.shares_memory(result._data, arr._data)
105-
assert not np.shares_memory(result._mask, arr._mask)
103+
assert not tm.shares_memory(result, arr)
106104
result[0] = 10
107105
tm.assert_extension_array_equal(arr, orig)
108106
result[0] = pd.NA

pandas/tests/arrays/integer/test_arithmetic.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,5 @@ def test_unary_int_operators(any_signed_int_ea_dtype, source, neg_target, abs_ta
299299

300300
tm.assert_extension_array_equal(neg_result, neg_target)
301301
tm.assert_extension_array_equal(pos_result, arr)
302-
assert not np.shares_memory(pos_result._data, arr._data)
303-
assert not np.shares_memory(pos_result._mask, arr._mask)
302+
assert not tm.shares_memory(pos_result, arr)
304303
tm.assert_extension_array_equal(abs_result, abs_target)

pandas/tests/arrays/integer/test_dtypes.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def test_astype_copy():
153153
# copy=True -> ensure both data and mask are actual copies
154154
result = arr.astype("Int64", copy=True)
155155
assert result is not arr
156-
assert not np.shares_memory(result._data, arr._data)
157-
assert not np.shares_memory(result._mask, arr._mask)
156+
assert not tm.shares_memory(result, arr)
158157
result[0] = 10
159158
tm.assert_extension_array_equal(arr, orig)
160159
result[0] = pd.NA
@@ -176,8 +175,7 @@ def test_astype_copy():
176175
orig = pd.array([1, 2, 3, None], dtype="Int64")
177176

178177
result = arr.astype("Int32", copy=False)
179-
assert not np.shares_memory(result._data, arr._data)
180-
assert not np.shares_memory(result._mask, arr._mask)
178+
assert not tm.shares_memory(result, arr)
181179
result[0] = 10
182180
tm.assert_extension_array_equal(arr, orig)
183181
result[0] = pd.NA

pandas/tests/arrays/masked_shared.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""
22
Tests shared by MaskedArray subclasses.
33
"""
4-
import numpy as np
54

65
import pandas as pd
76
import pandas._testing as tm
@@ -56,7 +55,7 @@ class NumericOps:
5655

5756
def test_no_shared_mask(self, data):
5857
result = data + 1
59-
assert np.shares_memory(result._mask, data._mask) is False
58+
assert not tm.shares_memory(result, data)
6059

6160
def test_array(self, comparison_op, dtype):
6261
op = comparison_op

pandas/tests/arrays/test_array.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -184,15 +184,15 @@ def test_array_copy():
184184
a = np.array([1, 2])
185185
# default is to copy
186186
b = pd.array(a, dtype=a.dtype)
187-
assert np.shares_memory(a, b._ndarray) is False
187+
assert not tm.shares_memory(a, b)
188188

189189
# copy=True
190190
b = pd.array(a, dtype=a.dtype, copy=True)
191-
assert np.shares_memory(a, b._ndarray) is False
191+
assert not tm.shares_memory(a, b)
192192

193193
# copy=False
194194
b = pd.array(a, dtype=a.dtype, copy=False)
195-
assert np.shares_memory(a, b._ndarray) is True
195+
assert tm.shares_memory(a, b)
196196

197197

198198
cet = pytz.timezone("CET")

pandas/tests/arrays/test_numpy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_constructor_copy():
128128
arr = np.array([0, 1])
129129
result = PandasArray(arr, copy=True)
130130

131-
assert np.shares_memory(result._ndarray, arr) is False
131+
assert not tm.shares_memory(result, arr)
132132

133133

134134
def test_constructor_with_data(any_numpy_array):

pandas/tests/arrays/test_timedeltas.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,11 @@ def test_pos(self):
9999

100100
result = +arr
101101
tm.assert_timedelta_array_equal(result, arr)
102-
assert not np.shares_memory(result._ndarray, arr._ndarray)
102+
assert not tm.shares_memory(result, arr)
103103

104104
result2 = np.positive(arr)
105105
tm.assert_timedelta_array_equal(result2, arr)
106-
assert not np.shares_memory(result2._ndarray, arr._ndarray)
106+
assert not tm.shares_memory(result2, arr)
107107

108108
def test_neg(self):
109109
vals = np.array([-3600 * 10 ** 9, "NaT", 7200 * 10 ** 9], dtype="m8[ns]")

pandas/tests/series/indexing/test_setitem.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def _check_inplace(self, is_inplace, orig, arr, obj):
562562
if arr.dtype.kind in ["m", "M"]:
563563
# We may not have the same DTA/TDA, but will have the same
564564
# underlying data
565-
assert arr._data is obj._values._data
565+
assert arr._ndarray is obj._values._ndarray
566566
else:
567567
assert obj._values is arr
568568
else:

0 commit comments

Comments
 (0)