Skip to content

ENH: Add dtype argument to StringMethods get_dummies() #59577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e6f9527
Add prefix, prefix_sep, dummy_na, and dtype args to StringMethods get…
aaronchucarroll Aug 21, 2024
dafb61d
Fix import issue
aaronchucarroll Aug 21, 2024
bb79ef2
Fix typing of dtype
aaronchucarroll Aug 21, 2024
24be84f
Fix NaN type issue
aaronchucarroll Aug 21, 2024
09b2fad
Support categorical string backend
aaronchucarroll Aug 21, 2024
50ed90c
Fix dtype type hints
aaronchucarroll Aug 21, 2024
9e95485
Add dtype to get_dummies docstring
aaronchucarroll Aug 21, 2024
9a47768
Fix get_dummies dtype docstring
aaronchucarroll Aug 21, 2024
0c94bff
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Aug 22, 2024
9702bf7
remove changes for unnecessary args
aaronchucarroll Sep 3, 2024
8793516
Merge branch 'stringmethods-get-dummies' of https://github.com/aaronc…
aaronchucarroll Sep 3, 2024
bad1038
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 3, 2024
163fe09
parametrize dtype tests
aaronchucarroll Sep 5, 2024
3d75fdc
Merge branch 'stringmethods-get-dummies' of https://github.com/aaronc…
aaronchucarroll Sep 5, 2024
d68bece
support pyarrow and nullable dtypes
aaronchucarroll Sep 5, 2024
c2aa7d5
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 5, 2024
0fd2401
fix pyarrow import error
aaronchucarroll Sep 5, 2024
920c865
skip pyarrow tests when not present
aaronchucarroll Sep 5, 2024
800f787
split pyarrow tests
aaronchucarroll Sep 5, 2024
d8149e6
Merge branch 'main' into stringmethods-get-dummies
aaronchucarroll Sep 5, 2024
6cbc3e8
parametrize pyarrow tests
aaronchucarroll Sep 7, 2024
532e139
change var name to dummies_dtype
aaronchucarroll Sep 7, 2024
cd5c2ab
fix string issue
aaronchucarroll Sep 7, 2024
822b3f4
consolidate conditionals
aaronchucarroll Sep 7, 2024
ba05a8d
add tests for str and pyarrow strings
aaronchucarroll Sep 7, 2024
37dddb8
skip pyarrow string tests if not present
aaronchucarroll Sep 7, 2024
6fbe183
add info to whatsnew doc
aaronchucarroll Sep 9, 2024
87a1ee8
change func to meth in doc info
aaronchucarroll Sep 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2539,20 +2539,39 @@ def _str_findall(self, pat: str, flags: int = 0) -> Self:
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(
self, sep: str = "|", dummy_na: bool = False, dtype: NpDtype | None = None
):
if dtype is None:
dtype = np.bool_
split = pc.split_pattern(self._pa_array, sep)
flattened_values = pc.list_flatten(split)
if dummy_na:
nan_mask = self._pa_array.is_null()
flattened_values = flattened_values.fill_null(pa.NA)
uniques = flattened_values.unique()
uniques_sorted = uniques.take(pa.compute.array_sort_indices(uniques))
if dummy_na:
if "__nan__" not in uniques_sorted.to_pylist():
uniques_sorted = pa.concat_arrays(
[uniques_sorted, pa.array(["__nan__"], type=uniques_sorted.type)]
)
lengths = pc.list_value_length(split).fill_null(0).to_numpy()
n_rows = len(self)
n_cols = len(uniques)
indices = pc.index_in(flattened_values, uniques_sorted).to_numpy()
indices = indices + np.arange(n_rows).repeat(lengths) * n_cols
dummies = np.zeros(n_rows * n_cols, dtype=np.bool_)
dummies = np.zeros(n_rows * n_cols, dtype=dtype)
dummies[indices] = True
dummies = dummies.reshape((n_rows, n_cols))
if dummy_na:
nan_column = nan_mask.to_numpy().reshape(-1, 1)
dummies = np.hstack([dummies, nan_column])
result = type(self)(pa.array(list(dummies)))
if dummy_na:
uniques_sorted = pa.array(
["NaN" if x == "__nan__" else x for x in uniques_sorted.to_pylist()]
)
return result, uniques_sorted.to_pylist()

def _str_index(self, sub: str, start: int = 0, end: int | None = None) -> Self:
Expand Down
8 changes: 6 additions & 2 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2686,11 +2686,15 @@ def _str_map(
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(
self, sep: str = "|", dummy_na: bool = False, dtype: NpDtype | None = None
):
# sep may not be in categories. Just bail on this.
from pandas.core.arrays import NumpyExtensionArray

return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
return NumpyExtensionArray(self.astype(str))._str_get_dummies(
sep, dummy_na, dtype
)

# ------------------------------------------------------------------------
# GroupBy Methods
Expand Down
15 changes: 11 additions & 4 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
ArrayLike,
AxisInt,
Dtype,
NpDtype,
Scalar,
Self,
npt,
Expand Down Expand Up @@ -488,12 +489,18 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None):
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
def _str_get_dummies(
self, sep: str = "|", dummy_na: bool = False, dtype: NpDtype | None = None
):
if dtype is None:
dtype = np.int64
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(
sep, dummy_na, dtype
)
if len(labels) == 0:
return np.empty(shape=(0, 0), dtype=np.int64), labels
return np.empty(shape=(0, 0), dtype=dtype), labels
dummies = np.vstack(dummies_pa.to_numpy())
return dummies.astype(np.int64, copy=False), labels
return dummies.astype(dtype, copy=False), labels

def _convert_int_dtype(self, result):
if self.dtype.na_value is np.nan:
Expand Down
73 changes: 71 additions & 2 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,12 @@
from collections.abc import (
Callable,
Hashable,
Iterable,
Iterator,
)

from pandas._typing import NpDtype

from pandas import (
DataFrame,
Index,
Expand Down Expand Up @@ -2398,7 +2401,14 @@ def wrap(
return self._wrap_result(result)

@forbid_nonstring_types(["bytes"])
def get_dummies(self, sep: str = "|"):
def get_dummies(
self,
sep: str = "|",
prefix: str | Iterable[str] | dict[str, str] | None = None,
prefix_sep: str = "_",
dummy_na: bool = False,
dtype: NpDtype | None = None,
):
"""
Return DataFrame of dummy/indicator variables for Series.

Expand All @@ -2409,6 +2419,17 @@ def get_dummies(self, sep: str = "|"):
----------
sep : str, default "|"
String to split on.
prefix : str, list of str, or dict of str, default None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't users just prefix the columns by calling .rename after the call to get_dummies?

String to append DataFrame column names.
Pass a list with length equal to the number of columns
when calling get_dummies on a DataFrame. Alternatively, `prefix`
can be a dictionary mapping column names to prefixes.
prefix_sep : str, default '_'
If appending prefix, separator/delimiter to use.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't users just add the separator to their prefix, e.g. prefix="prefix_"?

dummy_na : bool, default False
Add a column to indicate NaNs, if False NaNs are ignored.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me users can already do this in a straightforward manner:

pd.concat([ser.str.get_dummies(), ser.isna().rename("NaN")], axis=1)

Is this not sufficient?

dtype : dtype, default np.int64
Data type for new columns. Only a single dtype is allowed.

Returns
-------
Expand All @@ -2433,10 +2454,58 @@ def get_dummies(self, sep: str = "|"):
0 1 1 0
1 0 0 0
2 1 0 1

>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dummy_na=True)
a b c NaN
0 1 1 0 0
1 0 0 0 1
2 1 0 1 0

>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(prefix="prefix")
prefix_a prefix_b prefix_c
0 1 1 0
1 0 0 0
2 1 0 1

>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(
... prefix={"a": "alpha", "b": "beta", "c": "gamma"}
... )
alpha_a beta_b gamma_c
0 1 1 0
1 0 0 0
2 1 0 1

>>> pd.Series(["a|b", np.nan, "a|c"]).str.get_dummies(dtype=bool)
a b c
0 True True False
1 False False False
2 True False True
"""
# we need to cast to Series of strings as only that has all
# methods available for making the dummies...
result, name = self._data.array._str_get_dummies(sep)
result, name = self._data.array._str_get_dummies(sep, dummy_na, dtype)
name = [np.nan if x == "NaN" else x for x in name]
if isinstance(prefix, str):
name = [f"{prefix}{prefix_sep}{col}" for col in name]
elif isinstance(prefix, dict):
if len(prefix) != len(name):
len_msg = (
f"Length of 'prefix' ({len(prefix)}) did not match the "
"length of the columns being encoded "
f"({len(name)})."
)
raise ValueError(len_msg)
name = [f"{prefix[col]}{prefix_sep}{col}" for col in name]
elif isinstance(prefix, list):
if len(prefix) != len(name):
len_msg = (
f"Length of 'prefix' ({len(prefix)}) did not match the "
"length of the columns being encoded "
f"({len(name)})."
)
raise ValueError(len_msg)
name = [f"{prefix[i]}{prefix_sep}{col}" for i, col in enumerate(name)]

return self._wrap_result(
result,
name=name,
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import re

from pandas._typing import (
NpDtype,
Scalar,
Self,
)
Expand Down Expand Up @@ -163,7 +164,9 @@ def _str_wrap(self, width: int, **kwargs):
pass

@abc.abstractmethod
def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(
self, sep: str = "|", dummy_na: bool = False, dtype: NpDtype | None = None
):
pass

@abc.abstractmethod
Expand Down
12 changes: 10 additions & 2 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,13 @@ def _str_wrap(self, width: int, **kwargs):
tw = textwrap.TextWrapper(**kwargs)
return self._str_map(lambda s: "\n".join(tw.wrap(s)))

def _str_get_dummies(self, sep: str = "|"):
def _str_get_dummies(
self, sep: str = "|", dummy_na: bool = False, dtype: NpDtype | None = None
):
from pandas import Series

if dtype is None:
dtype = np.int64
arr = Series(self).fillna("")
try:
arr = sep + arr + sep
Expand All @@ -386,7 +390,7 @@ def _str_get_dummies(self, sep: str = "|"):
tags.update(ts)
tags2 = sorted(tags - {""})

dummies = np.empty((len(arr), len(tags2)), dtype=np.int64)
dummies = np.empty((len(arr), len(tags2)), dtype=dtype)

def _isin(test_elements: str, element: str) -> bool:
return element in test_elements
Expand All @@ -396,6 +400,10 @@ def _isin(test_elements: str, element: str) -> bool:
dummies[:, i] = lib.map_infer(
arr.to_numpy(), functools.partial(_isin, element=pat)
)
if dummy_na:
nan_col = Series(self).isna().astype(dtype).to_numpy()
dummies = np.column_stack((dummies, nan_col))
tags2.append("NaN")
return dummies, tags2

def _str_upper(self):
Expand Down
56 changes: 56 additions & 0 deletions pandas/tests/strings/test_get_dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,59 @@ def test_get_dummies_with_name_dummy_index():
[(1, 1, 0, 0), (0, 0, 1, 1), (0, 1, 0, 1)], names=("a", "b", "c", "name")
)
tm.assert_index_equal(result, expected)


def test_get_dummies_with_prefix(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies(sep="|", prefix="prefix")
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
columns=["prefix_a", "prefix_b", "prefix_c"],
)
tm.assert_frame_equal(result, expected)


def test_get_dummies_with_prefix_sep(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies(sep="|", prefix=None, prefix_sep="__")
expected = DataFrame([[1, 1, 0], [1, 0, 1], [0, 0, 0]], columns=["a", "b", "c"])
tm.assert_frame_equal(result, expected)

result = s.str.get_dummies(sep="|", prefix="col", prefix_sep="__")
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
columns=["col__a", "col__b", "col__c"],
)
tm.assert_frame_equal(result, expected)


def test_get_dummies_with_dummy_na(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies(sep="|", dummy_na=True)
expected = DataFrame(
[[1, 1, 0, 0], [1, 0, 1, 0], [0, 0, 0, 1]],
columns=["a", "b", "c", np.nan],
)
tm.assert_frame_equal(result, expected)


def test_get_dummies_with_dtype(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
result = s.str.get_dummies(sep="|", dtype=bool)
expected = DataFrame(
[[True, True, False], [True, False, True], [False, False, False]],
columns=["a", "b", "c"],
)
tm.assert_frame_equal(result, expected)
assert (result.dtypes == bool).all()


def test_get_dummies_with_prefix_dict(any_string_dtype):
s = Series(["a|b", "a|c", np.nan], dtype=any_string_dtype)
prefix = {"a": "alpha", "b": "beta", "c": "gamma"}
result = s.str.get_dummies(sep="|", prefix=prefix)
expected = DataFrame(
[[1, 1, 0], [1, 0, 1], [0, 0, 0]],
columns=["alpha_a", "beta_b", "gamma_c"],
)
tm.assert_frame_equal(result, expected)