Skip to content

Make VI compatible with JAX backend #7103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions pymc/pytensorf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor import scalar
from pytensor.compile import Function, Mode, get_mode
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import grad
Expand All @@ -39,7 +38,7 @@
from pytensor.graph.op import Op
from pytensor.scalar.basic import Cast
from pytensor.scan.op import Scan
from pytensor.tensor.basic import _as_tensor_variable
from pytensor.tensor.basic import _as_tensor_variable, tensor_copy
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.op import RandomVariable
from pytensor.tensor.random.type import RandomType
Expand Down Expand Up @@ -412,29 +411,7 @@ def hessian_diag(f, vars=None, negate_output=True):
return empty_gradient


class IdentityOp(scalar.UnaryScalarOp):
@staticmethod
def st_impl(x):
return x

def impl(self, x):
return x

def grad(self, inp, grads):
return grads

def c_code(self, node, name, inp, out, sub):
return f"{out[0]} = {inp[0]};"

def __eq__(self, other):
return isinstance(self, type(other))

def __hash__(self):
return hash(type(self))


scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity")
identity = Elemwise(scalar_identity, name="identity")
identity = tensor_copy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick just import it directly in the VI module, no need to define it in pytensorf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be used by someone else I assume

Copy link
Member

@ricardoV94 ricardoV94 Jan 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so, but even if we keep we should add a deprecatation warning

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just do from pytensor... import tensor_copy as identity?



def make_shared_replacements(point, vars, model):
Expand Down
17 changes: 13 additions & 4 deletions pymc/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def create_shared_params(self, start=None, start_sigma=None):
rho = rho1

return {
"mu": pytensor.shared(pm.floatX(start), "mu"),
"rho": pytensor.shared(pm.floatX(rho), "rho"),
"mu": pytensor.shared(pm.floatX(start), "mu", shape=start.shape),
"rho": pytensor.shared(pm.floatX(rho), "rho", shape=rho.shape),
}

@node_property
Expand Down Expand Up @@ -137,7 +137,10 @@ def create_shared_params(self, start=None):
start = self._prepare_start(start)
n = self.ddim
L_tril = np.eye(n)[np.tril_indices(n)].astype(pytensor.config.floatX)
return {"mu": pytensor.shared(start, "mu"), "L_tril": pytensor.shared(L_tril, "L_tril")}
return {
"mu": pytensor.shared(start, "mu", shape=start.shape),
"L_tril": pytensor.shared(L_tril, "L_tril", shape=L_tril.shape),
}

@node_property
def L(self):
Expand Down Expand Up @@ -225,7 +228,13 @@ def create_shared_params(self, trace=None, size=None, jitter=1, start=None):
for j in range(len(trace)):
histogram[i] = DictToArrayBijection.map(trace.point(j, t)).data
i += 1
return dict(histogram=pytensor.shared(pm.floatX(histogram), "histogram"))
return dict(
histogram=pytensor.shared(
pm.floatX(histogram),
"histogram",
shape=histogram.shape,
)
)

def _check_trace(self):
trace = self._kwargs.get("trace", None)
Expand Down
40 changes: 37 additions & 3 deletions tests/sampling/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ def test_jax_PosDefMatrix():
[
pytest.param(1),
pytest.param(
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
2,
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"),
),
],
)
Expand Down Expand Up @@ -265,7 +266,11 @@ def test_get_jaxified_logp():
@pytest.fixture(scope="module")
def model_test_idata_kwargs() -> pm.Model:
with pm.Model(
coords={"x_coord": ["a", "b"], "x_coord2": [1, 2], "z_coord": ["apple", "banana", "orange"]}
coords={
"x_coord": ["a", "b"],
"x_coord2": [1, 2],
"z_coord": ["apple", "banana", "orange"],
}
) as m:
x = pm.Normal("x", shape=(2,), dims=["x_coord"])
_ = pm.Normal("y", x, observed=[0, 0])
Expand Down Expand Up @@ -372,7 +377,8 @@ def test_get_batched_jittered_initial_points():
[
pytest.param(1),
pytest.param(
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices")
2,
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"),
),
],
)
Expand Down Expand Up @@ -536,3 +542,31 @@ def test_dirichlet_multinomial_dims():
frozen_dm = freeze_dims_and_data(m)["dm"]
dm_draws = pm.draw(frozen_dm, mode="JAX")
np.testing.assert_equal(dm_draws, np.eye(3) * 5)


@pytest.mark.parametrize("method", ["advi", "fullrank_advi"])
def test_vi_sampling_jax(method):
with pm.Model() as model:
x = pm.Normal("x")
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with pm.sample and the nuts_sampler= arg, should we have a dedicated argument for the VI backend instead of kwargs?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I vote yes, this API looks super weird.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What looks weird? This is the compilation mode, would be exactly the same if you wanted to use Numba or JAX for the PyMC nuts sampler or for prior/posterior predictive.

The only thing I would change is the name of fn_kwargs, which is called compile_kwargs I think in those other functions

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't this be what the user would have to do if they wanted to run VI on JAX?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the question, this PR is just doing minor tweaks so the PyMC VI module can compile to JAX. It's not linking to specific JAX VI libraries.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We used this for sample_posterior_predictive for projects just last week, as we were sampling new variables that had heavy matmuls, went down from hours to minutes.

Great idea, should definitely add it there too.

pm.sample is still useful as you can sample discrete variables with JAX this way.

That makes sense, I'm not opposed to adding it there. Maybe we can add a warning that the sampler is still running Python and they likely will want to use nuts_sampler.

Copy link
Member

@ricardoV94 ricardoV94 Apr 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still doing python loops, it's exactly the same argument you need for pm.sample.

It's different than linking to a JAX VI library, which is what would be equivalent to the nuts_sampler kwarg that Chris mentioned in the first comment

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still doing python loops, it's exactly the same argument you need for pm.sample.

Oh, I somehow assumed that VI was implemented mostly in PyTensor?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for this, I'd prefer to focus this PR on backend compatibility and later address possible API changes in a new issue + PR. Agreed that there is inconsistency, we need to resolve that, but this will only defer the push to main with at least some working solution which went through many issues already.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed @ferrine. My only suggestion is to switch fn_kwargs to compile_kwargs which we use in the other sample methods



@pytest.mark.xfail(
reason="""
During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.

TypeError: The broadcast pattern of the output of scan
(Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
(Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
Copy link
Member

@ricardoV94 ricardoV94 Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is actually something wrong, if the number of dimensions of a recurring output is different from the initial state. The difference between None and 1 is more annoying but this one looks like an error.

1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
while it is still variable in the output, or vice-verca. You have to make them consistent,
e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.

Instead of fixing this error it makes sense to rework the internals of the variational to utilize
pytensor vectorize instead of scan.
"""
)
def test_vi_sampling_jax_svgd():
with pm.Model():
x = pm.Normal("x")
pm.fit(10, method="svgd", fn_kwargs=dict(mode="JAX"))
Loading