Skip to content

BUG: Scan inner graphs are not optimized in NUMBA / JAX backends #20

Closed
@ricardoV94

Description

@ricardoV94

Reproducable code example:

import aesara
aesara.config.mode = "NUMBA"  # Otherwise it works
import aesara.tensor as at
from aesara.compile.builders import OpFromGraph

x = at.scalar("x")
out = at.log(x)
op = OpFromGraph([x], [out], inline=True)

xs = at.vector("xs")
seq, _ = aesara.scan(
    fn=lambda x: op(x),
    sequences=[xs],
)

f = aesara.function([xs], seq)
aesara.dprint(f)
for{cpu,scan_fn} [id A] 5
 |Shape_i{0} [id B] 0
 | |xs [id C]
 |Subtensor{int64:int64:int8} [id D] 4
 | |xs [id C]
 | |ScalarFromTensor [id E] 3
 | | |Elemwise{Composite{Switch(LE(i0, i1), i1, i2)}} [id F] 2
 | |   |Shape_i{0} [id B] 0
 | |   |TensorConstant{0} [id G]
 | |   |TensorConstant{0} [id H]
 | |ScalarFromTensor [id I] 1
 | | |Shape_i{0} [id B] 0
 | |ScalarConstant{1} [id J]
 |Shape_i{0} [id B] 0

Inner graphs:

for{cpu,scan_fn} [id A]
 >OpFromGraph{inline=True} [id K]
 > |*0-<TensorType(float64, ())> [id L] -> [id D]

OpFromGraph{inline=True} [id K]
 >Elemwise{log,no_inplace} [id M]
 > |*0-<TensorType(float64, ())> [id L]

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions