Skip to content

Commit 31ed4c9

Browse files
committed
ENH/API: ExtensionArray.factorize
Adds factorize to the interface for ExtensionArray, with a default implementation. This is a stepping stone to groupby.
1 parent 38afa93 commit 31ed4c9

File tree

10 files changed

+144
-6
lines changed

10 files changed

+144
-6
lines changed

pandas/core/algorithms.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -146,10 +146,9 @@ def _reconstruct_data(values, dtype, original):
146146
Returns
147147
-------
148148
Index for extension types, otherwise ndarray casted to dtype
149-
150149
"""
151150
from pandas import Index
152-
if is_categorical_dtype(dtype):
151+
if is_extension_array_dtype(dtype):
153152
pass
154153
elif is_datetime64tz_dtype(dtype) or is_period_dtype(dtype):
155154
values = Index(original)._shallow_copy(values, name=None)
@@ -502,9 +501,9 @@ def factorize(values, sort=False, order=None, na_sentinel=-1, size_hint=None):
502501
values = _ensure_arraylike(values)
503502
original = values
504503

505-
if is_categorical_dtype(values):
504+
if is_extension_array_dtype(values):
506505
values = getattr(values, '_values', values)
507-
labels, uniques = values.factorize()
506+
labels, uniques = values.factorize(na_sentinel=na_sentinel)
508507
dtype = original.dtype
509508
else:
510509
values, dtype, _ = _ensure_data(values)

pandas/core/arrays/base.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,41 @@ def unique(self):
248248
uniques = unique(self.astype(object))
249249
return self._constructor_from_sequence(uniques)
250250

251+
def factorize(self, na_sentinel=-1):
252+
"""Encode the extension array as an enumerated type.
253+
254+
Parameters
255+
----------
256+
na_sentinel : int, default -1
257+
Value to use in the `labels` array to indicate missing values.
258+
259+
Returns
260+
-------
261+
labels : ndarray
262+
An interger NumPy array that's an indexer into the original
263+
ExtensionArray
264+
uniques : ExtensionArray
265+
An ExtensionArray containing the unique values of `self`.
266+
267+
See Also
268+
--------
269+
pandas.factorize : top-level factorize method that dispatches here.
270+
271+
Notes
272+
-----
273+
:meth:`pandas.factorize` offers a `sort` keyword as well.
274+
"""
275+
from pandas.core.algorithms import _factorize_array
276+
277+
mask = self.isna()
278+
arr = self.astype(object)
279+
arr[mask] = np.nan
280+
281+
labels, uniques = _factorize_array(arr, check_nulls=True,
282+
na_sentinel=na_sentinel)
283+
uniques = self._constructor_from_sequence(uniques)
284+
return labels, uniques
285+
251286
# ------------------------------------------------------------------------
252287
# Indexing methods
253288
# ------------------------------------------------------------------------

pandas/tests/extension/base/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@
44
class BaseExtensionTests(object):
55
assert_series_equal = staticmethod(tm.assert_series_equal)
66
assert_frame_equal = staticmethod(tm.assert_frame_equal)
7+
assert_extension_array_equal = staticmethod(
8+
tm.assert_extension_array_equal
9+
)

pandas/tests/extension/base/methods.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33

44
import pandas as pd
5+
import pandas.util.testing as tm
56

67
from .base import BaseExtensionTests
78

@@ -42,3 +43,22 @@ def test_unique(self, data, box, method):
4243
assert len(result) == 1
4344
assert isinstance(result, type(data))
4445
assert result[0] == duplicated[0]
46+
47+
@pytest.mark.parametrize('na_sentinel', [-1, -2])
48+
def test_factorize(self, data_for_grouping, na_sentinel):
49+
labels, uniques = pd.factorize(data_for_grouping,
50+
na_sentinel=na_sentinel)
51+
expected_labels = np.array([0, 0, na_sentinel,
52+
na_sentinel, 1, 1, 0, 2],
53+
dtype='int64')
54+
expected_uniques = data_for_grouping.take([0, 4, 7])
55+
56+
tm.assert_numpy_array_equal(labels, expected_labels)
57+
self.assert_extension_array_equal(uniques, expected_uniques)
58+
59+
def test_factorize_equivalence(self, data_for_grouping):
60+
l1, u1 = pd.factorize(data_for_grouping)
61+
l2, u2 = pd.factorize(data_for_grouping)
62+
63+
tm.assert_numpy_array_equal(l1, l2)
64+
self.assert_extension_array_equal(u1, u2)

pandas/tests/extension/category/test_categorical.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,11 @@ def na_value():
3434
return np.nan
3535

3636

37+
@pytest.fixture
38+
def data_for_grouping():
39+
return Categorical(['a', 'a', None, None, 'b', 'b', 'a', 'c'])
40+
41+
3742
class TestDtype(base.BaseDtypeTests):
3843
pass
3944

pandas/tests/extension/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ def na_cmp():
4646
def na_value():
4747
"""The scalar missing value for this type. Default 'None'"""
4848
return None
49+
50+
51+
@pytest.fixture
52+
def data_for_grouping():
53+
"""Data for factorization, grouping, and unique tests.
54+
55+
Expected to be like [B, B, NA, NA, A, A, B, C]
56+
57+
Where A < B < C and NA is missing
58+
"""
59+
raise NotImplementedError

pandas/tests/extension/decimal/test_decimal.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,15 @@ def na_value():
3535
return decimal.Decimal("NaN")
3636

3737

38+
@pytest.fixture
39+
def data_for_grouping():
40+
b = decimal.Decimal('1.0')
41+
a = decimal.Decimal('0.0')
42+
c = decimal.Decimal('2.0')
43+
na = decimal.Decimal('NaN')
44+
return DecimalArray([b, b, na, na, a, a, b, c])
45+
46+
3847
class TestDtype(base.BaseDtypeTests):
3948
pass
4049

pandas/tests/extension/json/array.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import numpy as np
99

10+
import pandas as pd
1011
from pandas.core.dtypes.base import ExtensionDtype
1112
from pandas.core.arrays import ExtensionArray
1213

@@ -104,6 +105,21 @@ def _concat_same_type(cls, to_concat):
104105
data = list(itertools.chain.from_iterable([x.data for x in to_concat]))
105106
return cls(data)
106107

108+
def factorize(self, na_sentinel=-1):
109+
frozen = tuple(tuple(x.items()) for x in self)
110+
labels, uniques = pd.factorize(frozen)
111+
112+
# fixup NA
113+
if self.isna().any():
114+
na_code = self.isna().argmax()
115+
116+
labels[labels == na_code] = na_sentinel
117+
labels[labels > na_code] -= 1
118+
119+
uniques = JSONArray([collections.UserDict(x)
120+
for x in uniques if x != ()])
121+
return labels, uniques
122+
107123

108124
def make_data():
109125
# TODO: Use a regular dict. See _NDFrameIndexer._setitem_with_indexer

pandas/tests/extension/json/test_json.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ def na_cmp():
3939
return operator.eq
4040

4141

42+
@pytest.fixture
43+
def data_for_grouping():
44+
return JSONArray([
45+
{'b': 1}, {'b': 1},
46+
{}, {},
47+
{'a': 0, 'c': 2}, {'a': 0, 'c': 2},
48+
{'b': 1},
49+
{'c': 2},
50+
])
51+
52+
4253
class TestDtype(base.BaseDtypeTests):
4354
pass
4455

@@ -64,8 +75,10 @@ class TestMissing(base.BaseMissingTests):
6475

6576

6677
class TestMethods(base.BaseMethodsTests):
67-
@pytest.mark.skip(reason="Unhashable")
68-
def test_value_counts(self, all_data, dropna):
78+
unhashable = pytest.mark.skip(reason="Unhashable")
79+
80+
@unhashable
81+
def test_factorize(self):
6982
pass
7083

7184

pandas/util/testing.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121

2222
import pandas as pd
23+
from pandas.core.arrays.base import ExtensionArray
2324
from pandas.core.dtypes.missing import array_equivalent
2425
from pandas.core.dtypes.common import (
2526
is_datetimelike_v_numeric,
@@ -1083,6 +1084,32 @@ def _raise(left, right, err_msg):
10831084
return True
10841085

10851086

1087+
def assert_extension_array_equal(left, right):
1088+
"""Check that left and right ExtensionArrays are equal.
1089+
1090+
Parameters
1091+
----------
1092+
left, right : ExtensionArray
1093+
The two arrays to compare
1094+
1095+
Notes
1096+
-----
1097+
Missing values are checked separately from valid values.
1098+
A mask of missing values is computed for each and checked to match.
1099+
The remaining all-valid values are cast to object dtype and checked.
1100+
"""
1101+
assert isinstance(left, ExtensionArray)
1102+
assert left.dtype == right.dtype
1103+
left_na = left.isna()
1104+
right_na = right.isna()
1105+
assert_numpy_array_equal(left_na, right_na)
1106+
1107+
left_valid = left[~left_na].astype(object)
1108+
right_valid = right[~right_na].astype(object)
1109+
1110+
assert_numpy_array_equal(left_valid, right_valid)
1111+
1112+
10861113
# This could be refactored to use the NDFrame.equals method
10871114
def assert_series_equal(left, right, check_dtype=True,
10881115
check_index_type='equiv',

0 commit comments

Comments
 (0)