Skip to content

Commit 6d6dc1a

Browse files
authored
CoW: Ignore copy=True when copy_on_write is enabled (#51464)
* CoW: Ignore copy=True when copy_on_write is enabled * Update * Add concat and merge * Add align * Fix tests * Fix ci * Add transpose test * Fix array manager
1 parent bdbdab1 commit 6d6dc1a

File tree

13 files changed

+192
-45
lines changed

13 files changed

+192
-45
lines changed

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11491,7 +11491,7 @@ def to_timestamp(
1149111491
>>> df2.index
1149211492
DatetimeIndex(['2023-01-31', '2024-01-31'], dtype='datetime64[ns]', freq=None)
1149311493
"""
11494-
new_obj = self.copy(deep=copy)
11494+
new_obj = self.copy(deep=copy and not using_copy_on_write())
1149511495

1149611496
axis_name = self._get_axis_name(axis)
1149711497
old_ax = getattr(self, axis_name)
@@ -11548,7 +11548,7 @@ def to_period(
1154811548
>>> idx.to_period("Y")
1154911549
PeriodIndex(['2001', '2002', '2003'], dtype='period[A-DEC]')
1155011550
"""
11551-
new_obj = self.copy(deep=copy)
11551+
new_obj = self.copy(deep=copy and not using_copy_on_write())
1155211552

1155311553
axis_name = self._get_axis_name(axis)
1155411554
old_ax = getattr(self, axis_name)

pandas/core/generic.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def set_flags(
441441
>>> df2.flags.allows_duplicate_labels
442442
False
443443
"""
444-
df = self.copy(deep=copy)
444+
df = self.copy(deep=copy and not using_copy_on_write())
445445
if allows_duplicate_labels is not None:
446446
df.flags["allows_duplicate_labels"] = allows_duplicate_labels
447447
return df
@@ -712,7 +712,7 @@ def _set_axis_nocheck(
712712
else:
713713
# With copy=False, we create a new object but don't copy the
714714
# underlying data.
715-
obj = self.copy(deep=copy)
715+
obj = self.copy(deep=copy and not using_copy_on_write())
716716
setattr(obj, obj._get_axis_name(axis), labels)
717717
return obj
718718

@@ -741,7 +741,7 @@ def swapaxes(
741741
j = self._get_axis_number(axis2)
742742

743743
if i == j:
744-
return self.copy(deep=copy)
744+
return self.copy(deep=copy and not using_copy_on_write())
745745

746746
mapping = {i: j, j: i}
747747

@@ -998,7 +998,7 @@ def _rename(
998998
index = mapper
999999

10001000
self._check_inplace_and_allows_duplicate_labels(inplace)
1001-
result = self if inplace else self.copy(deep=copy)
1001+
result = self if inplace else self.copy(deep=copy and not using_copy_on_write())
10021002

10031003
for axis_no, replacements in enumerate((index, columns)):
10041004
if replacements is None:
@@ -1214,6 +1214,9 @@ class name
12141214

12151215
inplace = validate_bool_kwarg(inplace, "inplace")
12161216

1217+
if copy and using_copy_on_write():
1218+
copy = False
1219+
12171220
if mapper is not lib.no_default:
12181221
# Use v0.23 behavior if a scalar or list
12191222
non_mapper = is_scalar(mapper) or (
@@ -5330,6 +5333,8 @@ def reindex(
53305333

53315334
# if all axes that are requested to reindex are equal, then only copy
53325335
# if indicated must have index names equal here as well as values
5336+
if copy and using_copy_on_write():
5337+
copy = False
53335338
if all(
53345339
self._get_axis(axis_name).identical(ax)
53355340
for axis_name, ax in axes.items()
@@ -5424,10 +5429,14 @@ def _reindex_with_indexers(
54245429
# If we've made a copy once, no need to make another one
54255430
copy = False
54265431

5427-
if (copy or copy is None) and new_data is self._mgr:
5432+
if (
5433+
(copy or copy is None)
5434+
and new_data is self._mgr
5435+
and not using_copy_on_write()
5436+
):
54285437
new_data = new_data.copy(deep=copy)
54295438
elif using_copy_on_write() and new_data is self._mgr:
5430-
new_data = new_data.copy(deep=copy)
5439+
new_data = new_data.copy(deep=False)
54315440

54325441
return self._constructor(new_data).__finalize__(self)
54335442

@@ -6292,6 +6301,9 @@ def astype(
62926301
2 2020-01-03
62936302
dtype: datetime64[ns]
62946303
"""
6304+
if copy and using_copy_on_write():
6305+
copy = False
6306+
62956307
if is_dict_like(dtype):
62966308
if self.ndim == 1: # i.e. Series
62976309
if len(dtype) > 1 or self.name not in dtype:
@@ -9550,6 +9562,8 @@ def _align_series(
95509562
fill_axis: Axis = 0,
95519563
):
95529564
is_series = isinstance(self, ABCSeries)
9565+
if copy and using_copy_on_write():
9566+
copy = False
95539567

95549568
if (not is_series and axis is None) or axis not in [None, 0, 1]:
95559569
raise ValueError("Must specify axis=0 or 1")
@@ -10318,8 +10332,7 @@ def truncate(
1031810332
if isinstance(ax, MultiIndex):
1031910333
setattr(result, self._get_axis_name(axis), ax.truncate(before, after))
1032010334

10321-
if copy or (copy is None and not using_copy_on_write()):
10322-
result = result.copy(deep=copy)
10335+
result = result.copy(deep=copy and not using_copy_on_write())
1032310336

1032410337
return result
1032510338

@@ -10400,7 +10413,7 @@ def _tz_convert(ax, tz):
1040010413
raise ValueError(f"The level {level} is not valid")
1040110414
ax = _tz_convert(ax, tz)
1040210415

10403-
result = self.copy(deep=copy)
10416+
result = self.copy(deep=copy and not using_copy_on_write())
1040410417
result = result.set_axis(ax, axis=axis, copy=False)
1040510418
return result.__finalize__(self, method="tz_convert")
1040610419

@@ -10582,7 +10595,7 @@ def _tz_localize(ax, tz, ambiguous, nonexistent):
1058210595
raise ValueError(f"The level {level} is not valid")
1058310596
ax = _tz_localize(ax, tz, ambiguous, nonexistent)
1058410597

10585-
result = self.copy(deep=copy)
10598+
result = self.copy(deep=copy and not using_copy_on_write())
1058610599
result = result.set_axis(ax, axis=axis, copy=False)
1058710600
return result.__finalize__(self, method="tz_localize")
1058810601

pandas/core/internals/managers.py

+4
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,8 @@ def astype(self: T, dtype, copy: bool | None = False, errors: str = "raise") ->
444444
copy = False
445445
else:
446446
copy = True
447+
elif using_copy_on_write():
448+
copy = False
447449

448450
return self.apply(
449451
"astype",
@@ -459,6 +461,8 @@ def convert(self: T, copy: bool | None) -> T:
459461
copy = False
460462
else:
461463
copy = True
464+
elif using_copy_on_write():
465+
copy = False
462466

463467
return self.apply("convert", copy=copy, using_cow=using_copy_on_write())
464468

pandas/core/reshape/concat.py

+2
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,8 @@ def concat(
367367
copy = False
368368
else:
369369
copy = True
370+
elif copy and using_copy_on_write():
371+
copy = False
370372

371373
op = _Concatenator(
372374
objs,

pandas/core/series.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4026,7 +4026,7 @@ def swaplevel(
40264026
{examples}
40274027
"""
40284028
assert isinstance(self.index, MultiIndex)
4029-
result = self.copy(deep=copy)
4029+
result = self.copy(deep=copy and not using_copy_on_write())
40304030
result.index = self.index.swaplevel(i, j)
40314031
return result
40324032

@@ -5640,7 +5640,7 @@ def to_timestamp(
56405640
if not isinstance(self.index, PeriodIndex):
56415641
raise TypeError(f"unsupported Type {type(self.index).__name__}")
56425642

5643-
new_obj = self.copy(deep=copy)
5643+
new_obj = self.copy(deep=copy and not using_copy_on_write())
56445644
new_index = self.index.to_timestamp(freq=freq, how=how)
56455645
setattr(new_obj, "index", new_index)
56465646
return new_obj
@@ -5680,7 +5680,7 @@ def to_period(self, freq: str | None = None, copy: bool | None = None) -> Series
56805680
if not isinstance(self.index, DatetimeIndex):
56815681
raise TypeError(f"unsupported Type {type(self.index).__name__}")
56825682

5683-
new_obj = self.copy(deep=copy)
5683+
new_obj = self.copy(deep=copy and not using_copy_on_write())
56845684
new_index = self.index.to_period(freq=freq)
56855685
setattr(new_obj, "index", new_index)
56865686
return new_obj

pandas/tests/copy_view/test_functions.py

+30
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,21 @@ def test_concat_mixed_series_frame(using_copy_on_write):
181181
tm.assert_frame_equal(result, expected)
182182

183183

184+
@pytest.mark.parametrize("copy", [True, None, False])
185+
def test_concat_copy_keyword(using_copy_on_write, copy):
186+
df = DataFrame({"a": [1, 2]})
187+
df2 = DataFrame({"b": [1.5, 2.5]})
188+
189+
result = concat([df, df2], axis=1, copy=copy)
190+
191+
if using_copy_on_write or copy is False:
192+
assert np.shares_memory(get_array(df, "a"), get_array(result, "a"))
193+
assert np.shares_memory(get_array(df2, "b"), get_array(result, "b"))
194+
else:
195+
assert not np.shares_memory(get_array(df, "a"), get_array(result, "a"))
196+
assert not np.shares_memory(get_array(df2, "b"), get_array(result, "b"))
197+
198+
184199
@pytest.mark.parametrize(
185200
"func",
186201
[
@@ -280,3 +295,18 @@ def test_merge_on_key_enlarging_one(using_copy_on_write, func, how):
280295
assert not np.shares_memory(get_array(result, "a"), get_array(df1, "a"))
281296
tm.assert_frame_equal(df1, df1_orig)
282297
tm.assert_frame_equal(df2, df2_orig)
298+
299+
300+
@pytest.mark.parametrize("copy", [True, None, False])
301+
def test_merge_copy_keyword(using_copy_on_write, copy):
302+
df = DataFrame({"a": [1, 2]})
303+
df2 = DataFrame({"b": [3, 4.5]})
304+
305+
result = df.merge(df2, copy=copy, left_index=True, right_index=True)
306+
307+
if using_copy_on_write or copy is False:
308+
assert np.shares_memory(get_array(df, "a"), get_array(result, "a"))
309+
assert np.shares_memory(get_array(df2, "b"), get_array(result, "b"))
310+
else:
311+
assert not np.shares_memory(get_array(df, "a"), get_array(result, "a"))
312+
assert not np.shares_memory(get_array(df2, "b"), get_array(result, "b"))

pandas/tests/copy_view/test_methods.py

+83-8
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def test_copy_shallow(using_copy_on_write):
6666
lambda df, copy: df.rename(columns=str.lower, copy=copy),
6767
lambda df, copy: df.reindex(columns=["a", "c"], copy=copy),
6868
lambda df, copy: df.reindex_like(df, copy=copy),
69+
lambda df, copy: df.align(df, copy=copy)[0],
6970
lambda df, copy: df.set_axis(["a", "b", "c"], axis="index", copy=copy),
7071
lambda df, copy: df.rename_axis(index="test", copy=copy),
7172
lambda df, copy: df.rename_axis(columns="test", copy=copy),
@@ -84,6 +85,7 @@ def test_copy_shallow(using_copy_on_write):
8485
"rename",
8586
"reindex",
8687
"reindex_like",
88+
"align",
8789
"set_axis",
8890
"rename_axis0",
8991
"rename_axis1",
@@ -115,22 +117,96 @@ def test_methods_copy_keyword(
115117
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": [0.1, 0.2, 0.3]}, index=index)
116118
df2 = method(df, copy=copy)
117119

118-
share_memory = (using_copy_on_write and copy is not True) or copy is False
120+
share_memory = using_copy_on_write or copy is False
119121

120122
if request.node.callspec.id.startswith("reindex-"):
121123
# TODO copy=False without CoW still returns a copy in this case
122124
if not using_copy_on_write and not using_array_manager and copy is False:
123125
share_memory = False
124-
# TODO copy=True with CoW still returns a view
125-
if using_copy_on_write:
126-
share_memory = True
127126

128127
if share_memory:
129128
assert np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
130129
else:
131130
assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
132131

133132

133+
@pytest.mark.parametrize("copy", [True, None, False])
134+
@pytest.mark.parametrize(
135+
"method",
136+
[
137+
lambda ser, copy: ser.rename(index={0: 100}, copy=copy),
138+
lambda ser, copy: ser.reindex(index=ser.index, copy=copy),
139+
lambda ser, copy: ser.reindex_like(ser, copy=copy),
140+
lambda ser, copy: ser.align(ser, copy=copy)[0],
141+
lambda ser, copy: ser.set_axis(["a", "b", "c"], axis="index", copy=copy),
142+
lambda ser, copy: ser.rename_axis(index="test", copy=copy),
143+
lambda ser, copy: ser.astype("int64", copy=copy),
144+
lambda ser, copy: ser.swaplevel(0, 1, copy=copy),
145+
lambda ser, copy: ser.swapaxes(0, 0, copy=copy),
146+
lambda ser, copy: ser.truncate(0, 5, copy=copy),
147+
lambda ser, copy: ser.infer_objects(copy=copy),
148+
lambda ser, copy: ser.to_timestamp(copy=copy),
149+
lambda ser, copy: ser.to_period(freq="D", copy=copy),
150+
lambda ser, copy: ser.tz_localize("US/Central", copy=copy),
151+
lambda ser, copy: ser.tz_convert("US/Central", copy=copy),
152+
lambda ser, copy: ser.set_flags(allows_duplicate_labels=False, copy=copy),
153+
],
154+
ids=[
155+
"rename",
156+
"reindex",
157+
"reindex_like",
158+
"align",
159+
"set_axis",
160+
"rename_axis0",
161+
"astype",
162+
"swaplevel",
163+
"swapaxes",
164+
"truncate",
165+
"infer_objects",
166+
"to_timestamp",
167+
"to_period",
168+
"tz_localize",
169+
"tz_convert",
170+
"set_flags",
171+
],
172+
)
173+
def test_methods_series_copy_keyword(request, method, copy, using_copy_on_write):
174+
index = None
175+
if "to_timestamp" in request.node.callspec.id:
176+
index = period_range("2012-01-01", freq="D", periods=3)
177+
elif "to_period" in request.node.callspec.id:
178+
index = date_range("2012-01-01", freq="D", periods=3)
179+
elif "tz_localize" in request.node.callspec.id:
180+
index = date_range("2012-01-01", freq="D", periods=3)
181+
elif "tz_convert" in request.node.callspec.id:
182+
index = date_range("2012-01-01", freq="D", periods=3, tz="Europe/Brussels")
183+
elif "swaplevel" in request.node.callspec.id:
184+
index = MultiIndex.from_arrays([[1, 2, 3], [4, 5, 6]])
185+
186+
ser = Series([1, 2, 3], index=index)
187+
ser2 = method(ser, copy=copy)
188+
189+
share_memory = using_copy_on_write or copy is False
190+
191+
if share_memory:
192+
assert np.shares_memory(get_array(ser2), get_array(ser))
193+
else:
194+
assert not np.shares_memory(get_array(ser2), get_array(ser))
195+
196+
197+
@pytest.mark.parametrize("copy", [True, None, False])
198+
def test_transpose_copy_keyword(using_copy_on_write, copy, using_array_manager):
199+
df = DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]})
200+
result = df.transpose(copy=copy)
201+
share_memory = using_copy_on_write or copy is False or copy is None
202+
share_memory = share_memory and not using_array_manager
203+
204+
if share_memory:
205+
assert np.shares_memory(get_array(df, "a"), get_array(result, 0))
206+
else:
207+
assert not np.shares_memory(get_array(df, "a"), get_array(result, 0))
208+
209+
134210
# -----------------------------------------------------------------------------
135211
# DataFrame methods returning new DataFrame using shallow copy
136212

@@ -1119,14 +1195,13 @@ def test_set_flags(using_copy_on_write):
11191195
tm.assert_series_equal(ser, expected)
11201196

11211197

1122-
@pytest.mark.parametrize("copy_kwargs", [{"copy": True}, {}])
11231198
@pytest.mark.parametrize("kwargs", [{"mapper": "test"}, {"index": "test"}])
1124-
def test_rename_axis(using_copy_on_write, kwargs, copy_kwargs):
1199+
def test_rename_axis(using_copy_on_write, kwargs):
11251200
df = DataFrame({"a": [1, 2, 3, 4]}, index=Index([1, 2, 3, 4], name="a"))
11261201
df_orig = df.copy()
1127-
df2 = df.rename_axis(**kwargs, **copy_kwargs)
1202+
df2 = df.rename_axis(**kwargs)
11281203

1129-
if using_copy_on_write and not copy_kwargs:
1204+
if using_copy_on_write:
11301205
assert np.shares_memory(get_array(df2, "a"), get_array(df, "a"))
11311206
else:
11321207
assert not np.shares_memory(get_array(df2, "a"), get_array(df, "a"))

pandas/tests/frame/methods/test_reindex.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ def test_reindex_copies_ea(self, using_copy_on_write):
149149

150150
# pass both columns and index
151151
result2 = df.reindex(columns=cols, index=df.index, copy=True)
152-
assert not np.shares_memory(result2[0].array._data, df[0].array._data)
152+
if using_copy_on_write:
153+
assert np.shares_memory(result2[0].array._data, df[0].array._data)
154+
else:
155+
assert not np.shares_memory(result2[0].array._data, df[0].array._data)
153156

154157
@td.skip_array_manager_not_yet_implemented
155158
def test_reindex_date_fill_value(self):

0 commit comments

Comments
 (0)