Skip to content

Commit b02ffe2

Browse files
authored
CoW: Add reference tracking to index when created from series (#51803)
1 parent 2f43b41 commit b02ffe2

13 files changed

+327
-14
lines changed

pandas/_libs/internals.pyx

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,16 @@ cdef class BlockValuesRefs:
890890
"""
891891
self.referenced_blocks.append(weakref.ref(blk))
892892

893+
def add_index_reference(self, index: object) -> None:
894+
"""Adds a new reference to our reference collection when creating an index.
895+
896+
Parameters
897+
----------
898+
index: object
899+
The index that the new reference should point to.
900+
"""
901+
self.referenced_blocks.append(weakref.ref(index))
902+
893903
def has_reference(self) -> bool:
894904
"""Checks if block has foreign references.
895905

pandas/core/frame.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5937,7 +5937,7 @@ def set_index(
59375937
names.append(None)
59385938
# from here, col can only be a column label
59395939
else:
5940-
arrays.append(frame[col]._values)
5940+
arrays.append(frame[col])
59415941
names.append(col)
59425942
if drop:
59435943
to_remove.append(col)

pandas/core/indexes/base.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,10 @@
7373
rewrite_exception,
7474
)
7575

76-
from pandas.core.dtypes.astype import astype_array
76+
from pandas.core.dtypes.astype import (
77+
astype_array,
78+
astype_is_view,
79+
)
7780
from pandas.core.dtypes.cast import (
7881
LossySetitemError,
7982
can_hold_element,
@@ -458,6 +461,8 @@ def _engine_type(
458461

459462
str = CachedAccessor("str", StringMethods)
460463

464+
_references = None
465+
461466
# --------------------------------------------------------------------
462467
# Constructors
463468

@@ -478,6 +483,10 @@ def __new__(
478483

479484
data_dtype = getattr(data, "dtype", None)
480485

486+
refs = None
487+
if not copy and isinstance(data, (ABCSeries, Index)):
488+
refs = data._references
489+
481490
# range
482491
if isinstance(data, (range, RangeIndex)):
483492
result = RangeIndex(start=data, copy=copy, name=name)
@@ -551,7 +560,7 @@ def __new__(
551560
klass = cls._dtype_to_subclass(arr.dtype)
552561

553562
arr = klass._ensure_array(arr, arr.dtype, copy=False)
554-
return klass._simple_new(arr, name)
563+
return klass._simple_new(arr, name, refs=refs)
555564

556565
@classmethod
557566
def _ensure_array(cls, data, dtype, copy: bool):
@@ -629,7 +638,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
629638
# See each method's docstring.
630639

631640
@classmethod
632-
def _simple_new(cls, values: ArrayLike, name: Hashable = None) -> Self:
641+
def _simple_new(cls, values: ArrayLike, name: Hashable = None, refs=None) -> Self:
633642
"""
634643
We require that we have a dtype compat for the values. If we are passed
635644
a non-dtype compat, then coerce using the constructor.
@@ -643,6 +652,9 @@ def _simple_new(cls, values: ArrayLike, name: Hashable = None) -> Self:
643652
result._name = name
644653
result._cache = {}
645654
result._reset_identity()
655+
result._references = refs
656+
if refs is not None:
657+
refs.add_index_reference(result)
646658

647659
return result
648660

@@ -739,13 +751,13 @@ def _shallow_copy(self, values, name: Hashable = no_default) -> Self:
739751
"""
740752
name = self._name if name is no_default else name
741753

742-
return self._simple_new(values, name=name)
754+
return self._simple_new(values, name=name, refs=self._references)
743755

744756
def _view(self) -> Self:
745757
"""
746758
fastpath to make a shallow copy, i.e. new object with same data.
747759
"""
748-
result = self._simple_new(self._values, name=self._name)
760+
result = self._simple_new(self._values, name=self._name, refs=self._references)
749761

750762
result._cache = self._cache
751763
return result
@@ -955,7 +967,7 @@ def view(self, cls=None):
955967
# of types.
956968
arr_cls = idx_cls._data_cls
957969
arr = arr_cls(self._data.view("i8"), dtype=dtype)
958-
return idx_cls._simple_new(arr, name=self.name)
970+
return idx_cls._simple_new(arr, name=self.name, refs=self._references)
959971

960972
result = self._data.view(cls)
961973
else:
@@ -1011,7 +1023,15 @@ def astype(self, dtype, copy: bool = True):
10111023
new_values = astype_array(values, dtype=dtype, copy=copy)
10121024

10131025
# pass copy=False because any copying will be done in the astype above
1014-
return Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
1026+
result = Index(new_values, name=self.name, dtype=new_values.dtype, copy=False)
1027+
if (
1028+
not copy
1029+
and self._references is not None
1030+
and astype_is_view(self.dtype, dtype)
1031+
):
1032+
result._references = self._references
1033+
result._references.add_index_reference(result)
1034+
return result
10151035

10161036
_index_shared_docs[
10171037
"take"
@@ -5183,7 +5203,7 @@ def _getitem_slice(self, slobj: slice) -> Self:
51835203
Fastpath for __getitem__ when we know we have a slice.
51845204
"""
51855205
res = self._data[slobj]
5186-
result = type(self)._simple_new(res, name=self._name)
5206+
result = type(self)._simple_new(res, name=self._name, refs=self._references)
51875207
if "_engine" in self._cache:
51885208
reverse = slobj.step is not None and slobj.step < 0
51895209
result._engine._update_from_sliced(self._engine, reverse=reverse) # type: ignore[union-attr] # noqa: E501
@@ -6707,7 +6727,11 @@ def infer_objects(self, copy: bool = True) -> Index:
67076727
)
67086728
if copy and res_values is values:
67096729
return self.copy()
6710-
return Index(res_values, name=self.name)
6730+
result = Index(res_values, name=self.name)
6731+
if not copy and res_values is values and self._references is not None:
6732+
result._references = self._references
6733+
result._references.add_index_reference(result)
6734+
return result
67116735

67126736
# --------------------------------------------------------------------
67136737
# Generated Arithmetic, Comparison, and Unary Methods

pandas/core/indexes/datetimes.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
is_datetime64tz_dtype,
3636
is_scalar,
3737
)
38+
from pandas.core.dtypes.generic import ABCSeries
3839
from pandas.core.dtypes.missing import is_valid_na_for_dtype
3940

4041
from pandas.core.arrays.datetimes import (
@@ -267,7 +268,7 @@ def strftime(self, date_format) -> Index:
267268
@doc(DatetimeArray.tz_convert)
268269
def tz_convert(self, tz) -> DatetimeIndex:
269270
arr = self._data.tz_convert(tz)
270-
return type(self)._simple_new(arr, name=self.name)
271+
return type(self)._simple_new(arr, name=self.name, refs=self._references)
271272

272273
@doc(DatetimeArray.tz_localize)
273274
def tz_localize(
@@ -346,8 +347,11 @@ def __new__(
346347
yearfirst=yearfirst,
347348
ambiguous=ambiguous,
348349
)
350+
refs = None
351+
if not copy and isinstance(data, (Index, ABCSeries)):
352+
refs = data._references
349353

350-
subarr = cls._simple_new(dtarr, name=name)
354+
subarr = cls._simple_new(dtarr, name=name, refs=refs)
351355
return subarr
352356

353357
# --------------------------------------------------------------------

pandas/core/indexes/multi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,7 @@ def __new__(
352352
result._codes = new_codes
353353

354354
result._reset_identity()
355+
result._references = None
355356

356357
return result
357358

pandas/core/indexes/period.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
from pandas.core.dtypes.common import is_integer
2828
from pandas.core.dtypes.dtypes import PeriodDtype
29+
from pandas.core.dtypes.generic import ABCSeries
2930
from pandas.core.dtypes.missing import is_valid_na_for_dtype
3031

3132
from pandas.core.arrays.period import (
@@ -221,6 +222,10 @@ def __new__(
221222
"second",
222223
}
223224

225+
refs = None
226+
if not copy and isinstance(data, (Index, ABCSeries)):
227+
refs = data._references
228+
224229
if not set(fields).issubset(valid_field_set):
225230
argument = list(set(fields) - valid_field_set)[0]
226231
raise TypeError(f"__new__() got an unexpected keyword argument {argument}")
@@ -261,7 +266,7 @@ def __new__(
261266
if copy:
262267
data = data.copy()
263268

264-
return cls._simple_new(data, name=name)
269+
return cls._simple_new(data, name=name, refs=refs)
265270

266271
# ------------------------------------------------------------------------
267272
# Data

pandas/core/indexes/range.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def _simple_new( # type: ignore[override]
177177
result._name = name
178178
result._cache = {}
179179
result._reset_identity()
180+
result._references = None
180181
return result
181182

182183
@classmethod

pandas/core/indexes/timedeltas.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
is_scalar,
1919
is_timedelta64_dtype,
2020
)
21+
from pandas.core.dtypes.generic import ABCSeries
2122

2223
from pandas.core.arrays import datetimelike as dtl
2324
from pandas.core.arrays.timedeltas import TimedeltaArray
@@ -172,7 +173,11 @@ def __new__(
172173
tdarr = TimedeltaArray._from_sequence_not_strict(
173174
data, freq=freq, unit=unit, dtype=dtype, copy=copy
174175
)
175-
return cls._simple_new(tdarr, name=name)
176+
refs = None
177+
if not copy and isinstance(data, (ABCSeries, Index)):
178+
refs = data._references
179+
180+
return cls._simple_new(tdarr, name=name, refs=refs)
176181

177182
# -------------------------------------------------------------------
178183

pandas/tests/copy_view/index/__init__.py

Whitespace-only changes.
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import pytest
2+
3+
from pandas import (
4+
DatetimeIndex,
5+
Series,
6+
Timestamp,
7+
date_range,
8+
)
9+
import pandas._testing as tm
10+
11+
12+
@pytest.mark.parametrize(
13+
"cons",
14+
[
15+
lambda x: DatetimeIndex(x),
16+
lambda x: DatetimeIndex(DatetimeIndex(x)),
17+
],
18+
)
19+
def test_datetimeindex(using_copy_on_write, cons):
20+
dt = date_range("2019-12-31", periods=3, freq="D")
21+
ser = Series(dt)
22+
idx = cons(ser)
23+
expected = idx.copy(deep=True)
24+
ser.iloc[0] = Timestamp("2020-12-31")
25+
if using_copy_on_write:
26+
tm.assert_index_equal(idx, expected)
27+
28+
29+
def test_datetimeindex_tz_convert(using_copy_on_write):
30+
dt = date_range("2019-12-31", periods=3, freq="D", tz="Europe/Berlin")
31+
ser = Series(dt)
32+
idx = DatetimeIndex(ser).tz_convert("US/Eastern")
33+
expected = idx.copy(deep=True)
34+
ser.iloc[0] = Timestamp("2020-12-31", tz="Europe/Berlin")
35+
if using_copy_on_write:
36+
tm.assert_index_equal(idx, expected)
37+
38+
39+
def test_datetimeindex_tz_localize(using_copy_on_write):
40+
dt = date_range("2019-12-31", periods=3, freq="D")
41+
ser = Series(dt)
42+
idx = DatetimeIndex(ser).tz_localize("Europe/Berlin")
43+
expected = idx.copy(deep=True)
44+
ser.iloc[0] = Timestamp("2020-12-31")
45+
if using_copy_on_write:
46+
tm.assert_index_equal(idx, expected)
47+
48+
49+
def test_datetimeindex_isocalendar(using_copy_on_write):
50+
dt = date_range("2019-12-31", periods=3, freq="D")
51+
ser = Series(dt)
52+
df = DatetimeIndex(ser).isocalendar()
53+
expected = df.index.copy(deep=True)
54+
ser.iloc[0] = Timestamp("2020-12-31")
55+
if using_copy_on_write:
56+
tm.assert_index_equal(df.index, expected)

0 commit comments

Comments
 (0)