Skip to content

Commit 4a168d0

Browse files
authored
ENH: Add CoW optimization to interpolate (#51249)
1 parent f07e98b commit 4a168d0

File tree

5 files changed

+230
-37
lines changed

5 files changed

+230
-37
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,9 @@ Copy-on-Write improvements
223223
- :meth:`DataFrame.to_period` / :meth:`Series.to_period`
224224
- :meth:`DataFrame.truncate`
225225
- :meth:`DataFrame.tz_convert` / :meth:`Series.tz_localize`
226+
- :meth:`DataFrame.interpolate` / :meth:`Series.interpolate`
227+
- :meth:`DataFrame.ffill` / :meth:`Series.ffill`
228+
- :meth:`DataFrame.bfill` / :meth:`Series.bfill`
226229
- :meth:`DataFrame.infer_objects` / :meth:`Series.infer_objects`
227230
- :meth:`DataFrame.astype` / :meth:`Series.astype`
228231
- :meth:`DataFrame.convert_dtypes` / :meth:`Series.convert_dtypes`

pandas/core/internals/blocks.py

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,10 @@ def make_block(
228228

229229
@final
230230
def make_block_same_class(
231-
self, values, placement: BlockPlacement | None = None
231+
self,
232+
values,
233+
placement: BlockPlacement | None = None,
234+
refs: BlockValuesRefs | None = None,
232235
) -> Block:
233236
"""Wrap given values in a block of same type as self."""
234237
# Pre-2.0 we called ensure_wrapped_if_datetimelike because fastparquet
@@ -237,7 +240,7 @@ def make_block_same_class(
237240
placement = self._mgr_locs
238241

239242
# We assume maybe_coerce_values has already been called
240-
return type(self)(values, placement=placement, ndim=self.ndim)
243+
return type(self)(values, placement=placement, ndim=self.ndim, refs=refs)
241244

242245
@final
243246
def __repr__(self) -> str:
@@ -421,7 +424,9 @@ def coerce_to_target_dtype(self, other) -> Block:
421424
return self.astype(new_dtype, copy=False)
422425

423426
@final
424-
def _maybe_downcast(self, blocks: list[Block], downcast=None) -> list[Block]:
427+
def _maybe_downcast(
428+
self, blocks: list[Block], downcast=None, using_cow: bool = False
429+
) -> list[Block]:
425430
if downcast is False:
426431
return blocks
427432

@@ -431,23 +436,26 @@ def _maybe_downcast(self, blocks: list[Block], downcast=None) -> list[Block]:
431436
# but ATM it breaks too much existing code.
432437
# split and convert the blocks
433438

434-
return extend_blocks([blk.convert() for blk in blocks])
439+
return extend_blocks(
440+
[blk.convert(using_cow=using_cow, copy=not using_cow) for blk in blocks]
441+
)
435442

436443
if downcast is None:
437444
return blocks
438445

439-
return extend_blocks([b._downcast_2d(downcast) for b in blocks])
446+
return extend_blocks([b._downcast_2d(downcast, using_cow) for b in blocks])
440447

441448
@final
442449
@maybe_split
443-
def _downcast_2d(self, dtype) -> list[Block]:
450+
def _downcast_2d(self, dtype, using_cow: bool = False) -> list[Block]:
444451
"""
445452
downcast specialized to 2D case post-validation.
446453
447454
Refactored to allow use of maybe_split.
448455
"""
449456
new_values = maybe_downcast_to_dtype(self.values, dtype=dtype)
450-
return [self.make_block(new_values)]
457+
refs = self.refs if using_cow and new_values is self.values else None
458+
return [self.make_block(new_values, refs=refs)]
451459

452460
def convert(
453461
self,
@@ -1209,13 +1217,16 @@ def interpolate(
12091217
limit_area: str | None = None,
12101218
fill_value: Any | None = None,
12111219
downcast: str | None = None,
1220+
using_cow: bool = False,
12121221
**kwargs,
12131222
) -> list[Block]:
12141223

12151224
inplace = validate_bool_kwarg(inplace, "inplace")
12161225

12171226
if not self._can_hold_na:
12181227
# If there are no NAs, then interpolate is a no-op
1228+
if using_cow:
1229+
return [self.copy(deep=False)]
12191230
return [self] if inplace else [self.copy()]
12201231

12211232
try:
@@ -1224,8 +1235,10 @@ def interpolate(
12241235
m = None
12251236
if m is None and self.dtype.kind != "f":
12261237
# only deal with floats
1227-
# bc we already checked that can_hold_na, we dont have int dtype here
1238+
# bc we already checked that can_hold_na, we don't have int dtype here
12281239
# test_interp_basic checks that we make a copy here
1240+
if using_cow:
1241+
return [self.copy(deep=False)]
12291242
return [self] if inplace else [self.copy()]
12301243

12311244
if self.is_object and self.ndim == 2 and self.shape[0] != 1 and axis == 0:
@@ -1244,7 +1257,15 @@ def interpolate(
12441257
**kwargs,
12451258
)
12461259

1247-
data = self.values if inplace else self.values.copy()
1260+
refs = None
1261+
if inplace:
1262+
if using_cow and self.refs.has_reference():
1263+
data = self.values.copy()
1264+
else:
1265+
data = self.values
1266+
refs = self.refs
1267+
else:
1268+
data = self.values.copy()
12481269
data = cast(np.ndarray, data) # bc overridden by ExtensionBlock
12491270

12501271
missing.interpolate_array_2d(
@@ -1259,8 +1280,8 @@ def interpolate(
12591280
**kwargs,
12601281
)
12611282

1262-
nb = self.make_block_same_class(data)
1263-
return nb._maybe_downcast([nb], downcast)
1283+
nb = self.make_block_same_class(data, refs=refs)
1284+
return nb._maybe_downcast([nb], downcast, using_cow)
12641285

12651286
def diff(self, n: int, axis: AxisInt = 1) -> list[Block]:
12661287
"""return block for the diff of the values"""
@@ -1632,6 +1653,7 @@ def interpolate(
16321653
inplace: bool = False,
16331654
limit: int | None = None,
16341655
fill_value=None,
1656+
using_cow: bool = False,
16351657
**kwargs,
16361658
):
16371659
values = self.values
@@ -2011,6 +2033,7 @@ def interpolate(
20112033
inplace: bool = False,
20122034
limit: int | None = None,
20132035
fill_value=None,
2036+
using_cow: bool = False,
20142037
**kwargs,
20152038
):
20162039
values = self.values
@@ -2020,12 +2043,20 @@ def interpolate(
20202043
# "Literal['linear']") [comparison-overlap]
20212044
if method == "linear": # type: ignore[comparison-overlap]
20222045
# TODO: GH#50950 implement for arbitrary EAs
2023-
data_out = values._ndarray if inplace else values._ndarray.copy()
2046+
refs = None
2047+
if using_cow:
2048+
if inplace and not self.refs.has_reference():
2049+
data_out = values._ndarray
2050+
refs = self.refs
2051+
else:
2052+
data_out = values._ndarray.copy()
2053+
else:
2054+
data_out = values._ndarray if inplace else values._ndarray.copy()
20242055
missing.interpolate_array_2d(
20252056
data_out, method=method, limit=limit, index=index, axis=axis
20262057
)
20272058
new_values = type(values)._simple_new(data_out, dtype=values.dtype)
2028-
return self.make_block_same_class(new_values)
2059+
return self.make_block_same_class(new_values, refs=refs)
20292060

20302061
elif values.ndim == 2 and axis == 0:
20312062
# NDArrayBackedExtensionArray.fillna assumes axis=1

pandas/core/internals/managers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -389,14 +389,9 @@ def diff(self: T, n: int, axis: AxisInt) -> T:
389389
return self.apply("diff", n=n, axis=axis)
390390

391391
def interpolate(self: T, inplace: bool, **kwargs) -> T:
392-
if inplace:
393-
# TODO(CoW) can be optimized to only copy those blocks that have refs
394-
if using_copy_on_write() and any(
395-
not self._has_no_reference_block(i) for i in range(len(self.blocks))
396-
):
397-
self = self.copy()
398-
399-
return self.apply("interpolate", inplace=inplace, **kwargs)
392+
return self.apply(
393+
"interpolate", inplace=inplace, **kwargs, using_cow=using_copy_on_write()
394+
)
400395

401396
def shift(self: T, periods: int, axis: AxisInt, fill_value) -> T:
402397
axis = self._normalize_axis(axis)
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import numpy as np
2+
import pytest
3+
4+
from pandas import (
5+
DataFrame,
6+
NaT,
7+
Series,
8+
Timestamp,
9+
)
10+
import pandas._testing as tm
11+
from pandas.tests.copy_view.util import get_array
12+
13+
14+
@pytest.mark.parametrize("method", ["pad", "nearest", "linear"])
15+
def test_interpolate_no_op(using_copy_on_write, method):
16+
df = DataFrame({"a": [1, 2]})
17+
df_orig = df.copy()
18+
19+
result = df.interpolate(method=method)
20+
21+
if using_copy_on_write:
22+
assert np.shares_memory(get_array(result, "a"), get_array(df, "a"))
23+
else:
24+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
25+
26+
result.iloc[0, 0] = 100
27+
28+
if using_copy_on_write:
29+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
30+
tm.assert_frame_equal(df, df_orig)
31+
32+
33+
@pytest.mark.parametrize("func", ["ffill", "bfill"])
34+
def test_interp_fill_functions(using_copy_on_write, func):
35+
# Check that these takes the same code paths as interpolate
36+
df = DataFrame({"a": [1, 2]})
37+
df_orig = df.copy()
38+
39+
result = getattr(df, func)()
40+
41+
if using_copy_on_write:
42+
assert np.shares_memory(get_array(result, "a"), get_array(df, "a"))
43+
else:
44+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
45+
46+
result.iloc[0, 0] = 100
47+
48+
if using_copy_on_write:
49+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
50+
tm.assert_frame_equal(df, df_orig)
51+
52+
53+
@pytest.mark.parametrize("func", ["ffill", "bfill"])
54+
@pytest.mark.parametrize(
55+
"vals", [[1, np.nan, 2], [Timestamp("2019-12-31"), NaT, Timestamp("2020-12-31")]]
56+
)
57+
def test_interpolate_triggers_copy(using_copy_on_write, vals, func):
58+
df = DataFrame({"a": vals})
59+
result = getattr(df, func)()
60+
61+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
62+
if using_copy_on_write:
63+
# Check that we don't have references when triggering a copy
64+
assert result._mgr._has_no_reference(0)
65+
66+
67+
@pytest.mark.parametrize(
68+
"vals", [[1, np.nan, 2], [Timestamp("2019-12-31"), NaT, Timestamp("2020-12-31")]]
69+
)
70+
def test_interpolate_inplace_no_reference_no_copy(using_copy_on_write, vals):
71+
df = DataFrame({"a": vals})
72+
arr = get_array(df, "a")
73+
df.interpolate(method="linear", inplace=True)
74+
75+
assert np.shares_memory(arr, get_array(df, "a"))
76+
if using_copy_on_write:
77+
# Check that we don't have references when triggering a copy
78+
assert df._mgr._has_no_reference(0)
79+
80+
81+
@pytest.mark.parametrize(
82+
"vals", [[1, np.nan, 2], [Timestamp("2019-12-31"), NaT, Timestamp("2020-12-31")]]
83+
)
84+
def test_interpolate_inplace_with_refs(using_copy_on_write, vals):
85+
df = DataFrame({"a": [1, np.nan, 2]})
86+
df_orig = df.copy()
87+
arr = get_array(df, "a")
88+
view = df[:]
89+
df.interpolate(method="linear", inplace=True)
90+
91+
if using_copy_on_write:
92+
# Check that copy was triggered in interpolate and that we don't
93+
# have any references left
94+
assert not np.shares_memory(arr, get_array(df, "a"))
95+
tm.assert_frame_equal(df_orig, view)
96+
assert df._mgr._has_no_reference(0)
97+
assert view._mgr._has_no_reference(0)
98+
else:
99+
assert np.shares_memory(arr, get_array(df, "a"))
100+
101+
102+
def test_interpolate_cleaned_fill_method(using_copy_on_write):
103+
# Check that "method is set to None" case works correctly
104+
df = DataFrame({"a": ["a", np.nan, "c"], "b": 1})
105+
df_orig = df.copy()
106+
107+
result = df.interpolate(method="asfreq")
108+
109+
if using_copy_on_write:
110+
assert np.shares_memory(get_array(result, "a"), get_array(df, "a"))
111+
else:
112+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
113+
114+
result.iloc[0, 0] = Timestamp("2021-12-31")
115+
116+
if using_copy_on_write:
117+
assert not np.shares_memory(get_array(result, "a"), get_array(df, "a"))
118+
tm.assert_frame_equal(df, df_orig)
119+
120+
121+
def test_interpolate_object_convert_no_op(using_copy_on_write):
122+
df = DataFrame({"a": ["a", "b", "c"], "b": 1})
123+
arr_a = get_array(df, "a")
124+
df.interpolate(method="pad", inplace=True)
125+
126+
# Now CoW makes a copy, it should not!
127+
if using_copy_on_write:
128+
assert df._mgr._has_no_reference(0)
129+
assert np.shares_memory(arr_a, get_array(df, "a"))
130+
131+
132+
def test_interpolate_object_convert_copies(using_copy_on_write):
133+
df = DataFrame({"a": Series([1, 2], dtype=object), "b": 1})
134+
arr_a = get_array(df, "a")
135+
df.interpolate(method="pad", inplace=True)
136+
137+
if using_copy_on_write:
138+
assert df._mgr._has_no_reference(0)
139+
assert not np.shares_memory(arr_a, get_array(df, "a"))
140+
141+
142+
def test_interpolate_downcast(using_copy_on_write):
143+
df = DataFrame({"a": [1, np.nan, 2.5], "b": 1})
144+
arr_a = get_array(df, "a")
145+
df.interpolate(method="pad", inplace=True, downcast="infer")
146+
147+
if using_copy_on_write:
148+
assert df._mgr._has_no_reference(0)
149+
assert np.shares_memory(arr_a, get_array(df, "a"))
150+
151+
152+
def test_interpolate_downcast_reference_triggers_copy(using_copy_on_write):
153+
df = DataFrame({"a": [1, np.nan, 2.5], "b": 1})
154+
df_orig = df.copy()
155+
arr_a = get_array(df, "a")
156+
view = df[:]
157+
df.interpolate(method="pad", inplace=True, downcast="infer")
158+
159+
if using_copy_on_write:
160+
assert df._mgr._has_no_reference(0)
161+
assert not np.shares_memory(arr_a, get_array(df, "a"))
162+
tm.assert_frame_equal(df_orig, view)
163+
else:
164+
tm.assert_frame_equal(df, view)

0 commit comments

Comments
 (0)