37
37
38
38
from pandas .core .dtypes .cast import (
39
39
maybe_cast_pointwise_result ,
40
- maybe_cast_result_dtype ,
41
40
maybe_downcast_to_dtype ,
42
41
)
43
42
from pandas .core .dtypes .common import (
@@ -262,6 +261,41 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
262
261
out_dtype = "object"
263
262
return np .dtype (out_dtype )
264
263
264
+ def get_result_dtype (self , dtype : DtypeObj ) -> DtypeObj :
265
+ """
266
+ Get the desired dtype of a result based on the
267
+ input dtype and how it was computed.
268
+
269
+ Parameters
270
+ ----------
271
+ dtype : np.dtype or ExtensionDtype
272
+ Input dtype.
273
+
274
+ Returns
275
+ -------
276
+ np.dtype or ExtensionDtype
277
+ The desired dtype of the result.
278
+ """
279
+ from pandas .core .arrays .boolean import BooleanDtype
280
+ from pandas .core .arrays .floating import Float64Dtype
281
+ from pandas .core .arrays .integer import (
282
+ Int64Dtype ,
283
+ _IntegerDtype ,
284
+ )
285
+
286
+ how = self .how
287
+
288
+ if how in ["add" , "cumsum" , "sum" , "prod" ]:
289
+ if dtype == np .dtype (bool ):
290
+ return np .dtype (np .int64 )
291
+ elif isinstance (dtype , (BooleanDtype , _IntegerDtype )):
292
+ return Int64Dtype ()
293
+ elif how in ["mean" , "median" , "var" ] and isinstance (
294
+ dtype , (BooleanDtype , _IntegerDtype )
295
+ ):
296
+ return Float64Dtype ()
297
+ return dtype
298
+
265
299
def uses_mask (self ) -> bool :
266
300
return self .how in self ._MASKED_CYTHON_FUNCTIONS
267
301
@@ -564,7 +598,14 @@ def get_group_levels(self) -> list[Index]:
564
598
565
599
@final
566
600
def _ea_wrap_cython_operation (
567
- self , kind : str , values , how : str , axis : int , min_count : int = - 1 , ** kwargs
601
+ self ,
602
+ cy_op : WrappedCythonOp ,
603
+ kind : str ,
604
+ values ,
605
+ how : str ,
606
+ axis : int ,
607
+ min_count : int = - 1 ,
608
+ ** kwargs ,
568
609
) -> ArrayLike :
569
610
"""
570
611
If we have an ExtensionArray, unwrap, call _cython_operation, and
@@ -601,7 +642,7 @@ def _ea_wrap_cython_operation(
601
642
# other cast_blocklist methods dont go through cython_operation
602
643
return res_values
603
644
604
- dtype = maybe_cast_result_dtype (orig_values .dtype , how )
645
+ dtype = cy_op . get_result_dtype (orig_values .dtype )
605
646
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
606
647
# has no attribute "construct_array_type"
607
648
cls = dtype .construct_array_type () # type: ignore[union-attr]
@@ -618,7 +659,7 @@ def _ea_wrap_cython_operation(
618
659
# other cast_blocklist methods dont go through cython_operation
619
660
return res_values
620
661
621
- dtype = maybe_cast_result_dtype (orig_values .dtype , how )
662
+ dtype = cy_op . get_result_dtype (orig_values .dtype )
622
663
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
623
664
# has no attribute "construct_array_type"
624
665
cls = dtype .construct_array_type () # type: ignore[union-attr]
@@ -631,6 +672,7 @@ def _ea_wrap_cython_operation(
631
672
@final
632
673
def _masked_ea_wrap_cython_operation (
633
674
self ,
675
+ cy_op : WrappedCythonOp ,
634
676
kind : str ,
635
677
values : BaseMaskedArray ,
636
678
how : str ,
@@ -651,7 +693,7 @@ def _masked_ea_wrap_cython_operation(
651
693
res_values = self ._cython_operation (
652
694
kind , arr , how , axis , min_count , mask = mask , ** kwargs
653
695
)
654
- dtype = maybe_cast_result_dtype (orig_values .dtype , how )
696
+ dtype = cy_op . get_result_dtype (orig_values .dtype )
655
697
assert isinstance (dtype , BaseMaskedDtype )
656
698
cls = dtype .construct_array_type ()
657
699
@@ -694,11 +736,11 @@ def _cython_operation(
694
736
if is_extension_array_dtype (dtype ):
695
737
if isinstance (values , BaseMaskedArray ) and func_uses_mask :
696
738
return self ._masked_ea_wrap_cython_operation (
697
- kind , values , how , axis , min_count , ** kwargs
739
+ cy_op , kind , values , how , axis , min_count , ** kwargs
698
740
)
699
741
else :
700
742
return self ._ea_wrap_cython_operation (
701
- kind , values , how , axis , min_count , ** kwargs
743
+ cy_op , kind , values , how , axis , min_count , ** kwargs
702
744
)
703
745
704
746
elif values .ndim == 1 :
@@ -797,7 +839,7 @@ def _cython_operation(
797
839
if how not in cy_op .cast_blocklist :
798
840
# e.g. if we are int64 and need to restore to datetime64/timedelta64
799
841
# "rank" is the only member of cast_blocklist we get here
800
- dtype = maybe_cast_result_dtype (orig_values .dtype , how )
842
+ dtype = cy_op . get_result_dtype (orig_values .dtype )
801
843
op_result = maybe_downcast_to_dtype (result , dtype )
802
844
else :
803
845
op_result = result
0 commit comments