Skip to content

Gradient of scan fails when it involves a shared variable #555

Open
@jessegrabowski

Description

@jessegrabowski

Before

Currently, this graph has valid gradients with respect to mu and sigma:

mu = pt.dscalar('mu')
sigma = pt.dscalar('sigma')

epsilon = pt.random.normal(0, 1)
z = mu + sigma * epsilon

pt.grad(z, sigma).eval({mu:1, sigma:1})
# Out: Random draw from a N(0, 1)

But this graph does not:

def step(x, mu, sigma, rng):
    epsilon = pt.random.normal(0, 1, rng=rng)
    next_x = x + mu + sigma * epsilon
    return next_x, {rng:new_rng}

traj, updates = pytensor.scan(step, outputs_info=[x0], non_sequences=[mu, sigma, rng], n_steps=10)
pt.grad(traj[-1], sigma).eval({mu:1, sigma:1, x0:0})
# Out: Error, graph depends on a shared variable

After

I imagine that in cases where the "reparameterization trick" is used, stochastic gradients can be computed for scan graphs.

Context for the issue:

The "reparameterization trick" is well known in the machine learning literature as a way to get stochastic gradients from graphs with sampling operations. It seems like we already support this, because this graph can be differentiated:

epsilon = pt.random.normal(0, 1)
z = mu + sigma * epsilon

pt.grad(z, sigma).eval({mu:1, sigma:1})

But this graph cannot:

z= pt.random.normal(mu, sigma)
pt.grad(z, sigma).eval({mu:1, sigma:1})

The fact that even the "good" version breaks down in scan is I suppose a bug? Or a missing feature? Or neither? In the equation:

$$x_{t+1} = x_t + \mu + \sigma \epsilon_t$$
with $x_0$ given, it seems like:

$$\frac{\partial x_2}{\partial \sigma} =\frac{\partial}{\partial \sigma} x_0 + \mu + \sigma \epsilon_1 + \mu + \sigma \epsilon_2 = \epsilon_1 + \epsilon_2$$

I should get back the sum of the random draws for the sequence.

Context: I'm trying to use pytensor to compute greeks for options, which involves taking the derivative of sampled trajectories.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions