Skip to content

Rewrite expand_dims implied in vector indices #1138

Open
@ricardoV94

Description

@ricardoV94

Description

The model behind #1132 cannot run in non-obj numba due to an implicit expand_dims in the vector indexing, that looks like:

import pytensor
import pytensor.tensor as pt
import numpy as np

a = pt.tensor(shape=(None, None))
b = a[pt.arange(10)[:, None], pt.arange(10)[:, None]]
c = a[pt.arange(10), pt.arange(10)][:, None]

# Issues UserWarning: Numba will use object mode to run AdvancedSubtensor's perform method
fn_b = pytensor.function([a], b, mode="NUMBA")

# Runs in non-obj mode
fn_c = pytensor.function([a], c, mode="NUMBA")

test_a = pt.random.normal(size=(10, 10)).eval()
np.testing.assert_allclose(fn_b(test_a), fn_c(test_a))

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions