Skip to content

Commit 4730d0c

Browse files
authored
Fix numba impl of empty DimShuffle (#218)
Empty DimShuffles would return a scalar instead of an array, which can then lead to errors in ops that expect an array.
1 parent 8606498 commit 4730d0c

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

pytensor/link/numba/dispatch/elemwise.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -841,7 +841,7 @@ def dimshuffle_inner(x, shuffle):
841841

842842
@numba_basic.numba_njit
843843
def dimshuffle_inner(x, shuffle):
844-
return x.item()
844+
return np.reshape(x, ())
845845

846846
# Without the following wrapper function we would see this error:
847847
# E No implementation of function Function(<built-in function getitem>) found for signature:

tests/link/numba/test_elemwise.py

+8
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,14 @@ def test_Dimshuffle(v, new_order):
210210
)
211211

212212

213+
def test_Dimshuffle_returns_array():
214+
x = at.vector("x", shape=(1,))
215+
y = 2 * at_elemwise.DimShuffle([True], [])(x)
216+
func = pytensor.function([x], y, mode="NUMBA")
217+
out = func(np.zeros(1, dtype=config.floatX))
218+
assert out.ndim == 0
219+
220+
213221
@pytest.mark.parametrize(
214222
"careduce_fn, axis, v",
215223
[

0 commit comments

Comments
 (0)