Skip to content

Commit 9b7d707

Browse files
committed
Suppress caching warning when compiling Numba functions
1 parent 475fe3a commit 9b7d707

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

pytensor/link/numba/dispatch/basic.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import scipy.special
1515
from llvmlite import ir
1616
from numba import types
17-
from numba.core.errors import TypingError
17+
from numba.core.errors import NumbaWarning, TypingError
1818
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1919
from numba.extending import box, overload
2020

@@ -61,6 +61,13 @@ def global_numba_func(func):
6161
def numba_njit(*args, **kwargs):
6262
kwargs.setdefault("cache", config.numba__cache)
6363

64+
# Supress caching warnings
65+
warnings.filterwarnings(
66+
"ignore",
67+
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
68+
category=NumbaWarning,
69+
)
70+
6471
if len(args) > 0 and callable(args[0]):
6572
return numba.njit(*args[1:], **kwargs)(args[0])
6673

tests/link/numba/test_basic.py

+12
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
import pytest
99

10+
from tests.tensor.test_math_scipy import scipy
11+
1012

1113
numba = pytest.importorskip("numba")
1214

@@ -1064,3 +1066,13 @@ def test_OpFromGraph():
10641066
zv = np.ones((2, 2), dtype=config.floatX) * 5
10651067

10661068
compare_numba_and_py(((x, y, z), (out,)), [xv, yv, zv])
1069+
1070+
1071+
@pytest.mark.filterwarnings("error")
1072+
def test_cache_warning_suppressed():
1073+
x = pt.vector("x", shape=(5,), dtype="float64")
1074+
out = pt.psi(x) * 2
1075+
fn = function([x], out, mode="NUMBA")
1076+
1077+
x_test = np.random.uniform(size=5)
1078+
np.testing.assert_allclose(fn(x_test), scipy.special.psi(x_test) * 2)

0 commit comments

Comments
 (0)