|
19 | 19 | jax = pytest.importorskip("jax")
|
20 | 20 |
|
21 | 21 |
|
| 22 | +from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 |
| 23 | + |
| 24 | + |
22 | 25 | def test_random_RandomStream():
|
23 | 26 | """Two successive calls of a compiled graph using `RandomStream` should
|
24 | 27 | return different values.
|
@@ -377,6 +380,25 @@ def test_random_updates(rng_ctor):
|
377 | 380 | # https://stackoverflow.com/a/48603469
|
378 | 381 | lambda mean, scale: (mean / scale, 0, scale),
|
379 | 382 | ),
|
| 383 | + pytest.param( |
| 384 | + aer.vonmises, |
| 385 | + [ |
| 386 | + set_test_value( |
| 387 | + at.dvector(), |
| 388 | + np.array([-0.5, 1.3], dtype=np.float64), |
| 389 | + ), |
| 390 | + set_test_value( |
| 391 | + at.dvector(), |
| 392 | + np.array([5.5, 13.0], dtype=np.float64), |
| 393 | + ), |
| 394 | + ], |
| 395 | + (2,), |
| 396 | + "vonmises", |
| 397 | + lambda mu, kappa: (kappa, mu), |
| 398 | + marks=pytest.mark.skipif( |
| 399 | + not numpyro_available, reason="VonMises dispatch requires numpyro" |
| 400 | + ), |
| 401 | + ), |
380 | 402 | ],
|
381 | 403 | )
|
382 | 404 | def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_conv):
|
@@ -519,6 +541,83 @@ def test_negative_binomial():
|
519 | 541 | )
|
520 | 542 |
|
521 | 543 |
|
| 544 | +@pytest.mark.skipif(not numpyro_available, reason="Binomial dispatch requires numpyro") |
| 545 | +def test_binomial(): |
| 546 | + rng = shared(np.random.RandomState(123)) |
| 547 | + n = np.array([10, 40]) |
| 548 | + p = np.array([0.3, 0.7]) |
| 549 | + g = at.random.binomial(n, p, size=(10_000, 2), rng=rng) |
| 550 | + g_fn = function([], g, mode=jax_mode) |
| 551 | + samples = g_fn() |
| 552 | + np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) |
| 553 | + np.testing.assert_allclose(samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.1) |
| 554 | + |
| 555 | + |
| 556 | +@pytest.mark.skipif( |
| 557 | + not numpyro_available, reason="BetaBinomial dispatch requires numpyro" |
| 558 | +) |
| 559 | +def test_beta_binomial(): |
| 560 | + rng = shared(np.random.RandomState(123)) |
| 561 | + n = np.array([10, 40]) |
| 562 | + a = np.array([1.5, 13]) |
| 563 | + b = np.array([0.5, 9]) |
| 564 | + g = at.random.betabinom(n, a, b, size=(10_000, 2), rng=rng) |
| 565 | + g_fn = function([], g, mode=jax_mode) |
| 566 | + samples = g_fn() |
| 567 | + np.testing.assert_allclose(samples.mean(axis=0), n * a / (a + b), rtol=0.1) |
| 568 | + np.testing.assert_allclose( |
| 569 | + samples.std(axis=0), |
| 570 | + np.sqrt((n * a * b * (a + b + n)) / ((a + b) ** 2 * (a + b + 1))), |
| 571 | + rtol=0.1, |
| 572 | + ) |
| 573 | + |
| 574 | + |
| 575 | +@pytest.mark.skipif( |
| 576 | + not numpyro_available, reason="Multinomial dispatch requires numpyro" |
| 577 | +) |
| 578 | +def test_multinomial(): |
| 579 | + rng = shared(np.random.RandomState(123)) |
| 580 | + n = np.array([10, 40]) |
| 581 | + p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) |
| 582 | + g = at.random.multinomial(n, p, size=(10_000, 2), rng=rng) |
| 583 | + g_fn = function([], g, mode=jax_mode) |
| 584 | + samples = g_fn() |
| 585 | + np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) |
| 586 | + np.testing.assert_allclose( |
| 587 | + samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1 |
| 588 | + ) |
| 589 | + |
| 590 | + |
| 591 | +@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") |
| 592 | +def test_vonmises_mu_outside_circle(): |
| 593 | + # Scipy implementation does not behave as PyTensor/NumPy for mu outside the unit circle |
| 594 | + # We test that the random draws from the JAX dispatch work as expected in these cases |
| 595 | + rng = shared(np.random.RandomState(123)) |
| 596 | + mu = np.array([-30, 40]) |
| 597 | + kappa = np.array([100, 10]) |
| 598 | + g = at.random.vonmises(mu, kappa, size=(10_000, 2), rng=rng) |
| 599 | + g_fn = function([], g, mode=jax_mode) |
| 600 | + samples = g_fn() |
| 601 | + np.testing.assert_allclose( |
| 602 | + samples.mean(axis=0), (mu + np.pi) % (2.0 * np.pi) - np.pi, rtol=0.1 |
| 603 | + ) |
| 604 | + |
| 605 | + # Circvar only does the correct thing in more recent versions of Scipy |
| 606 | + # https://github.com/scipy/scipy/pull/5747 |
| 607 | + # np.testing.assert_allclose( |
| 608 | + # stats.circvar(samples, axis=0), |
| 609 | + # 1 - special.iv(1, kappa) / special.iv(0, kappa), |
| 610 | + # rtol=0.1, |
| 611 | + # ) |
| 612 | + |
| 613 | + # For now simple compare with std from numpy draws |
| 614 | + rng = np.random.default_rng(123) |
| 615 | + ref_samples = rng.vonmises(mu, kappa, size=(10_000, 2)) |
| 616 | + np.testing.assert_allclose( |
| 617 | + np.std(samples, axis=0), np.std(ref_samples, axis=0), rtol=0.1 |
| 618 | + ) |
| 619 | + |
| 620 | + |
522 | 621 | def test_random_unimplemented():
|
523 | 622 | """Compiling a graph with a non-supported `RandomVariable` should
|
524 | 623 | raise an error.
|
|
0 commit comments