-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Make VI compatible with JAX backend #7103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
581118a
5544852
44a75b0
6474f9f
5dfd869
829c6cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,6 @@ | |
import pytensor.tensor as pt | ||
import scipy.sparse as sps | ||
|
||
from pytensor import scalar | ||
from pytensor.compile import Function, Mode, get_mode | ||
from pytensor.compile.builders import OpFromGraph | ||
from pytensor.gradient import grad | ||
|
@@ -39,7 +38,7 @@ | |
from pytensor.graph.op import Op | ||
from pytensor.scalar.basic import Cast | ||
from pytensor.scan.op import Scan | ||
from pytensor.tensor.basic import _as_tensor_variable | ||
from pytensor.tensor.basic import _as_tensor_variable, tensor_copy | ||
from pytensor.tensor.elemwise import Elemwise | ||
from pytensor.tensor.random.op import RandomVariable | ||
from pytensor.tensor.random.type import RandomType | ||
|
@@ -412,29 +411,7 @@ def hessian_diag(f, vars=None, negate_output=True): | |
return empty_gradient | ||
|
||
|
||
class IdentityOp(scalar.UnaryScalarOp): | ||
@staticmethod | ||
def st_impl(x): | ||
return x | ||
|
||
def impl(self, x): | ||
return x | ||
|
||
def grad(self, inp, grads): | ||
return grads | ||
|
||
def c_code(self, node, name, inp, out, sub): | ||
return f"{out[0]} = {inp[0]};" | ||
|
||
def __eq__(self, other): | ||
return isinstance(self, type(other)) | ||
|
||
def __hash__(self): | ||
return hash(type(self)) | ||
|
||
|
||
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity") | ||
identity = Elemwise(scalar_identity, name="identity") | ||
identity = tensor_copy | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just do |
||
|
||
|
||
def make_shared_replacements(point, vars, model): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -86,7 +86,8 @@ def test_jax_PosDefMatrix(): | |
[ | ||
pytest.param(1), | ||
pytest.param( | ||
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices") | ||
2, | ||
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"), | ||
), | ||
], | ||
) | ||
|
@@ -265,7 +266,11 @@ def test_get_jaxified_logp(): | |
@pytest.fixture(scope="module") | ||
def model_test_idata_kwargs() -> pm.Model: | ||
with pm.Model( | ||
coords={"x_coord": ["a", "b"], "x_coord2": [1, 2], "z_coord": ["apple", "banana", "orange"]} | ||
coords={ | ||
"x_coord": ["a", "b"], | ||
"x_coord2": [1, 2], | ||
"z_coord": ["apple", "banana", "orange"], | ||
} | ||
) as m: | ||
x = pm.Normal("x", shape=(2,), dims=["x_coord"]) | ||
_ = pm.Normal("y", x, observed=[0, 0]) | ||
|
@@ -372,7 +377,8 @@ def test_get_batched_jittered_initial_points(): | |
[ | ||
pytest.param(1), | ||
pytest.param( | ||
2, marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices") | ||
2, | ||
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"), | ||
), | ||
], | ||
) | ||
|
@@ -536,3 +542,31 @@ def test_dirichlet_multinomial_dims(): | |
frozen_dm = freeze_dims_and_data(m)["dm"] | ||
dm_draws = pm.draw(frozen_dm, mode="JAX") | ||
np.testing.assert_equal(dm_draws, np.eye(3) * 5) | ||
|
||
|
||
@pytest.mark.parametrize("method", ["advi", "fullrank_advi"]) | ||
def test_vi_sampling_jax(method): | ||
with pm.Model() as model: | ||
x = pm.Normal("x") | ||
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX")) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To be consistent with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I vote yes, this API looks super weird. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What looks weird? This is the compilation mode, would be exactly the same if you wanted to use Numba or JAX for the PyMC nuts sampler or for prior/posterior predictive. The only thing I would change is the name of fn_kwargs, which is called compile_kwargs I think in those other functions There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wouldn't this be what the user would have to do if they wanted to run VI on JAX? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't understand the question, this PR is just doing minor tweaks so the PyMC VI module can compile to JAX. It's not linking to specific JAX VI libraries. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Great idea, should definitely add it there too.
That makes sense, I'm not opposed to adding it there. Maybe we can add a warning that the sampler is still running Python and they likely will want to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still doing python loops, it's exactly the same argument you need for pm.sample. It's different than linking to a JAX VI library, which is what would be equivalent to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Oh, I somehow assumed that VI was implemented mostly in PyTensor? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for this, I'd prefer to focus this PR on backend compatibility and later address possible API changes in a new issue + PR. Agreed that there is inconsistency, we need to resolve that, but this will only defer the push to main with at least some working solution which went through many issues already. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed @ferrine. My only suggestion is to switch |
||
|
||
|
||
@pytest.mark.xfail( | ||
reason=""" | ||
During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic. | ||
|
||
TypeError: The broadcast pattern of the output of scan | ||
(Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info` | ||
(Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is actually something wrong, if the number of dimensions of a recurring output is different from the initial state. The difference between None and 1 is more annoying but this one looks like an error. |
||
1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input, | ||
while it is still variable in the output, or vice-verca. You have to make them consistent, | ||
e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}. | ||
|
||
Instead of fixing this error it makes sense to rework the internals of the variational to utilize | ||
pytensor vectorize instead of scan. | ||
""" | ||
) | ||
def test_vi_sampling_jax_svgd(): | ||
with pm.Model(): | ||
x = pm.Normal("x") | ||
pm.fit(10, method="svgd", fn_kwargs=dict(mode="JAX")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick just import it directly in the VI module, no need to define it in pytensorf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be used by someone else I assume
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, but even if we keep we should add a deprecatation warning