Skip to content

Commit 1007f66

Browse files
lukemanleyim-vinicius
authored and
im-vinicius
committed
BUG: Series.str.split(expand=True) for ArrowDtype(pa.string()) (pandas-dev#53532)
* BUG: Series.str.split(expand=True) for ArrowDtype(pa.string()) * whatsnew * min versions * ensure ArrowExtensionArray
1 parent db3f6bc commit 1007f66

File tree

3 files changed

+51
-5
lines changed

3 files changed

+51
-5
lines changed

doc/source/whatsnew/v2.0.3.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ Fixed regressions
2222
Bug fixes
2323
~~~~~~~~~
2424
- Bug in :func:`read_csv` when defining ``dtype`` with ``bool[pyarrow]`` for the ``"c"`` and ``"python"`` engines (:issue:`53390`)
25+
- Bug in :meth:`Series.str.split` and :meth:`Series.str.rsplit` with ``expand=True`` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`53532`)
26+
-
2527

2628
.. ---------------------------------------------------------------------------
2729
.. _whatsnew_203.other:

pandas/core/strings/accessor.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,14 +275,40 @@ def _wrap_result(
275275
if isinstance(result.dtype, ArrowDtype):
276276
import pyarrow as pa
277277

278+
from pandas.compat import pa_version_under11p0
279+
278280
from pandas.core.arrays.arrow.array import ArrowExtensionArray
279281

280-
max_len = pa.compute.max(
281-
result._pa_array.combine_chunks().value_lengths()
282-
).as_py()
283-
if result.isna().any():
282+
value_lengths = result._pa_array.combine_chunks().value_lengths()
283+
max_len = pa.compute.max(value_lengths).as_py()
284+
min_len = pa.compute.min(value_lengths).as_py()
285+
if result._hasna:
284286
# ArrowExtensionArray.fillna doesn't work for list scalars
285-
result._pa_array = result._pa_array.fill_null([None] * max_len)
287+
result = ArrowExtensionArray(
288+
result._pa_array.fill_null([None] * max_len)
289+
)
290+
if min_len < max_len:
291+
# append nulls to each scalar list element up to max_len
292+
if not pa_version_under11p0:
293+
result = ArrowExtensionArray(
294+
pa.compute.list_slice(
295+
result._pa_array,
296+
start=0,
297+
stop=max_len,
298+
return_fixed_size_list=True,
299+
)
300+
)
301+
else:
302+
all_null = np.full(max_len, fill_value=None, dtype=object)
303+
values = result.to_numpy()
304+
new_values = []
305+
for row in values:
306+
if len(row) < max_len:
307+
nulls = all_null[: max_len - len(row)]
308+
row = np.append(row, nulls)
309+
new_values.append(row)
310+
pa_type = result._pa_array.type
311+
result = ArrowExtensionArray(pa.array(new_values, type=pa_type))
286312
if name is not None:
287313
labels = name
288314
else:

pandas/tests/extension/test_arrow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,6 +2286,15 @@ def test_str_split():
22862286
)
22872287
tm.assert_frame_equal(result, expected)
22882288

2289+
result = ser.str.split("1", expand=True)
2290+
expected = pd.DataFrame(
2291+
{
2292+
0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])),
2293+
1: ArrowExtensionArray(pa.array(["cbcb", None, None])),
2294+
}
2295+
)
2296+
tm.assert_frame_equal(result, expected)
2297+
22892298

22902299
def test_str_rsplit():
22912300
# GH 52401
@@ -2311,6 +2320,15 @@ def test_str_rsplit():
23112320
)
23122321
tm.assert_frame_equal(result, expected)
23132322

2323+
result = ser.str.rsplit("1", expand=True)
2324+
expected = pd.DataFrame(
2325+
{
2326+
0: ArrowExtensionArray(pa.array(["a", "a2cbcb", None])),
2327+
1: ArrowExtensionArray(pa.array(["cbcb", None, None])),
2328+
}
2329+
)
2330+
tm.assert_frame_equal(result, expected)
2331+
23142332

23152333
def test_str_unsupported_extract():
23162334
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))

0 commit comments

Comments
 (0)