Description
Description
There are many problems in machine learning that require differentiating through random variables. This specifically came up in the context of pymc-devs/pymc#7799, but it was also implicated in #555.
Right now, pt.grad
refuses to go through a random variable. That's probably the correct behavior. But we could have another function, pt.stochastic_grad
, that would do it and return a stochastic gradient. It would also take an n_samples
argument, since what we're actually computing is a sample-based MCMC estimate of the gradients with respect to parameters.
Most often, the RV would have to know a "reparameterization trick" to split its parameters from the source of randomness. The canonical example is the non-centered normal parameterization. Given a loss function pt.stochastic_grad
would compute the gradient of the expected loss given the RVs: $\nabla_\theta \mathbb{E}x [\mathcal{L(g(x, \theta)})] = \mathbb{E} \nabla\theta \mathcal{L}(g(x, \theta))$. The so-called reparameterization trick just does a non-centered parameterization,
And the (expected) sensitivity equations for the parameters of the normal are:
It would be easy enough for a normal_rv to know this, and to supply these formulas when requested to by the hypothetical pt.stochastic_grad
.
I guess other RVs also have reparameterizations (beta, dirichlet, gamma, ...?), but in some cases, there are multiple options but it's not clear which one is best to use in what cases. Some thought would have to be given to how to handle that.
When a reparameterization doesn't exist, there are other, higher-variance options to compute the expected gradients (the REINFORCE gradients, for example). We could offer these as a fallback.
Basically, this issue is proposing this API, and inviting some discussion on whether we want this type of feature, and how to do it if so. The pt.stochastic_grad
function would be novel as far as I know. Other packages require that you explicitly generate samples in your computation graph. For example, torch
offers normal(mu, sigma).rsample(n_draws)
, which generates samples using reparameterization trick, so the standard loss.backward()
works. Here the user can't "accidentally" trigger stochastic gradients (because you have to call rsample instead of sample).
I'm less familiar with how numpyro works, but I believe that something like numpyro.sample("z", dist.Normal(mu, sigma))
automatically implies reparameterization trick if it's available. They don't have a special idiom like rsample
for when it will or won't be used.