Skip to content

Consider implementing a scalar Scan Op #83

Closed
@ricardoV94

Description

@ricardoV94

Description

We have a couple of derivatives that are implemented as iterative power series approximations:

class GammaIncDer(BinaryScalarOp):

class GammaIncCDer(BinaryScalarOp):

class BetaIncDer(ScalarOp):

And this in the near future: aesara-devs/aesara#1288

These are currently implemented in Python only. We can't use Scan for these, because Elemwise requires the gradients to be composed exclusively of other Elemwise Ops, so that it can be safely vectorized by just passing tensor inputs. See aesara-devs/aesara#512, aesara-devs/aesara#1178, aesara-devs/aesara#514

If we had a Scalar scan, which expects all inputs and outputs to be scalar, we could then turn it into an Elemwise and use it in the gradients of such Ops.

This would allow us to have a single implementation for our Python/Numba/JAX backends without having to manually rewrite the same code for each backend.

It may be also easier to optimize than the general-purpose Scan, so it could be used internally in other scenarios as well for better performance. For instance this Op wouldn't have to worry about SharedVariables, RNGs, Taps (better to not start down that path, and model carryover explicitly, as this is one of the biggest "unmanageable" complexities in Scan), etc.

This would go with the idea of creating multiple specialized scan Ops, instead of trying to refactor the general purpose one that already exists. As the latter is the approach being explored by Aesara, we may cover more ground across libraries by trying this new approach.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions