Skip to content

Commit 3435ebf

Browse files
authored
REF: make maybe_cast_result_dtype a WrappedCythonOp method (#41065)
1 parent 4caf4c7 commit 3435ebf

File tree

2 files changed

+50
-44
lines changed

2 files changed

+50
-44
lines changed

pandas/core/dtypes/cast.py

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -406,42 +406,6 @@ def maybe_cast_pointwise_result(
406406
return result
407407

408408

409-
def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
410-
"""
411-
Get the desired dtype of a result based on the
412-
input dtype and how it was computed.
413-
414-
Parameters
415-
----------
416-
dtype : DtypeObj
417-
Input dtype.
418-
how : str
419-
How the result was computed.
420-
421-
Returns
422-
-------
423-
DtypeObj
424-
The desired dtype of the result.
425-
"""
426-
from pandas.core.arrays.boolean import BooleanDtype
427-
from pandas.core.arrays.floating import Float64Dtype
428-
from pandas.core.arrays.integer import (
429-
Int64Dtype,
430-
_IntegerDtype,
431-
)
432-
433-
if how in ["add", "cumsum", "sum", "prod"]:
434-
if dtype == np.dtype(bool):
435-
return np.dtype(np.int64)
436-
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
437-
return Int64Dtype()
438-
elif how in ["mean", "median", "var"] and isinstance(
439-
dtype, (BooleanDtype, _IntegerDtype)
440-
):
441-
return Float64Dtype()
442-
return dtype
443-
444-
445409
def maybe_cast_to_extension_array(
446410
cls: type[ExtensionArray], obj: ArrayLike, dtype: ExtensionDtype | None = None
447411
) -> ArrayLike:

pandas/core/groupby/ops.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
from pandas.core.dtypes.cast import (
3939
maybe_cast_pointwise_result,
40-
maybe_cast_result_dtype,
4140
maybe_downcast_to_dtype,
4241
)
4342
from pandas.core.dtypes.common import (
@@ -262,6 +261,41 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
262261
out_dtype = "object"
263262
return np.dtype(out_dtype)
264263

264+
def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
265+
"""
266+
Get the desired dtype of a result based on the
267+
input dtype and how it was computed.
268+
269+
Parameters
270+
----------
271+
dtype : np.dtype or ExtensionDtype
272+
Input dtype.
273+
274+
Returns
275+
-------
276+
np.dtype or ExtensionDtype
277+
The desired dtype of the result.
278+
"""
279+
from pandas.core.arrays.boolean import BooleanDtype
280+
from pandas.core.arrays.floating import Float64Dtype
281+
from pandas.core.arrays.integer import (
282+
Int64Dtype,
283+
_IntegerDtype,
284+
)
285+
286+
how = self.how
287+
288+
if how in ["add", "cumsum", "sum", "prod"]:
289+
if dtype == np.dtype(bool):
290+
return np.dtype(np.int64)
291+
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
292+
return Int64Dtype()
293+
elif how in ["mean", "median", "var"] and isinstance(
294+
dtype, (BooleanDtype, _IntegerDtype)
295+
):
296+
return Float64Dtype()
297+
return dtype
298+
265299
def uses_mask(self) -> bool:
266300
return self.how in self._MASKED_CYTHON_FUNCTIONS
267301

@@ -564,7 +598,14 @@ def get_group_levels(self) -> list[Index]:
564598

565599
@final
566600
def _ea_wrap_cython_operation(
567-
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
601+
self,
602+
cy_op: WrappedCythonOp,
603+
kind: str,
604+
values,
605+
how: str,
606+
axis: int,
607+
min_count: int = -1,
608+
**kwargs,
568609
) -> ArrayLike:
569610
"""
570611
If we have an ExtensionArray, unwrap, call _cython_operation, and
@@ -601,7 +642,7 @@ def _ea_wrap_cython_operation(
601642
# other cast_blocklist methods dont go through cython_operation
602643
return res_values
603644

604-
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
645+
dtype = cy_op.get_result_dtype(orig_values.dtype)
605646
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
606647
# has no attribute "construct_array_type"
607648
cls = dtype.construct_array_type() # type: ignore[union-attr]
@@ -618,7 +659,7 @@ def _ea_wrap_cython_operation(
618659
# other cast_blocklist methods dont go through cython_operation
619660
return res_values
620661

621-
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
662+
dtype = cy_op.get_result_dtype(orig_values.dtype)
622663
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
623664
# has no attribute "construct_array_type"
624665
cls = dtype.construct_array_type() # type: ignore[union-attr]
@@ -631,6 +672,7 @@ def _ea_wrap_cython_operation(
631672
@final
632673
def _masked_ea_wrap_cython_operation(
633674
self,
675+
cy_op: WrappedCythonOp,
634676
kind: str,
635677
values: BaseMaskedArray,
636678
how: str,
@@ -651,7 +693,7 @@ def _masked_ea_wrap_cython_operation(
651693
res_values = self._cython_operation(
652694
kind, arr, how, axis, min_count, mask=mask, **kwargs
653695
)
654-
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
696+
dtype = cy_op.get_result_dtype(orig_values.dtype)
655697
assert isinstance(dtype, BaseMaskedDtype)
656698
cls = dtype.construct_array_type()
657699

@@ -694,11 +736,11 @@ def _cython_operation(
694736
if is_extension_array_dtype(dtype):
695737
if isinstance(values, BaseMaskedArray) and func_uses_mask:
696738
return self._masked_ea_wrap_cython_operation(
697-
kind, values, how, axis, min_count, **kwargs
739+
cy_op, kind, values, how, axis, min_count, **kwargs
698740
)
699741
else:
700742
return self._ea_wrap_cython_operation(
701-
kind, values, how, axis, min_count, **kwargs
743+
cy_op, kind, values, how, axis, min_count, **kwargs
702744
)
703745

704746
elif values.ndim == 1:
@@ -797,7 +839,7 @@ def _cython_operation(
797839
if how not in cy_op.cast_blocklist:
798840
# e.g. if we are int64 and need to restore to datetime64/timedelta64
799841
# "rank" is the only member of cast_blocklist we get here
800-
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
842+
dtype = cy_op.get_result_dtype(orig_values.dtype)
801843
op_result = maybe_downcast_to_dtype(result, dtype)
802844
else:
803845
op_result = result

0 commit comments

Comments
 (0)