Open
Description
Description
#232 left this case out
IIRC, JAX can only JIT while loops with fixed output size, which means PyTensor while Scans where only the last state is used.
#232 left this case out
IIRC, JAX can only JIT while loops with fixed output size, which means PyTensor while Scans where only the last state is used.