Skip to content

Use numba code for supported CAReduce cases #931

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 38 additions & 14 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,16 @@
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.math import (
All,
Argmax,
Max,
Min,
MulWithoutZeros,
Prod,
ProdWithoutZeros,
Sum,
)
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from pytensor.tensor.type import scalar

Expand Down Expand Up @@ -546,37 +555,52 @@ def ov_elemwise(*inputs):


@numba_funcify.register(Sum)
def numba_funcify_Sum(op, node, **kwargs):
@numba_funcify.register(Prod)
@numba_funcify.register(ProdWithoutZeros)
@numba_funcify.register(Max)
@numba_funcify.register(Min)
@numba_funcify.register(All)
@numba_funcify.register(Any)
def numba_funcify_CAReduce_specialized(op, node, **kwargs):
if isinstance(op, ProdWithoutZeros):
# ProdWithoutZeros is the same as Prod but the gradient can assume no zeros
np_op = np.prod
else:
np_op = getattr(np, op.__class__.__name__.lower())

axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))

axes = tuple(axes)
axes = tuple(sorted(axes))

ndim_input = node.inputs[0].ndim
out_dtype = np.dtype(node.outputs[0].dtype)

if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype

np_acc_dtype = np.dtype(acc_dtype)
if len(axes) == 0:

out_dtype = np.dtype(node.outputs[0].dtype)
@numba_njit(fastmath=True)
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)

if ndim_input == len(axes):
elif (
len(axes) == 1
# Some Ops don't support axis in Numba
and not isinstance(op, Prod | ProdWithoutZeros | All | Prod | Mean | Max | Min)
):

@numba_njit(fastmath=True)
def impl_sum(array):
return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype)
return np.asarray(np_op(array, axis=axes[0])).astype(out_dtype)

elif len(axes) == 0:
elif len(axes) == ndim_input:

@numba_njit(fastmath=True)
def impl_sum(array):
return np.asarray(array, dtype=out_dtype)
return np.asarray(np_op(array)).astype(out_dtype)

else:
# Slow path
impl_sum = numba_funcify_CAReduce(op, node, **kwargs)

return impl_sum
Expand Down
8 changes: 7 additions & 1 deletion pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,12 +1304,18 @@ def complex_from_polar(abs, angle):


class Mean(FixedOpCAReduce):
# FIXME: Mean is not a true CAReduce in the PyTensor sense, because it needs to keep
# track of the number of elements already reduced in order to work iteratively.
# This should subclass a `ReduceOp` which `CAReduce` could also inherit from.
__props__ = ("axis",)
nfunc_spec = ("mean", 1, 1)

def __init__(self, axis=None):
super().__init__(ps.mean, axis)
assert self.axis is None or len(self.axis) == 1
if not (self.axis is None or len(self.axis) == 1):
raise NotImplementedError(
"Mean Op only supports axis=None or a single axis. Use `mean` function instead"
)

def __str__(self):
if self.axis is not None:
Expand Down
188 changes: 38 additions & 150 deletions tests/link/numba/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,157 +236,21 @@ def test_Dimshuffle_non_contiguous():
assert func(np.zeros(3), np.array([1])).ndim == 0


@pytest.mark.parametrize(
"careduce_fn, axis, v",
[
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Mean(axis)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(0, 1),
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(1, 0),
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: ProdWithoutZeros(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
0,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
1,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Max(axis)(x),
None,
set_test_value(
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Min(axis)(x),
None,
set_test_value(
pt.lmatrix(), np.arange(3 * 2, dtype=np.int64).reshape((3, 2))
),
),
],
)
def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
@pytest.mark.parametrize("axis", [0, None, (0, 1)])
@pytest.mark.parametrize("op", [Sum, Prod, ProdWithoutZeros, All, Any, Mean, Max, Min])
def test_CAReduce(op, axis):
if op == Mean and isinstance(axis, tuple) and len(axis) > 1:
pytest.xfail("Mean does not support multiple partial axes")

compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
bool_reduction = op in (All, Any)
x = pt.tensor3("x", dtype=bool if bool_reduction else config.floatX)
g = op(axis=axis)(x)
g_fg = FunctionGraph([x], [g])

x_test = np.random.normal(size=(2, 3, 4)).astype(config.floatX)
if bool_reduction:
x_test = x_test > 0
compare_numba_and_py(g_fg, [x_test])


def test_scalar_Elemwise_Clip():
Expand Down Expand Up @@ -665,3 +529,27 @@ def test_elemwise_out_type():
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))

assert func(x_val).shape == (18,)


@pytest.mark.parametrize("axis", [0, 2, (0, 2), None])
@pytest.mark.parametrize("op", [Sum, Max, Any])
def test_careduce_benchmark(benchmark, op, axis):
rng = np.random.default_rng(123)
N = 256
if op == All:
# Sparse tensor
value = np.zeros((N, N, N), dtype="bool")
true_arrays = np.random.choice(N, size=N // 2, replace=False)
true_rows = np.random.choice(N, size=N // 2, replace=False)
true_cols = np.random.choice(N, size=N // 2, replace=False)
value[true_arrays, true_rows, true_cols] = True
else:
value = rng.normal(size=(N, N, N))

x = pytensor.shared(value, name="x")
out = op(axis=axis)(x)

func = pytensor.function([], [out], mode="NUMBA")
# JIT compile first
func()
benchmark(func)
Loading