Skip to content

Commit f66814d

Browse files
committed
Specialized numba sum impl
1 parent adc66bb commit f66814d

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

pytensor/link/numba/dispatch/elemwise.py

+37-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from pytensor.scalar.basic import add as add_as
4646
from pytensor.scalar.basic import scalar_maximum
4747
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
48-
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros
48+
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
4949
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
5050
from pytensor.tensor.type import scalar
5151

@@ -649,6 +649,42 @@ def elemwise_wrapper(*inputs):
649649
return elemwise_wrapper
650650

651651

652+
@numba_funcify.register(Sum)
653+
def numba_funcify_Sum(op, node, **kwargs):
654+
axes = op.axis
655+
if axes is None:
656+
axes = list(range(node.inputs[0].ndim))
657+
658+
axes = list(axes)
659+
660+
ndim_input = node.inputs[0].ndim
661+
662+
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
663+
acc_dtype = op.acc_dtype
664+
else:
665+
acc_dtype = node.outputs[0].type.dtype
666+
667+
np_acc_dtype = np.dtype(acc_dtype)
668+
669+
if ndim_input == len(axes):
670+
@numba_njit
671+
def impl_sum(array):
672+
673+
# TODO The accumulation itself should happen in acc_dtype...
674+
#return array.sum(axes).astype(np_acc_dtype)
675+
return np.asarray(array.sum())#.astype(np_acc_dtype)
676+
677+
else:
678+
@numba_njit
679+
def impl_sum(array):
680+
681+
# TODO The accumulation itself should happen in acc_dtype...
682+
#return array.sum(axes).astype(np_acc_dtype)
683+
return np.asarray(array.sum(axes))#.astype(np_acc_dtype)
684+
685+
return impl_sum
686+
687+
652688
@numba_funcify.register(CAReduce)
653689
def numba_funcify_CAReduce(op, node, **kwargs):
654690
axes = op.axis

0 commit comments

Comments
 (0)