|
45 | 45 | from pytensor.scalar.basic import add as add_as
|
46 | 46 | from pytensor.scalar.basic import scalar_maximum
|
47 | 47 | from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
|
48 |
| -from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros |
| 48 | +from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum |
49 | 49 | from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
|
50 | 50 | from pytensor.tensor.type import scalar
|
51 | 51 |
|
@@ -649,6 +649,42 @@ def elemwise_wrapper(*inputs):
|
649 | 649 | return elemwise_wrapper
|
650 | 650 |
|
651 | 651 |
|
| 652 | +@numba_funcify.register(Sum) |
| 653 | +def numba_funcify_Sum(op, node, **kwargs): |
| 654 | + axes = op.axis |
| 655 | + if axes is None: |
| 656 | + axes = list(range(node.inputs[0].ndim)) |
| 657 | + |
| 658 | + axes = list(axes) |
| 659 | + |
| 660 | + ndim_input = node.inputs[0].ndim |
| 661 | + |
| 662 | + if hasattr(op, "acc_dtype") and op.acc_dtype is not None: |
| 663 | + acc_dtype = op.acc_dtype |
| 664 | + else: |
| 665 | + acc_dtype = node.outputs[0].type.dtype |
| 666 | + |
| 667 | + np_acc_dtype = np.dtype(acc_dtype) |
| 668 | + |
| 669 | + if ndim_input == len(axes): |
| 670 | + @numba_njit |
| 671 | + def impl_sum(array): |
| 672 | + |
| 673 | + # TODO The accumulation itself should happen in acc_dtype... |
| 674 | + #return array.sum(axes).astype(np_acc_dtype) |
| 675 | + return np.asarray(array.sum())#.astype(np_acc_dtype) |
| 676 | + |
| 677 | + else: |
| 678 | + @numba_njit |
| 679 | + def impl_sum(array): |
| 680 | + |
| 681 | + # TODO The accumulation itself should happen in acc_dtype... |
| 682 | + #return array.sum(axes).astype(np_acc_dtype) |
| 683 | + return np.asarray(array.sum(axes))#.astype(np_acc_dtype) |
| 684 | + |
| 685 | + return impl_sum |
| 686 | + |
| 687 | + |
652 | 688 | @numba_funcify.register(CAReduce)
|
653 | 689 | def numba_funcify_CAReduce(op, node, **kwargs):
|
654 | 690 | axes = op.axis
|
|
0 commit comments