|
13 | 13 | from pytensor.scan.op import Scan
|
14 | 14 | from pytensor.tensor import random
|
15 | 15 | from pytensor.tensor.math import gammaln, log
|
16 |
| -from pytensor.tensor.type import lscalar, scalar, vector |
| 16 | +from pytensor.tensor.type import dmatrix, dvector, lscalar, scalar, vector |
17 | 17 | from tests.link.jax.test_basic import compare_jax_and_py
|
18 | 18 |
|
19 | 19 |
|
@@ -317,3 +317,104 @@ def input_step_fn(y_tm1, y_tm3, a):
|
317 | 317 |
|
318 | 318 | test_input_vals = [np.array(10.0).astype(config.floatX)]
|
319 | 319 | compare_jax_and_py(out_fg, test_input_vals)
|
| 320 | + |
| 321 | + |
| 322 | +@pytest.mark.parametrize("x0_func", [dvector, dmatrix]) |
| 323 | +@pytest.mark.parametrize("A_func", [dmatrix, dmatrix]) |
| 324 | +def test_nd_scan_sit_sot(x0_func, A_func): |
| 325 | + x0 = x0_func("x0") |
| 326 | + A = A_func("A") |
| 327 | + |
| 328 | + n_steps = 3 |
| 329 | + k = 3 |
| 330 | + |
| 331 | + # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph |
| 332 | + xs, _ = scan( |
| 333 | + lambda X, A: A @ X, |
| 334 | + non_sequences=[A], |
| 335 | + outputs_info=[x0], |
| 336 | + n_steps=n_steps, |
| 337 | + mode=get_mode("JAX"), |
| 338 | + ) |
| 339 | + |
| 340 | + x0_val = ( |
| 341 | + np.arange(k, dtype=config.floatX) |
| 342 | + if x0.ndim == 1 |
| 343 | + else np.diag(np.arange(k, dtype=config.floatX)) |
| 344 | + ) |
| 345 | + A_val = np.eye(k, dtype=config.floatX) |
| 346 | + |
| 347 | + fg = FunctionGraph([x0, A], [xs]) |
| 348 | + test_input_vals = [x0_val, A_val] |
| 349 | + compare_jax_and_py(fg, test_input_vals) |
| 350 | + |
| 351 | + |
| 352 | +def test_nd_scan_sit_sot_with_seq(): |
| 353 | + n_steps = 3 |
| 354 | + k = 3 |
| 355 | + |
| 356 | + x = at.matrix("x0", shape=(n_steps, k)) |
| 357 | + A = at.matrix("A", shape=(k, k)) |
| 358 | + |
| 359 | + # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph |
| 360 | + xs, _ = scan( |
| 361 | + lambda X, A: A @ X, |
| 362 | + non_sequences=[A], |
| 363 | + sequences=[x], |
| 364 | + n_steps=n_steps, |
| 365 | + mode=get_mode("JAX"), |
| 366 | + ) |
| 367 | + |
| 368 | + x_val = np.arange(n_steps * k, dtype=config.floatX).reshape(n_steps, k) |
| 369 | + A_val = np.eye(k, dtype=config.floatX) |
| 370 | + |
| 371 | + fg = FunctionGraph([x, A], [xs]) |
| 372 | + test_input_vals = [x_val, A_val] |
| 373 | + compare_jax_and_py(fg, test_input_vals) |
| 374 | + |
| 375 | + |
| 376 | +def test_nd_scan_mit_sot(): |
| 377 | + x0 = at.matrix("x0", shape=(3, 3)) |
| 378 | + A = at.matrix("A", shape=(3, 3)) |
| 379 | + B = at.matrix("B", shape=(3, 3)) |
| 380 | + |
| 381 | + # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph |
| 382 | + xs, _ = scan( |
| 383 | + lambda xtm3, xtm1, A, B: A @ xtm3 + B @ xtm1, |
| 384 | + outputs_info=[{"initial": x0, "taps": [-3, -1]}], |
| 385 | + non_sequences=[A, B], |
| 386 | + n_steps=10, |
| 387 | + mode=get_mode("JAX"), |
| 388 | + ) |
| 389 | + |
| 390 | + fg = FunctionGraph([x0, A, B], [xs]) |
| 391 | + x0_val = np.arange(9, dtype=config.floatX).reshape(3, 3) |
| 392 | + A_val = np.eye(3, dtype=config.floatX) |
| 393 | + B_val = np.eye(3, dtype=config.floatX) |
| 394 | + |
| 395 | + test_input_vals = [x0_val, A_val, B_val] |
| 396 | + compare_jax_and_py(fg, test_input_vals) |
| 397 | + |
| 398 | + |
| 399 | +def test_nd_scan_sit_sot_with_carry(): |
| 400 | + x0 = at.vector("x0", shape=(3,)) |
| 401 | + A = at.matrix("A", shape=(3, 3)) |
| 402 | + |
| 403 | + def step(x, A): |
| 404 | + return A @ x, x.sum() |
| 405 | + |
| 406 | + # Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph |
| 407 | + xs, _ = scan( |
| 408 | + step, |
| 409 | + outputs_info=[x0, None], |
| 410 | + non_sequences=[A], |
| 411 | + n_steps=10, |
| 412 | + mode=get_mode("JAX"), |
| 413 | + ) |
| 414 | + |
| 415 | + fg = FunctionGraph([x0, A], xs) |
| 416 | + x0_val = np.arange(3, dtype=config.floatX) |
| 417 | + A_val = np.eye(3, dtype=config.floatX) |
| 418 | + |
| 419 | + test_input_vals = [x0_val, A_val] |
| 420 | + compare_jax_and_py(fg, test_input_vals) |
0 commit comments