Skip to content

ENH: Support arrow/parquet roundtrip for nullable integer / string extension dtypes #29483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
1 change: 1 addition & 0 deletions doc/source/user_guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4717,6 +4717,7 @@ Several caveats.
* The ``pyarrow`` engine preserves the ``ordered`` flag of categorical dtypes with string types. ``fastparquet`` does not preserve the ``ordered`` flag.
* Non supported types include ``Period`` and actual Python object types. These will raise a helpful error message
on an attempt at serialization.
* The ``pyarrow`` engine preserves extension data types such as the nullable integer and string data type (requiring pyarrow >= 1.0.0).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only preserve integer and string extension types or does it preserve all? I assume the former but somewhat unclear in note

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only preserve integer and string extension types or does it preserve all? I assume the former but somewhat unclear in note

It does preserve it for all dtypes that have implemented the needed protocol methods (__arrow_array__, __from_arrow__), which right now is integer and string for pandas. And potentially any external EA as well, in case those support those protocols.

Tried to clarify this in the docs.


You can specify an ``engine`` to direct the serialization. This can be one of ``pyarrow``, or ``fastparquet``, or ``auto``.
If the engine is NOT specified, then the ``pd.options.io.parquet.engine`` option is checked; if this is also ``auto``,
Expand Down
3 changes: 3 additions & 0 deletions doc/source/whatsnew/v1.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ Other enhancements
- Added ``encoding`` argument to :meth:`DataFrame.to_string` for non-ascii text (:issue:`28766`)
- Added ``encoding`` argument to :func:`DataFrame.to_html` for non-ascii text (:issue:`28663`)
- :meth:`Styler.background_gradient` now accepts ``vmin`` and ``vmax`` arguments (:issue:`12145`)
- Roundtripping DataFrames with nullable integer or string data types to parquet
(:meth:`~DataFrame.to_parquet` / :func:`read_parquet`) using the `'pyarrow'` engine
now preserve those data types with pyarrow >= 1.0.0 (:issue:`20612`).

Build Changes
^^^^^^^^^^^^^
Expand Down
28 changes: 28 additions & 0 deletions pandas/core/arrays/integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,34 @@ def construct_array_type(cls):
"""
return IntegerArray

def __from_arrow__(self, array):
"""Construct IntegerArray from passed pyarrow Array"""
import pyarrow

if isinstance(array, pyarrow.Array):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When is this False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The passed pyarrow values can be either a pyarrow.Array or pyarrow.ChunkedArray. Added a comment for this

chunks = [array]
else:
chunks = array.chunks

results = []
for arr in chunks:
buflist = arr.buffers()
data = np.frombuffer(buflist[1], dtype=self.type)[
arr.offset : arr.offset + len(arr)
]
bitmask = buflist[0]
if bitmask is not None:
mask = pyarrow.BooleanArray.from_buffers(
pyarrow.bool_(), len(arr), [None, bitmask]
)
mask = np.asarray(mask)
else:
mask = np.ones(len(arr), dtype=bool)
int_arr = IntegerArray(data.copy(), ~mask, copy=False)
results.append(int_arr)

return IntegerArray._concat_same_type(results)


def integer_array(values, dtype=None, copy=False):
"""
Expand Down
16 changes: 16 additions & 0 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,22 @@ def construct_array_type(cls) -> "Type[StringArray]":
def __repr__(self) -> str:
return "StringDtype"

def __from_arrow__(self, array):
"""Construct StringArray from passed pyarrow Array"""
import pyarrow

if isinstance(array, pyarrow.Array):
chunks = [array]
else:
chunks = array.chunks

results = []
for arr in chunks:
str_arr = StringArray(np.array(arr))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@TomAugspurger Is there a way to turn off validation when creating a StringArray? (in this case I know the elements will be strings or None)
I didn't directly see a private method on StringArray that does this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not at the moment. In StringArray.__init__ we have

skip_validation = isinstance(values, type(self))

I'd be open to a public API for allowing the user / library to state that they have the right values.

One slight concern though: right now it's expect that the values are strings or nan, and we may expect to change the NA value to change in the future. Libraries will have to adapt to that if they're claiming to be compatible.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, so actually the validation right now is needed here, even if it is just to convert None to nan?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the None -> nan replacement happens in in _from_sequence, though I have a TODO to move it to _validate since that's already requires a pass over the data.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, good catch, I switched to _from_sequence for now, and added a test for it.

results.append(str_arr)

return StringArray._concat_same_type(results)


class StringArray(PandasArray):
"""
Expand Down
14 changes: 14 additions & 0 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,17 @@ def test_arrow_array():
arr = pa.array(data)
expected = pa.array(list(data), type=pa.string(), from_pandas=True)
assert arr.equals(expected)


@td.skip_if_no("pyarrow", min_version="0.15.1.dev")
def test_arrow_roundtrip():
# roundtrip possible from arrow 1.0.0
import pyarrow as pa

data = pd.array(["a", "b", "c"], dtype="string")
df = pd.DataFrame({"a": data})
table = pa.table(df)
assert table.field("a").type == "string"
result = table.to_pandas()
assert isinstance(result["a"].dtype, pd.StringDtype)
tm.assert_frame_equal(result, df)
12 changes: 12 additions & 0 deletions pandas/tests/arrays/test_integer.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,6 +829,18 @@ def test_arrow_array(data):
assert arr.equals(expected)


@td.skip_if_no("pyarrow", min_version="0.15.1.dev")
def test_arrow_roundtrip(data):
# roundtrip possible from arrow 1.0.0
import pyarrow as pa

df = pd.DataFrame({"a": data})
table = pa.table(df)
assert table.field("a").type == str(data.dtype.numpy_dtype)
result = table.to_pandas()
tm.assert_frame_equal(result, df)


# TODO(jreback) - these need testing / are broken

# shift
Expand Down
14 changes: 10 additions & 4 deletions pandas/tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,13 +514,19 @@ def test_additional_extension_arrays(self, pa):
"b": pd.Series(["a", None, "c"], dtype="string"),
}
)
# currently de-serialized as plain int / object
expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object"))
if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"):
expected = df
else:
# de-serialized as plain int / object
expected = df.assign(a=df.a.astype("int64"), b=df.b.astype("object"))
check_round_trip(df, pa, expected=expected)

df = pd.DataFrame({"a": pd.Series([1, 2, 3, None], dtype="Int64")})
# if missing values in integer, currently de-serialized as float
expected = df.assign(a=df.a.astype("float64"))
if LooseVersion(pyarrow.__version__) >= LooseVersion("0.15.1.dev"):
expected = df
else:
# if missing values in integer, currently de-serialized as float
expected = df.assign(a=df.a.astype("float64"))
check_round_trip(df, pa, expected=expected)


Expand Down