Skip to content

A more low level implementation of vectorize in numba #92

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jan 4, 2023
8 changes: 4 additions & 4 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def in_seq_empty_tuple(x, y):


def to_scalar(x):
raise NotImplementedError()
return np.asarray(x).item()


@numba.extending.overload(to_scalar)
Expand Down Expand Up @@ -543,7 +543,7 @@ def {fn_name}({", ".join(input_names)}):
{index_prologue}
{indices_creation_src}
{index_body}
return z
return np.asarray(z)
"""

return subtensor_def_src
Expand Down Expand Up @@ -665,7 +665,7 @@ def numba_funcify_Shape_i(op, **kwargs):

@numba_njit
def shape_i(x):
return np.shape(x)[i]
return np.asarray(np.shape(x)[i])

return shape_i

Expand Down Expand Up @@ -698,7 +698,7 @@ def numba_funcify_Reshape(op, **kwargs):

@numba_njit
def reshape(x, shape):
return x.item()
return np.asarray(x.item())

else:

Expand Down
Loading