Skip to content

Commit 30a2d73

Browse files
committed
update pytensor version, make xfail more elaborate
1 parent 59aadf0 commit 30a2d73

File tree

2 files changed

+28
-37
lines changed

2 files changed

+28
-37
lines changed

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ cloudpickle
44
fastprogress>=0.2.0
55
numpy>=1.15.0
66
pandas>=0.24.0
7-
pytensor>=2.18.1,<2.19
7+
pytensor>=2.19.0,<2.20
88
scipy>=1.4.1
99
typing-extensions>=3.7.4

tests/sampling/test_jax.py

+27-36
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,12 @@ def test_jax_PosDefMatrix():
8585
pytest.param(1),
8686
pytest.param(
8787
2,
88-
marks=pytest.mark.skipif(
89-
len(jax.devices()) < 2, reason="not enough devices"
90-
),
88+
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"),
9189
),
9290
],
9391
)
9492
@pytest.mark.parametrize("postprocessing_vectorize", ["scan", "vmap"])
95-
def test_transform_samples(
96-
sampler, postprocessing_backend, chains, postprocessing_vectorize
97-
):
93+
def test_transform_samples(sampler, postprocessing_backend, chains, postprocessing_vectorize):
9894
pytensor.config.on_opt_error = "raise"
9995
np.random.seed(13244)
10096

@@ -241,9 +237,7 @@ def test_replace_shared_variables():
241237
x = pytensor.shared(5, name="shared_x")
242238

243239
new_x = _replace_shared_variables([x])
244-
shared_variables = [
245-
var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)
246-
]
240+
shared_variables = [var for var in graph_inputs(new_x) if isinstance(var, SharedVariable)]
247241
assert not shared_variables
248242

249243
x.default_update = x + 1
@@ -333,30 +327,23 @@ def test_idata_kwargs(
333327

334328
posterior = idata.get("posterior")
335329
assert posterior is not None
336-
x_dim_expected = idata_kwargs.get(
337-
"dims", model_test_idata_kwargs.named_vars_to_dims
338-
)["x"][0]
330+
x_dim_expected = idata_kwargs.get("dims", model_test_idata_kwargs.named_vars_to_dims)["x"][0]
339331
assert x_dim_expected is not None
340332
assert posterior["x"].dims[-1] == x_dim_expected
341333

342-
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[
343-
x_dim_expected
344-
]
334+
x_coords_expected = idata_kwargs.get("coords", model_test_idata_kwargs.coords)[x_dim_expected]
345335
assert x_coords_expected is not None
346336
assert list(x_coords_expected) == list(posterior["x"].coords[x_dim_expected].values)
347337

348338
assert posterior["z"].dims[2] == "z_coord"
349339
assert np.all(
350-
posterior["z"].coords["z_coord"].values
351-
== np.array(["apple", "banana", "orange"])
340+
posterior["z"].coords["z_coord"].values == np.array(["apple", "banana", "orange"])
352341
)
353342

354343

355344
def test_get_batched_jittered_initial_points():
356345
with pm.Model() as model:
357-
x = pm.MvNormal(
358-
"x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3))
359-
)
346+
x = pm.MvNormal("x", mu=np.zeros(3), cov=np.eye(3), shape=(2, 3), initval=np.zeros((2, 3)))
360347

361348
# No jitter
362349
ips = _get_batched_jittered_initial_points(
@@ -365,17 +352,13 @@ def test_get_batched_jittered_initial_points():
365352
assert np.all(ips[0] == 0)
366353

367354
# Single chain
368-
ips = _get_batched_jittered_initial_points(
369-
model=model, chains=1, random_seed=1, initvals=None
370-
)
355+
ips = _get_batched_jittered_initial_points(model=model, chains=1, random_seed=1, initvals=None)
371356

372357
assert ips[0].shape == (2, 3)
373358
assert np.all(ips[0] != 0)
374359

375360
# Multiple chains
376-
ips = _get_batched_jittered_initial_points(
377-
model=model, chains=2, random_seed=1, initvals=None
378-
)
361+
ips = _get_batched_jittered_initial_points(model=model, chains=2, random_seed=1, initvals=None)
379362

380363
assert ips[0].shape == (2, 2, 3)
381364
assert np.all(ips[0][0] != ips[0][1])
@@ -395,9 +378,7 @@ def test_get_batched_jittered_initial_points():
395378
pytest.param(1),
396379
pytest.param(
397380
2,
398-
marks=pytest.mark.skipif(
399-
len(jax.devices()) < 2, reason="not enough devices"
400-
),
381+
marks=pytest.mark.skipif(len(jax.devices()) < 2, reason="not enough devices"),
401382
),
402383
],
403384
)
@@ -421,12 +402,8 @@ def test_seeding(chains, random_seed, sampler):
421402
assert all_equal
422403

423404
if chains > 1:
424-
assert np.all(
425-
result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1)
426-
)
427-
assert np.all(
428-
result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1)
429-
)
405+
assert np.all(result1.posterior["x"].sel(chain=0) != result1.posterior["x"].sel(chain=1))
406+
assert np.all(result2.posterior["x"].sel(chain=0) != result2.posterior["x"].sel(chain=1))
430407

431408

432409
@mock.patch("numpyro.infer.MCMC")
@@ -541,7 +518,21 @@ def test_vi_sampling_jax(method):
541518
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX"))
542519

543520

544-
@pytest.mark.xfail(reason="Due to https://github.com/pymc-devs/pytensor/issues/595")
521+
@pytest.mark.xfail(
522+
reason="""
523+
During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.
524+
525+
TypeError: The broadcast pattern of the output of scan
526+
(Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
527+
(Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
528+
1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
529+
while it is still variable in the output, or vice-verca. You have to make them consistent,
530+
e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.
531+
532+
Instead of fixing this error it makes sense to rework the internals of the variational to utilize
533+
pytensor vectorize instead of scan.
534+
"""
535+
)
545536
def test_vi_sampling_jax_svgd():
546537
with pm.Model():
547538
x = pm.Normal("x")

0 commit comments

Comments
 (0)