Skip to content

Commit 87d3fe4

Browse files
authored
ENH: Implement convert_dtypes on block level (#55341)
* ENH: Implement convert_dtypes on block level * ENH: Implement convert_dtypes on block level * Update * Fix typing * BUG: Fix convert_dtypes for all na column and arrow backend BUG: Fix convert_dtypes for all na column and arrow backend * Update cast.py * Fix * Fix * Fix typing
1 parent ac6cec3 commit 87d3fe4

File tree

6 files changed

+90
-66
lines changed

6 files changed

+90
-66
lines changed

pandas/core/dtypes/cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1133,7 +1133,7 @@ def convert_dtypes(
11331133
base_dtype = inferred_dtype
11341134
if (
11351135
base_dtype.kind == "O" # type: ignore[union-attr]
1136-
and len(input_array) > 0
1136+
and input_array.size > 0
11371137
and isna(input_array).all()
11381138
):
11391139
import pyarrow as pa

pandas/core/generic.py

Lines changed: 10 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6940,36 +6940,16 @@ def convert_dtypes(
69406940
dtype: string
69416941
"""
69426942
check_dtype_backend(dtype_backend)
6943-
if self.ndim == 1:
6944-
return self._convert_dtypes(
6945-
infer_objects,
6946-
convert_string,
6947-
convert_integer,
6948-
convert_boolean,
6949-
convert_floating,
6950-
dtype_backend=dtype_backend,
6951-
)
6952-
else:
6953-
results = [
6954-
col._convert_dtypes(
6955-
infer_objects,
6956-
convert_string,
6957-
convert_integer,
6958-
convert_boolean,
6959-
convert_floating,
6960-
dtype_backend=dtype_backend,
6961-
)
6962-
for col_name, col in self.items()
6963-
]
6964-
if len(results) > 0:
6965-
result = concat(results, axis=1, copy=False, keys=self.columns)
6966-
cons = cast(type["DataFrame"], self._constructor)
6967-
result = cons(result)
6968-
result = result.__finalize__(self, method="convert_dtypes")
6969-
# https://github.com/python/mypy/issues/8354
6970-
return cast(Self, result)
6971-
else:
6972-
return self.copy(deep=None)
6943+
new_mgr = self._mgr.convert_dtypes( # type: ignore[union-attr]
6944+
infer_objects=infer_objects,
6945+
convert_string=convert_string,
6946+
convert_integer=convert_integer,
6947+
convert_boolean=convert_boolean,
6948+
convert_floating=convert_floating,
6949+
dtype_backend=dtype_backend,
6950+
)
6951+
res = self._constructor_from_mgr(new_mgr, axes=new_mgr.axes)
6952+
return res.__finalize__(self, method="convert_dtypes")
69736953

69746954
# ----------------------------------------------------------------------
69756955
# Filling NA's

pandas/core/internals/blocks.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from pandas._typing import (
3434
ArrayLike,
3535
AxisInt,
36+
DtypeBackend,
3637
DtypeObj,
3738
F,
3839
FillnaOptions,
@@ -55,6 +56,7 @@
5556
from pandas.core.dtypes.cast import (
5657
LossySetitemError,
5758
can_hold_element,
59+
convert_dtypes,
5860
find_result_type,
5961
maybe_downcast_to_dtype,
6062
np_can_hold_element,
@@ -636,6 +638,52 @@ def convert(
636638
res_values = maybe_coerce_values(res_values)
637639
return [self.make_block(res_values, refs=refs)]
638640

641+
def convert_dtypes(
642+
self,
643+
copy: bool,
644+
using_cow: bool,
645+
infer_objects: bool = True,
646+
convert_string: bool = True,
647+
convert_integer: bool = True,
648+
convert_boolean: bool = True,
649+
convert_floating: bool = True,
650+
dtype_backend: DtypeBackend = "numpy_nullable",
651+
) -> list[Block]:
652+
if infer_objects and self.is_object:
653+
blks = self.convert(copy=False, using_cow=using_cow)
654+
else:
655+
blks = [self]
656+
657+
if not any(
658+
[convert_floating, convert_integer, convert_boolean, convert_string]
659+
):
660+
return [b.copy(deep=copy) for b in blks]
661+
662+
rbs = []
663+
for blk in blks:
664+
# Determine dtype column by column
665+
sub_blks = [blk] if blk.ndim == 1 or self.shape[0] == 1 else blk._split()
666+
dtypes = [
667+
convert_dtypes(
668+
b.values,
669+
convert_string,
670+
convert_integer,
671+
convert_boolean,
672+
convert_floating,
673+
infer_objects,
674+
dtype_backend,
675+
)
676+
for b in sub_blks
677+
]
678+
if all(dtype == self.dtype for dtype in dtypes):
679+
# Avoid block splitting if no dtype changes
680+
rbs.append(blk.copy(deep=copy))
681+
continue
682+
683+
for dtype, b in zip(dtypes, sub_blks):
684+
rbs.append(b.astype(dtype=dtype, copy=copy, squeeze=b.ndim != 1))
685+
return rbs
686+
639687
# ---------------------------------------------------------------------
640688
# Array-Like Methods
641689

@@ -651,6 +699,7 @@ def astype(
651699
copy: bool = False,
652700
errors: IgnoreRaise = "raise",
653701
using_cow: bool = False,
702+
squeeze: bool = False,
654703
) -> Block:
655704
"""
656705
Coerce to the new dtype.
@@ -665,12 +714,18 @@ def astype(
665714
- ``ignore`` : suppress exceptions. On error return original object
666715
using_cow: bool, default False
667716
Signaling if copy on write copy logic is used.
717+
squeeze : bool, default False
718+
squeeze values to ndim=1 if only one column is given
668719
669720
Returns
670721
-------
671722
Block
672723
"""
673724
values = self.values
725+
if squeeze and values.ndim == 2:
726+
if values.shape[0] != 1:
727+
raise ValueError("Can not squeeze with more than one column.")
728+
values = values[0, :] # type: ignore[call-overload]
674729

675730
new_values = astype_array_safe(values, dtype, copy=copy, errors=errors)
676731

pandas/core/internals/managers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,16 @@ def convert(self, copy: bool | None) -> Self:
464464

465465
return self.apply("convert", copy=copy, using_cow=using_copy_on_write())
466466

467+
def convert_dtypes(self, **kwargs):
468+
if using_copy_on_write():
469+
copy = False
470+
else:
471+
copy = True
472+
473+
return self.apply(
474+
"convert_dtypes", copy=copy, using_cow=using_copy_on_write(), **kwargs
475+
)
476+
467477
def get_values_for_csv(
468478
self, *, float_format, date_format, decimal, na_rep: str = "nan", quoting=None
469479
) -> Self:

pandas/core/series.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@
6161
from pandas.core.dtypes.astype import astype_is_view
6262
from pandas.core.dtypes.cast import (
6363
LossySetitemError,
64-
convert_dtypes,
6564
maybe_box_native,
6665
maybe_cast_pointwise_result,
6766
)
@@ -167,7 +166,6 @@
167166
CorrelationMethod,
168167
DropKeep,
169168
Dtype,
170-
DtypeBackend,
171169
DtypeObj,
172170
FilePath,
173171
Frequency,
@@ -5556,39 +5554,6 @@ def between(
55565554

55575555
return lmask & rmask
55585556

5559-
# ----------------------------------------------------------------------
5560-
# Convert to types that support pd.NA
5561-
5562-
def _convert_dtypes(
5563-
self,
5564-
infer_objects: bool = True,
5565-
convert_string: bool = True,
5566-
convert_integer: bool = True,
5567-
convert_boolean: bool = True,
5568-
convert_floating: bool = True,
5569-
dtype_backend: DtypeBackend = "numpy_nullable",
5570-
) -> Series:
5571-
input_series = self
5572-
if infer_objects:
5573-
input_series = input_series.infer_objects()
5574-
if is_object_dtype(input_series.dtype):
5575-
input_series = input_series.copy(deep=None)
5576-
5577-
if convert_string or convert_integer or convert_boolean or convert_floating:
5578-
inferred_dtype = convert_dtypes(
5579-
input_series._values,
5580-
convert_string,
5581-
convert_integer,
5582-
convert_boolean,
5583-
convert_floating,
5584-
infer_objects,
5585-
dtype_backend,
5586-
)
5587-
result = input_series.astype(inferred_dtype)
5588-
else:
5589-
result = input_series.copy(deep=None)
5590-
return result
5591-
55925557
# error: Cannot determine type of 'isna'
55935558
@doc(NDFrame.isna, klass=_shared_doc_kwargs["klass"]) # type: ignore[has-type]
55945559
def isna(self) -> Series:

pandas/tests/frame/methods/test_convert_dtypes.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,3 +175,17 @@ def test_convert_dtypes_pyarrow_timestamp(self):
175175
expected = ser.astype("timestamp[ms][pyarrow]")
176176
result = expected.convert_dtypes(dtype_backend="pyarrow")
177177
tm.assert_series_equal(result, expected)
178+
179+
def test_convert_dtypes_avoid_block_splitting(self):
180+
# GH#55341
181+
df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6], "c": "a"})
182+
result = df.convert_dtypes(convert_integer=False)
183+
expected = pd.DataFrame(
184+
{
185+
"a": [1, 2, 3],
186+
"b": [4, 5, 6],
187+
"c": pd.Series(["a"] * 3, dtype="string[python]"),
188+
}
189+
)
190+
tm.assert_frame_equal(result, expected)
191+
assert result._mgr.nblocks == 2

0 commit comments

Comments
 (0)