Skip to content

Stochastic gradients in pytensor #1419

Closed
@jessegrabowski

Description

@jessegrabowski

Description

There are many problems in machine learning that require differentiating through random variables. This specifically came up in the context of pymc-devs/pymc#7799, but it was also implicated in #555.

Right now, pt.grad refuses to go through a random variable. That's probably the correct behavior. But we could have another function, pt.stochastic_grad, that would do it and return a stochastic gradient. It would also take an n_samples argument, since what we're actually computing is a sample-based MCMC estimate of the gradients with respect to parameters.

Most often, the RV would have to know a "reparameterization trick" to split its parameters from the source of randomness. The canonical example is the non-centered normal parameterization. Given a loss function $\mathcal L$ that depends on $x \sim N(\mu, \sigma)$ , the proposed pt.stochastic_grad would compute the gradient of the expected loss given the RVs: $\nabla_\theta \mathbb{E}x [\mathcal{L(g(x, \theta)})] = \mathbb{E} \nabla\theta \mathcal{L}(g(x, \theta))$. The so-called reparameterization trick just does a non-centered parameterization, $x = \mu + \sigma z, \quad z \sim N(0,1)$, so that now the gradient contribution of $g(x, \theta)$ can be estimated:

$$ \approx \frac{1}{N} \sum_{i=1}^N \nabla_\theta L(\mu + \sigma z_i) $$

And the (expected) sensitivity equations for the parameters of the normal are:

$$ \begin{align} \bar{\mu} &= \frac{1}{N} \sum_{i=1}^N \frac{\partial L}{\partial x^{(i)}}, \\ \bar{\sigma} &= \frac{1}{N} \sum_{i=1}^N \frac{\partial L}{\partial x^{(i)}} \cdot z^{(i)} \end{align} $$

It would be easy enough for a normal_rv to know this, and to supply these formulas when requested to by the hypothetical pt.stochastic_grad.

I guess other RVs also have reparameterizations (beta, dirichlet, gamma, ...?), but in some cases, there are multiple options but it's not clear which one is best to use in what cases. Some thought would have to be given to how to handle that.

When a reparameterization doesn't exist, there are other, higher-variance options to compute the expected gradients (the REINFORCE gradients, for example). We could offer these as a fallback.

Basically, this issue is proposing this API, and inviting some discussion on whether we want this type of feature, and how to do it if so. The pt.stochastic_grad function would be novel as far as I know. Other packages require that you explicitly generate samples in your computation graph. For example, torch offers normal(mu, sigma).rsample(n_draws), which generates samples using reparameterization trick, so the standard loss.backward() works. Here the user can't "accidentally" trigger stochastic gradients (because you have to call rsample instead of sample).

I'm less familiar with how numpyro works, but I believe that something like numpyro.sample("z", dist.Normal(mu, sigma)) automatically implies reparameterization trick if it's available. They don't have a special idiom like rsample for when it will or won't be used.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions