@@ -85,16 +85,12 @@ def test_jax_PosDefMatrix():
85
85
pytest .param (1 ),
86
86
pytest .param (
87
87
2 ,
88
- marks = pytest .mark .skipif (
89
- len (jax .devices ()) < 2 , reason = "not enough devices"
90
- ),
88
+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
91
89
),
92
90
],
93
91
)
94
92
@pytest .mark .parametrize ("postprocessing_vectorize" , ["scan" , "vmap" ])
95
- def test_transform_samples (
96
- sampler , postprocessing_backend , chains , postprocessing_vectorize
97
- ):
93
+ def test_transform_samples (sampler , postprocessing_backend , chains , postprocessing_vectorize ):
98
94
pytensor .config .on_opt_error = "raise"
99
95
np .random .seed (13244 )
100
96
@@ -241,9 +237,7 @@ def test_replace_shared_variables():
241
237
x = pytensor .shared (5 , name = "shared_x" )
242
238
243
239
new_x = _replace_shared_variables ([x ])
244
- shared_variables = [
245
- var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )
246
- ]
240
+ shared_variables = [var for var in graph_inputs (new_x ) if isinstance (var , SharedVariable )]
247
241
assert not shared_variables
248
242
249
243
x .default_update = x + 1
@@ -333,30 +327,23 @@ def test_idata_kwargs(
333
327
334
328
posterior = idata .get ("posterior" )
335
329
assert posterior is not None
336
- x_dim_expected = idata_kwargs .get (
337
- "dims" , model_test_idata_kwargs .named_vars_to_dims
338
- )["x" ][0 ]
330
+ x_dim_expected = idata_kwargs .get ("dims" , model_test_idata_kwargs .named_vars_to_dims )["x" ][0 ]
339
331
assert x_dim_expected is not None
340
332
assert posterior ["x" ].dims [- 1 ] == x_dim_expected
341
333
342
- x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[
343
- x_dim_expected
344
- ]
334
+ x_coords_expected = idata_kwargs .get ("coords" , model_test_idata_kwargs .coords )[x_dim_expected ]
345
335
assert x_coords_expected is not None
346
336
assert list (x_coords_expected ) == list (posterior ["x" ].coords [x_dim_expected ].values )
347
337
348
338
assert posterior ["z" ].dims [2 ] == "z_coord"
349
339
assert np .all (
350
- posterior ["z" ].coords ["z_coord" ].values
351
- == np .array (["apple" , "banana" , "orange" ])
340
+ posterior ["z" ].coords ["z_coord" ].values == np .array (["apple" , "banana" , "orange" ])
352
341
)
353
342
354
343
355
344
def test_get_batched_jittered_initial_points ():
356
345
with pm .Model () as model :
357
- x = pm .MvNormal (
358
- "x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 ))
359
- )
346
+ x = pm .MvNormal ("x" , mu = np .zeros (3 ), cov = np .eye (3 ), shape = (2 , 3 ), initval = np .zeros ((2 , 3 )))
360
347
361
348
# No jitter
362
349
ips = _get_batched_jittered_initial_points (
@@ -365,17 +352,13 @@ def test_get_batched_jittered_initial_points():
365
352
assert np .all (ips [0 ] == 0 )
366
353
367
354
# Single chain
368
- ips = _get_batched_jittered_initial_points (
369
- model = model , chains = 1 , random_seed = 1 , initvals = None
370
- )
355
+ ips = _get_batched_jittered_initial_points (model = model , chains = 1 , random_seed = 1 , initvals = None )
371
356
372
357
assert ips [0 ].shape == (2 , 3 )
373
358
assert np .all (ips [0 ] != 0 )
374
359
375
360
# Multiple chains
376
- ips = _get_batched_jittered_initial_points (
377
- model = model , chains = 2 , random_seed = 1 , initvals = None
378
- )
361
+ ips = _get_batched_jittered_initial_points (model = model , chains = 2 , random_seed = 1 , initvals = None )
379
362
380
363
assert ips [0 ].shape == (2 , 2 , 3 )
381
364
assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
@@ -395,9 +378,7 @@ def test_get_batched_jittered_initial_points():
395
378
pytest .param (1 ),
396
379
pytest .param (
397
380
2 ,
398
- marks = pytest .mark .skipif (
399
- len (jax .devices ()) < 2 , reason = "not enough devices"
400
- ),
381
+ marks = pytest .mark .skipif (len (jax .devices ()) < 2 , reason = "not enough devices" ),
401
382
),
402
383
],
403
384
)
@@ -421,12 +402,8 @@ def test_seeding(chains, random_seed, sampler):
421
402
assert all_equal
422
403
423
404
if chains > 1 :
424
- assert np .all (
425
- result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 )
426
- )
427
- assert np .all (
428
- result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 )
429
- )
405
+ assert np .all (result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 ))
406
+ assert np .all (result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 ))
430
407
431
408
432
409
@mock .patch ("numpyro.infer.MCMC" )
@@ -541,7 +518,21 @@ def test_vi_sampling_jax(method):
541
518
pm .fit (10 , method = method , fn_kwargs = dict (mode = "JAX" ))
542
519
543
520
544
- @pytest .mark .xfail (reason = "Due to https://github.com/pymc-devs/pytensor/issues/595" )
521
+ @pytest .mark .xfail (
522
+ reason = """
523
+ During equilibrium rewriter this error happens. Probably one of the routines in SVGD is problematic.
524
+
525
+ TypeError: The broadcast pattern of the output of scan
526
+ (Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info`
527
+ (Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis
528
+ 1 in `output_info`. This can happen if one of the dimension is fixed to 1 in the input,
529
+ while it is still variable in the output, or vice-verca. You have to make them consistent,
530
+ e.g. using pytensor.tensor.{unbroadcast, specify_broadcastable}.
531
+
532
+ Instead of fixing this error it makes sense to rework the internals of the variational to utilize
533
+ pytensor vectorize instead of scan.
534
+ """
535
+ )
545
536
def test_vi_sampling_jax_svgd ():
546
537
with pm .Model ():
547
538
x = pm .Normal ("x" )
0 commit comments