Skip to content

Commit 6135962

Browse files
committed
Benchmark Scan buffer optimization in Numba
1 parent e1f77bf commit 6135962

File tree

1 file changed

+117
-33
lines changed

1 file changed

+117
-33
lines changed

tests/link/numba/test_scan.py

Lines changed: 117 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -339,39 +339,6 @@ def power_step(prior_result, x):
339339
compare_numba_and_py([A], result, test_input_vals)
340340

341341

342-
@pytest.mark.parametrize("n_steps_val", [1, 5])
343-
def test_scan_save_mem_basic(n_steps_val):
344-
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
345-
346-
def f_pow2(x_tm2, x_tm1):
347-
return 2 * x_tm1 + x_tm2
348-
349-
init_x = pt.dvector("init_x")
350-
n_steps = pt.iscalar("n_steps")
351-
output, _ = scan(
352-
f_pow2,
353-
sequences=[],
354-
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
355-
non_sequences=[],
356-
n_steps=n_steps,
357-
)
358-
359-
state_val = np.array([1.0, 2.0])
360-
361-
numba_mode = get_mode("NUMBA").including("scan_save_mem")
362-
py_mode = Mode("py").including("scan_save_mem")
363-
364-
test_input_vals = (state_val, n_steps_val)
365-
366-
compare_numba_and_py(
367-
[init_x, n_steps],
368-
[output],
369-
test_input_vals,
370-
numba_mode=numba_mode,
371-
py_mode=py_mode,
372-
)
373-
374-
375342
def test_grad_sitsot():
376343
def get_sum_of_grad(inp):
377344
scan_outputs, updates = scan(
@@ -482,3 +449,120 @@ def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
482449
np.testing.assert_array_almost_equal(numba_r, ref_r)
483450

484451
benchmark(numba_fn, *test.values())
452+
453+
454+
@pytest.mark.parametrize(
455+
"buffer_size", ("unit", "aligned", "misaligned", "whole", "whole+init")
456+
)
457+
@pytest.mark.parametrize("n_steps, op_size", [(10, 2), (512, 2), (512, 256)])
458+
class TestScanSITSOTBuffer:
459+
def buffer_tester(self, n_steps, op_size, buffer_size, benchmark=None):
460+
x0 = pt.vector(shape=(op_size,), dtype="float64")
461+
xs, _ = pytensor.scan(
462+
fn=lambda xtm1: (xtm1 + 1),
463+
outputs_info=[x0],
464+
n_steps=n_steps - 1, # 1- makes it easier to align/misalign
465+
)
466+
if buffer_size == "unit":
467+
xs_kept = xs[-1] # Only last state is used
468+
expected_buffer_size = 2
469+
elif buffer_size == "aligned":
470+
xs_kept = xs[-2:] # The buffer will be aligned at the end of the 9 steps
471+
expected_buffer_size = 2
472+
elif buffer_size == "misaligned":
473+
xs_kept = xs[-3:] # The buffer will be misaligned at the end of the 9 steps
474+
expected_buffer_size = 3
475+
elif buffer_size == "whole":
476+
xs_kept = xs # What users think is the whole buffer
477+
expected_buffer_size = n_steps - 1
478+
elif buffer_size == "whole+init":
479+
xs_kept = xs.owner.inputs[0] # Whole buffer actually used by Scan
480+
expected_buffer_size = n_steps
481+
482+
x_test = np.zeros(x0.type.shape)
483+
numba_fn, _ = compare_numba_and_py(
484+
[x0],
485+
[xs_kept],
486+
test_inputs=[x_test],
487+
numba_mode="NUMBA", # Default doesn't include optimizations
488+
eval_obj_mode=False,
489+
)
490+
[scan_node] = [
491+
node
492+
for node in numba_fn.maker.fgraph.toposort()
493+
if isinstance(node.op, Scan)
494+
]
495+
buffer = scan_node.inputs[1]
496+
assert buffer.type.shape[0] == expected_buffer_size
497+
498+
if benchmark is not None:
499+
numba_fn.trust_input = True
500+
benchmark(numba_fn, x_test)
501+
502+
def test_sit_sot_buffer(self, n_steps, op_size, buffer_size):
503+
self.buffer_tester(n_steps, op_size, buffer_size, benchmark=None)
504+
505+
def test_sit_sot_buffer_benchmark(self, n_steps, op_size, buffer_size, benchmark):
506+
self.buffer_tester(n_steps, op_size, buffer_size, benchmark=benchmark)
507+
508+
509+
@pytest.mark.parametrize("constant_n_steps", [False, True])
510+
@pytest.mark.parametrize("n_steps_val", [1, 1000])
511+
class TestScanMITSOTBuffer:
512+
def buffer_tester(self, constant_n_steps, n_steps_val, benchmark=None):
513+
"""Make sure we can handle storage changes caused by the `scan_save_mem` rewrite."""
514+
515+
def f_pow2(x_tm2, x_tm1):
516+
return 2 * x_tm1 + x_tm2
517+
518+
init_x = pt.vector("init_x", shape=(2,))
519+
n_steps = pt.iscalar("n_steps")
520+
output, _ = scan(
521+
f_pow2,
522+
sequences=[],
523+
outputs_info=[{"initial": init_x, "taps": [-2, -1]}],
524+
non_sequences=[],
525+
n_steps=n_steps_val if constant_n_steps else n_steps,
526+
)
527+
528+
init_x_val = np.array([1.0, 2.0], dtype=init_x.type.dtype)
529+
test_vals = (
530+
[init_x_val]
531+
if constant_n_steps
532+
else [init_x_val, np.asarray(n_steps_val, dtype=n_steps.type.dtype)]
533+
)
534+
numba_fn, _ = compare_numba_and_py(
535+
[init_x] if constant_n_steps else [init_x, n_steps],
536+
[output[-1]],
537+
test_vals,
538+
numba_mode="NUMBA",
539+
eval_obj_mode=False,
540+
)
541+
542+
if n_steps_val == 1 and constant_n_steps:
543+
# There's no Scan in the graph when nsteps=constant(1)
544+
return
545+
546+
# Check the buffer size as been optimized
547+
[scan_node] = [
548+
node
549+
for node in numba_fn.maker.fgraph.toposort()
550+
if isinstance(node.op, Scan)
551+
]
552+
[mitsot_buffer] = scan_node.op.outer_mitsot(scan_node.inputs)
553+
mitsot_buffer_shape = mitsot_buffer.shape.eval(
554+
{init_x: init_x_val, n_steps: n_steps_val},
555+
accept_inplace=True,
556+
on_unused_input="ignore",
557+
)
558+
assert tuple(mitsot_buffer_shape) == (3,)
559+
560+
if benchmark is not None:
561+
numba_fn.trust_input = True
562+
benchmark(numba_fn, *test_vals)
563+
564+
def test_mit_sot_buffer(self, constant_n_steps, n_steps_val):
565+
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=None)
566+
567+
def test_mit_sot_buffer_benchmark(self, constant_n_steps, n_steps_val, benchmark):
568+
self.buffer_tester(constant_n_steps, n_steps_val, benchmark=benchmark)

0 commit comments

Comments
 (0)