Skip to content

Commit 5a462e9

Browse files
committed
Fix slow dot in numba
1 parent 2d414d4 commit 5a462e9

File tree

2 files changed

+75
-29
lines changed

2 files changed

+75
-29
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -565,26 +565,27 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
565565
def int_to_float_fn(inputs, out_dtype):
566566
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
567567

568-
if all(
569-
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
570-
) and isinstance(np.dtype(out_dtype), np.floating):
568+
if (
569+
all(inp.type.dtype == out_dtype for inp in inputs)
570+
and np.dtype(out_dtype).kind == "f"
571+
):
571572

572-
@numba_njit
573+
@numba_njit(inline="always")
573574
def inputs_cast(x):
574575
return x
575576

576-
elif any(i.type.numpy_dtype.kind in "ib" for i in inputs):
577+
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
577578
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
578579

579-
@numba_njit
580+
@numba_njit(inline="always")
580581
def inputs_cast(x):
581582
return x.astype(args_dtype)
582583

583584
else:
584585
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
585586
args_dtype = np.dtype(f"f{args_dtype_sz}")
586587

587-
@numba_njit
588+
@numba_njit(inline="always")
588589
def inputs_cast(x):
589590
return x.astype(args_dtype)
590591

@@ -593,17 +594,49 @@ def inputs_cast(x):
593594

594595
@numba_funcify.register(Dot)
595596
def numba_funcify_Dot(op, node, **kwargs):
596-
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
597-
# float.
597+
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
598+
x, y = node.inputs
599+
[out] = node.outputs
598600

599-
out_dtype = node.outputs[0].type.numpy_dtype
600-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
601+
x_dtype = x.type.dtype
602+
y_dtype = y.type.dtype
603+
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
604+
out_dtype = out.type.dtype
601605

602-
@numba_njit
603-
def dot(x, y):
604-
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
606+
if x_dtype == dot_dtype and y_dtype == dot_dtype:
607+
608+
@numba_njit
609+
def dot(x, y):
610+
return np.asarray(np.dot(x, y))
611+
612+
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
613+
614+
@numba_njit
615+
def dot(x, y):
616+
return np.asarray(np.dot(x, y.astype(dot_dtype)))
617+
618+
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
619+
620+
@numba_njit
621+
def dot(x, y):
622+
return np.asarray(np.dot(x.astype(dot_dtype), y))
623+
624+
else:
625+
626+
@numba_njit()
627+
def dot(x, y):
628+
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
629+
630+
if out_dtype == dot_dtype:
631+
return dot
632+
633+
else:
634+
635+
@numba_njit
636+
def dot_with_cast(x, y):
637+
return dot(x, y).astype(out_dtype)
605638

606-
return dot
639+
return dot_with_cast
607640

608641

609642
@numba_funcify.register(Solve)

tests/link/numba/test_basic.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytensor.link.numba.linker import NumbaLinker
3131
from pytensor.raise_op import assert_op
3232
from pytensor.scalar.basic import ScalarOp, as_scalar
33-
from pytensor.tensor import blas
33+
from pytensor.tensor import blas, tensor
3434
from pytensor.tensor.elemwise import Elemwise
3535
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
3636
from pytensor.tensor.sort import ArgSortOp, SortOp
@@ -603,43 +603,41 @@ def test_perform_type_convert():
603603

604604

605605
@pytest.mark.parametrize(
606-
"x, y, exc",
606+
"x, y",
607607
[
608608
(
609609
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
610610
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
611-
None,
612611
),
613612
(
614613
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
615614
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
616-
None,
617615
),
618616
(
619617
(pt.lmatrix(), rng.poisson(size=(3, 2))),
620618
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
621-
None,
622619
),
623620
(
624621
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
625622
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
626-
None,
623+
),
624+
(
625+
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
626+
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
627627
),
628628
],
629629
)
630-
def test_Dot(x, y, exc):
630+
def test_Dot(x, y):
631631
x, x_test_value = x
632632
y, y_test_value = y
633633

634634
g = ptm.Dot()(x, y)
635635

636-
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
637-
with cm:
638-
compare_numba_and_py(
639-
[x, y],
640-
[g],
641-
[x_test_value, y_test_value],
642-
)
636+
compare_numba_and_py(
637+
[x, y],
638+
[g],
639+
[x_test_value, y_test_value],
640+
)
643641

644642

645643
@pytest.mark.parametrize(
@@ -937,3 +935,18 @@ def test_Nonzero(input_data):
937935
compare_numba_and_py(
938936
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
939937
)
938+
939+
940+
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
941+
def test_mat_vec_dot_performance(dtype, benchmark):
942+
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
943+
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
944+
out = ptm.dot(A, x)
945+
946+
fn = function([A, x], out, mode="NUMBA", trust_input=True)
947+
948+
rng = np.random.default_rng(948)
949+
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
950+
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
951+
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
952+
benchmark(fn, A_test, x_test)

0 commit comments

Comments
 (0)