Description
Description
We have a couple of derivatives that are implemented as iterative power series approximations:
pytensor/pytensor/scalar/math.py
Line 678 in 2ebfbf1
pytensor/pytensor/scalar/math.py
Line 754 in 2ebfbf1
pytensor/pytensor/scalar/math.py
Line 1298 in 2ebfbf1
And this in the near future: aesara-devs/aesara#1288
These are currently implemented in Python only. We can't use Scan for these, because Elemwise requires the gradients to be composed exclusively of other Elemwise Ops, so that it can be safely vectorized by just passing tensor inputs. See aesara-devs/aesara#512, aesara-devs/aesara#1178, aesara-devs/aesara#514
If we had a Scalar scan, which expects all inputs and outputs to be scalar, we could then turn it into an Elemwise and use it in the gradients of such Ops.
This would allow us to have a single implementation for our Python/Numba/JAX backends without having to manually rewrite the same code for each backend.
It may be also easier to optimize than the general-purpose Scan, so it could be used internally in other scenarios as well for better performance. For instance this Op wouldn't have to worry about SharedVariables, RNGs, Taps (better to not start down that path, and model carryover explicitly, as this is one of the biggest "unmanageable" complexities in Scan), etc.
This would go with the idea of creating multiple specialized scan Ops, instead of trying to refactor the general purpose one that already exists. As the latter is the approach being explored by Aesara, we may cover more ground across libraries by trying this new approach.