Skip to content

Disable runtime broadcasting in indexing operations #1348

Open
@ricardoV94

Description

@ricardoV94

Description

We are inconsistent in what Ops we allow runtime broadcasting and which we don't.

import pytensor.tensor as pt

x = pt.vector("x", shape=(None,))  # Not known to have length 1 at runtime
out = pt.alloc(x, 3, 5)
try:
    out.eval({x: [1]})
except Exception as e:
    print(str(e).splitlines()[0])
# Runtime broadcasting not allowed. The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. If broadcasting was intended, use `specify_broadcastable` on the relevant input.

out = pt.zeros((10, 10))[[5, 6, 7], [0, 1, 2]].inc(x)
try:
    out.eval({x: [1]})
except Exception as e:
    print(str(e).splitlines()[0])
else:
    print("Did not raise")
# Did not raise

Note that whenever we allow runtime broadcasting will have a wrong gradient wrt to that broadcasted input, since we never implemented a mechanism to reduce runtime broadcasted dimensions.

print(pt.grad(out.sum(), x).eval({x: [1]}).shape)  # (3,)

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions