2
2
3
3
import abc
4
4
from collections import defaultdict
5
+ import functools
5
6
from functools import partial
6
7
import inspect
7
8
from typing import (
29
30
NDFrameT ,
30
31
npt ,
31
32
)
33
+ from pandas .compat ._optional import import_optional_dependency
32
34
from pandas .errors import SpecificationError
33
35
from pandas .util ._decorators import cache_readonly
34
36
from pandas .util ._exceptions import find_stack_level
35
37
36
38
from pandas .core .dtypes .cast import is_nested_object
37
39
from pandas .core .dtypes .common import (
38
40
is_dict_like ,
41
+ is_extension_array_dtype ,
39
42
is_list_like ,
43
+ is_numeric_dtype ,
40
44
is_sequence ,
41
45
)
42
46
from pandas .core .dtypes .dtypes import (
@@ -121,6 +125,8 @@ def __init__(
121
125
result_type : str | None ,
122
126
* ,
123
127
by_row : Literal [False , "compat" , "_compat" ] = "compat" ,
128
+ engine : str = "python" ,
129
+ engine_kwargs : dict [str , bool ] | None = None ,
124
130
args ,
125
131
kwargs ,
126
132
) -> None :
@@ -133,6 +139,9 @@ def __init__(
133
139
self .args = args or ()
134
140
self .kwargs = kwargs or {}
135
141
142
+ self .engine = engine
143
+ self .engine_kwargs = {} if engine_kwargs is None else engine_kwargs
144
+
136
145
if result_type not in [None , "reduce" , "broadcast" , "expand" ]:
137
146
raise ValueError (
138
147
"invalid value for result_type, must be one "
@@ -601,6 +610,13 @@ def apply_list_or_dict_like(self) -> DataFrame | Series:
601
610
result: Series, DataFrame, or None
602
611
Result when self.func is a list-like or dict-like, None otherwise.
603
612
"""
613
+
614
+ if self .engine == "numba" :
615
+ raise NotImplementedError (
616
+ "The 'numba' engine doesn't support list-like/"
617
+ "dict likes of callables yet."
618
+ )
619
+
604
620
if self .axis == 1 and isinstance (self .obj , ABCDataFrame ):
605
621
return self .obj .T .apply (self .func , 0 , args = self .args , ** self .kwargs ).T
606
622
@@ -768,10 +784,16 @@ def __init__(
768
784
) -> None :
769
785
if by_row is not False and by_row != "compat" :
770
786
raise ValueError (f"by_row={ by_row } not allowed" )
771
- self .engine = engine
772
- self .engine_kwargs = engine_kwargs
773
787
super ().__init__ (
774
- obj , func , raw , result_type , by_row = by_row , args = args , kwargs = kwargs
788
+ obj ,
789
+ func ,
790
+ raw ,
791
+ result_type ,
792
+ by_row = by_row ,
793
+ engine = engine ,
794
+ engine_kwargs = engine_kwargs ,
795
+ args = args ,
796
+ kwargs = kwargs ,
775
797
)
776
798
777
799
# ---------------------------------------------------------------
@@ -792,6 +814,32 @@ def result_columns(self) -> Index:
792
814
def series_generator (self ) -> Generator [Series , None , None ]:
793
815
pass
794
816
817
+ @staticmethod
818
+ @functools .cache
819
+ @abc .abstractmethod
820
+ def generate_numba_apply_func (
821
+ func , nogil = True , nopython = True , parallel = False
822
+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
823
+ pass
824
+
825
+ @abc .abstractmethod
826
+ def apply_with_numba (self ):
827
+ pass
828
+
829
+ def validate_values_for_numba (self ):
830
+ # Validate column dtyps all OK
831
+ for colname , dtype in self .obj .dtypes .items ():
832
+ if not is_numeric_dtype (dtype ):
833
+ raise ValueError (
834
+ f"Column { colname } must have a numeric dtype. "
835
+ f"Found '{ dtype } ' instead"
836
+ )
837
+ if is_extension_array_dtype (dtype ):
838
+ raise ValueError (
839
+ f"Column { colname } is backed by an extension array, "
840
+ f"which is not supported by the numba engine."
841
+ )
842
+
795
843
@abc .abstractmethod
796
844
def wrap_results_for_axis (
797
845
self , results : ResType , res_index : Index
@@ -815,13 +863,12 @@ def values(self):
815
863
def apply (self ) -> DataFrame | Series :
816
864
"""compute the results"""
817
865
818
- if self .engine == "numba" and not self .raw :
819
- raise ValueError (
820
- "The numba engine in DataFrame.apply can only be used when raw=True"
821
- )
822
-
823
866
# dispatch to handle list-like or dict-like
824
867
if is_list_like (self .func ):
868
+ if self .engine == "numba" :
869
+ raise NotImplementedError (
870
+ "the 'numba' engine doesn't support lists of callables yet"
871
+ )
825
872
return self .apply_list_or_dict_like ()
826
873
827
874
# all empty
@@ -830,17 +877,31 @@ def apply(self) -> DataFrame | Series:
830
877
831
878
# string dispatch
832
879
if isinstance (self .func , str ):
880
+ if self .engine == "numba" :
881
+ raise NotImplementedError (
882
+ "the 'numba' engine doesn't support using "
883
+ "a string as the callable function"
884
+ )
833
885
return self .apply_str ()
834
886
835
887
# ufunc
836
888
elif isinstance (self .func , np .ufunc ):
889
+ if self .engine == "numba" :
890
+ raise NotImplementedError (
891
+ "the 'numba' engine doesn't support "
892
+ "using a numpy ufunc as the callable function"
893
+ )
837
894
with np .errstate (all = "ignore" ):
838
895
results = self .obj ._mgr .apply ("apply" , func = self .func )
839
896
# _constructor will retain self.index and self.columns
840
897
return self .obj ._constructor_from_mgr (results , axes = results .axes )
841
898
842
899
# broadcasting
843
900
if self .result_type == "broadcast" :
901
+ if self .engine == "numba" :
902
+ raise NotImplementedError (
903
+ "the 'numba' engine doesn't support result_type='broadcast'"
904
+ )
844
905
return self .apply_broadcast (self .obj )
845
906
846
907
# one axis empty
@@ -997,7 +1058,10 @@ def apply_broadcast(self, target: DataFrame) -> DataFrame:
997
1058
return result
998
1059
999
1060
def apply_standard (self ):
1000
- results , res_index = self .apply_series_generator ()
1061
+ if self .engine == "python" :
1062
+ results , res_index = self .apply_series_generator ()
1063
+ else :
1064
+ results , res_index = self .apply_series_numba ()
1001
1065
1002
1066
# wrap results
1003
1067
return self .wrap_results (results , res_index )
@@ -1021,6 +1085,19 @@ def apply_series_generator(self) -> tuple[ResType, Index]:
1021
1085
1022
1086
return results , res_index
1023
1087
1088
+ def apply_series_numba (self ):
1089
+ if self .engine_kwargs .get ("parallel" , False ):
1090
+ raise NotImplementedError (
1091
+ "Parallel apply is not supported when raw=False and engine='numba'"
1092
+ )
1093
+ if not self .obj .index .is_unique or not self .columns .is_unique :
1094
+ raise NotImplementedError (
1095
+ "The index/columns must be unique when raw=False and engine='numba'"
1096
+ )
1097
+ self .validate_values_for_numba ()
1098
+ results = self .apply_with_numba ()
1099
+ return results , self .result_index
1100
+
1024
1101
def wrap_results (self , results : ResType , res_index : Index ) -> DataFrame | Series :
1025
1102
from pandas import Series
1026
1103
@@ -1060,6 +1137,49 @@ class FrameRowApply(FrameApply):
1060
1137
def series_generator (self ) -> Generator [Series , None , None ]:
1061
1138
return (self .obj ._ixs (i , axis = 1 ) for i in range (len (self .columns )))
1062
1139
1140
+ @staticmethod
1141
+ @functools .cache
1142
+ def generate_numba_apply_func (
1143
+ func , nogil = True , nopython = True , parallel = False
1144
+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
1145
+ numba = import_optional_dependency ("numba" )
1146
+ from pandas import Series
1147
+
1148
+ # Import helper from extensions to cast string object -> np strings
1149
+ # Note: This also has the side effect of loading our numba extensions
1150
+ from pandas .core ._numba .extensions import maybe_cast_str
1151
+
1152
+ jitted_udf = numba .extending .register_jitable (func )
1153
+
1154
+ # Currently the parallel argument doesn't get passed through here
1155
+ # (it's disabled) since the dicts in numba aren't thread-safe.
1156
+ @numba .jit (nogil = nogil , nopython = nopython , parallel = parallel )
1157
+ def numba_func (values , col_names , df_index ):
1158
+ results = {}
1159
+ for j in range (values .shape [1 ]):
1160
+ # Create the series
1161
+ ser = Series (
1162
+ values [:, j ], index = df_index , name = maybe_cast_str (col_names [j ])
1163
+ )
1164
+ results [j ] = jitted_udf (ser )
1165
+ return results
1166
+
1167
+ return numba_func
1168
+
1169
+ def apply_with_numba (self ) -> dict [int , Any ]:
1170
+ nb_func = self .generate_numba_apply_func (
1171
+ cast (Callable , self .func ), ** self .engine_kwargs
1172
+ )
1173
+ from pandas .core ._numba .extensions import set_numba_data
1174
+
1175
+ # Convert from numba dict to regular dict
1176
+ # Our isinstance checks in the df constructor don't pass for numbas typed dict
1177
+ with set_numba_data (self .obj .index ) as index , set_numba_data (
1178
+ self .columns
1179
+ ) as columns :
1180
+ res = dict (nb_func (self .values , columns , index ))
1181
+ return res
1182
+
1063
1183
@property
1064
1184
def result_index (self ) -> Index :
1065
1185
return self .columns
@@ -1143,6 +1263,52 @@ def series_generator(self) -> Generator[Series, None, None]:
1143
1263
object .__setattr__ (ser , "_name" , name )
1144
1264
yield ser
1145
1265
1266
+ @staticmethod
1267
+ @functools .cache
1268
+ def generate_numba_apply_func (
1269
+ func , nogil = True , nopython = True , parallel = False
1270
+ ) -> Callable [[npt .NDArray , Index , Index ], dict [int , Any ]]:
1271
+ numba = import_optional_dependency ("numba" )
1272
+ from pandas import Series
1273
+ from pandas .core ._numba .extensions import maybe_cast_str
1274
+
1275
+ jitted_udf = numba .extending .register_jitable (func )
1276
+
1277
+ @numba .jit (nogil = nogil , nopython = nopython , parallel = parallel )
1278
+ def numba_func (values , col_names_index , index ):
1279
+ results = {}
1280
+ # Currently the parallel argument doesn't get passed through here
1281
+ # (it's disabled) since the dicts in numba aren't thread-safe.
1282
+ for i in range (values .shape [0 ]):
1283
+ # Create the series
1284
+ # TODO: values corrupted without the copy
1285
+ ser = Series (
1286
+ values [i ].copy (),
1287
+ index = col_names_index ,
1288
+ name = maybe_cast_str (index [i ]),
1289
+ )
1290
+ results [i ] = jitted_udf (ser )
1291
+
1292
+ return results
1293
+
1294
+ return numba_func
1295
+
1296
+ def apply_with_numba (self ) -> dict [int , Any ]:
1297
+ nb_func = self .generate_numba_apply_func (
1298
+ cast (Callable , self .func ), ** self .engine_kwargs
1299
+ )
1300
+
1301
+ from pandas .core ._numba .extensions import set_numba_data
1302
+
1303
+ # Convert from numba dict to regular dict
1304
+ # Our isinstance checks in the df constructor don't pass for numbas typed dict
1305
+ with set_numba_data (self .obj .index ) as index , set_numba_data (
1306
+ self .columns
1307
+ ) as columns :
1308
+ res = dict (nb_func (self .values , columns , index ))
1309
+
1310
+ return res
1311
+
1146
1312
@property
1147
1313
def result_index (self ) -> Index :
1148
1314
return self .index
0 commit comments