Skip to content

Commit 92e458d

Browse files
authored
BUG: is_numeric_dtype(ArrowDtype[numeric]) not returning True (#50572)
* BUG: is_numeric_dtype(ArrowDtype[numeric]) not returning True * Adjust decimal type to be numeric * Some arrow tests now working
1 parent 5483590 commit 92e458d

File tree

5 files changed

+43
-16
lines changed

5 files changed

+43
-16
lines changed

doc/source/whatsnew/v2.0.0.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -989,6 +989,7 @@ ExtensionArray
989989
- Bug in :meth:`Series.round` for pyarrow-backed dtypes raising ``AttributeError`` (:issue:`50437`)
990990
- Bug when concatenating an empty DataFrame with an ExtensionDtype to another DataFrame with the same ExtensionDtype, the resulting dtype turned into object (:issue:`48510`)
991991
- Bug in :meth:`array.PandasArray.to_numpy` raising with ``NA`` value when ``na_value`` is specified (:issue:`40638`)
992+
- Bug in :meth:`api.types.is_numeric_dtype` where a custom :class:`ExtensionDtype` would not return ``True`` if ``_is_numeric`` returned ``True`` (:issue:`50563`)
992993

993994
Styler
994995
^^^^^^

pandas/core/dtypes/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,6 +1200,8 @@ def is_numeric_dtype(arr_or_dtype) -> bool:
12001200
"""
12011201
return _is_dtype_type(
12021202
arr_or_dtype, classes_and_not_datetimelike(np.number, np.bool_)
1203+
) or _is_dtype(
1204+
arr_or_dtype, lambda typ: isinstance(typ, ExtensionDtype) and typ._is_numeric
12031205
)
12041206

12051207

pandas/tests/dtypes/test_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,24 @@ def test_is_numeric_dtype():
556556
assert com.is_numeric_dtype(pd.Series([1, 2]))
557557
assert com.is_numeric_dtype(pd.Index([1, 2.0]))
558558

559+
class MyNumericDType(ExtensionDtype):
560+
@property
561+
def type(self):
562+
return str
563+
564+
@property
565+
def name(self):
566+
raise NotImplementedError
567+
568+
@classmethod
569+
def construct_array_type(cls):
570+
raise NotImplementedError
571+
572+
def _is_numeric(self) -> bool:
573+
return True
574+
575+
assert com.is_numeric_dtype(MyNumericDType())
576+
559577

560578
def test_is_float_dtype():
561579
assert not com.is_float_dtype(str)

pandas/tests/extension/test_arrow.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838
import pandas as pd
3939
import pandas._testing as tm
40-
from pandas.api.types import is_bool_dtype
40+
from pandas.api.types import (
41+
is_bool_dtype,
42+
is_numeric_dtype,
43+
)
4144
from pandas.tests.extension import base
4245

4346
pa = pytest.importorskip("pyarrow", minversion="1.0.1")
@@ -550,16 +553,6 @@ def test_groupby_extension_apply(
550553
):
551554
super().test_groupby_extension_apply(data_for_grouping, groupby_apply_op)
552555

553-
def test_in_numeric_groupby(self, data_for_grouping, request):
554-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
555-
if pa.types.is_integer(pa_dtype) or pa.types.is_floating(pa_dtype):
556-
request.node.add_marker(
557-
pytest.mark.xfail(
558-
reason="ArrowExtensionArray doesn't support .sum() yet.",
559-
)
560-
)
561-
super().test_in_numeric_groupby(data_for_grouping)
562-
563556
@pytest.mark.parametrize("as_index", [True, False])
564557
def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
565558
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
@@ -1446,6 +1439,19 @@ def test_is_bool_dtype():
14461439
tm.assert_series_equal(result, expected)
14471440

14481441

1442+
def test_is_numeric_dtype(data):
1443+
# GH 50563
1444+
pa_type = data.dtype.pyarrow_dtype
1445+
if (
1446+
pa.types.is_floating(pa_type)
1447+
or pa.types.is_integer(pa_type)
1448+
or pa.types.is_decimal(pa_type)
1449+
):
1450+
assert is_numeric_dtype(data)
1451+
else:
1452+
assert not is_numeric_dtype(data)
1453+
1454+
14491455
def test_pickle_roundtrip(data):
14501456
# GH 42600
14511457
expected = pd.Series(data)

pandas/tests/io/json/test_json_table_schema_ext_dtype.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_build_table_schema(self):
4848
"fields": [
4949
{"name": "index", "type": "integer"},
5050
{"name": "A", "type": "any", "extDtype": "DateDtype"},
51-
{"name": "B", "type": "any", "extDtype": "decimal"},
51+
{"name": "B", "type": "number", "extDtype": "decimal"},
5252
{"name": "C", "type": "any", "extDtype": "string"},
5353
{"name": "D", "type": "integer", "extDtype": "Int64"},
5454
],
@@ -82,10 +82,10 @@ def test_as_json_table_type_ext_date_dtype(self):
8282
],
8383
)
8484
def test_as_json_table_type_ext_decimal_array_dtype(self, decimal_data):
85-
assert as_json_table_type(decimal_data.dtype) == "any"
85+
assert as_json_table_type(decimal_data.dtype) == "number"
8686

8787
def test_as_json_table_type_ext_decimal_dtype(self):
88-
assert as_json_table_type(DecimalDtype()) == "any"
88+
assert as_json_table_type(DecimalDtype()) == "number"
8989

9090
@pytest.mark.parametrize(
9191
"string_data",
@@ -180,7 +180,7 @@ def test_build_decimal_series(self, dc):
180180

181181
fields = [
182182
{"name": "id", "type": "integer"},
183-
{"name": "a", "type": "any", "extDtype": "decimal"},
183+
{"name": "a", "type": "number", "extDtype": "decimal"},
184184
]
185185

186186
schema = {"fields": fields, "primaryKey": ["id"]}
@@ -257,7 +257,7 @@ def test_to_json(self, df):
257257
fields = [
258258
OrderedDict({"name": "idx", "type": "integer"}),
259259
OrderedDict({"name": "A", "type": "any", "extDtype": "DateDtype"}),
260-
OrderedDict({"name": "B", "type": "any", "extDtype": "decimal"}),
260+
OrderedDict({"name": "B", "type": "number", "extDtype": "decimal"}),
261261
OrderedDict({"name": "C", "type": "any", "extDtype": "string"}),
262262
OrderedDict({"name": "D", "type": "integer", "extDtype": "Int64"}),
263263
]

0 commit comments

Comments
 (0)