Skip to content

Commit 44a75b0

Browse files
committed
add expected fail to remember about svgd
1 parent 5544852 commit 44a75b0

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

tests/sampling/test_jax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,10 @@ def test_vi_sampling_jax(method):
572572
with pm.Model() as model:
573573
x = pm.Normal("x")
574574
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
575+
576+
577+
@pytest.mark.xfail(reason="Due to https://github.com/pymc-devs/pytensor/issues/595")
578+
def test_vi_sampling_jax_svgd():
579+
with pm.Model():
580+
x = pm.Normal("x")
581+
pm.fit(10, method="svgd", fn_kwargs=dict(mode="JAX"))

0 commit comments

Comments
 (0)