Skip to content

Commit cb417fe

Browse files
ricardoV94twiecki
authored andcommitted
Numba scan: reuse scalar arrays for taps from vector inputs
Indexing vector inputs to create taps during scan, yields numeric variables which must be wrapped again into scalar arrays before passing into the inernal function. This commit pre-allocates such arrays and reuses them during looping.
1 parent 3f5b76b commit cb417fe

File tree

2 files changed

+97
-12
lines changed

2 files changed

+97
-12
lines changed

pytensor/link/numba/dispatch/scan.py

+44-11
Original file line numberDiff line numberDiff line change
@@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs):
112112
# Inner-inputs are ordered as follows:
113113
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
114114
# shared-inputs + non-sequences.
115+
temp_scalar_storage_alloc_stmts: List[str] = []
116+
inner_in_exprs_scalar: List[str] = []
115117
inner_in_exprs: List[str] = []
116118

117119
def add_inner_in_expr(
118-
outer_in_name: str, tap_offset: Optional[int], storage_size_var: Optional[str]
120+
outer_in_name: str,
121+
tap_offset: Optional[int],
122+
storage_size_var: Optional[str],
123+
vector_slice_opt: bool,
119124
):
120125
"""Construct an inner-input expression."""
121126
storage_name = outer_in_to_storage_name.get(outer_in_name, outer_in_name)
122-
indexed_inner_in_str = (
123-
storage_name
124-
if tap_offset is None
125-
else idx_to_str(
126-
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
127+
if vector_slice_opt:
128+
indexed_inner_in_str_scalar = idx_to_str(
129+
storage_name, tap_offset, size=storage_size_var, allow_scalar=True
130+
)
131+
temp_storage = f"{storage_name}_temp_scalar_{tap_offset}"
132+
storage_dtype = outer_in_var.type.numpy_dtype.name
133+
temp_scalar_storage_alloc_stmts.append(
134+
f"{temp_storage} = np.empty((), dtype=np.{storage_dtype})"
135+
)
136+
inner_in_exprs_scalar.append(
137+
f"{temp_storage}[()] = {indexed_inner_in_str_scalar}"
138+
)
139+
indexed_inner_in_str = temp_storage
140+
else:
141+
indexed_inner_in_str = (
142+
storage_name
143+
if tap_offset is None
144+
else idx_to_str(
145+
storage_name, tap_offset, size=storage_size_var, allow_scalar=False
146+
)
127147
)
128-
)
129148
inner_in_exprs.append(indexed_inner_in_str)
130149

131150
for outer_in_name in outer_in_seqs_names:
132151
# These outer-inputs are indexed without offsets or storage wrap-around
133-
add_inner_in_expr(outer_in_name, 0, None)
152+
outer_in_var = outer_in_names_to_vars[outer_in_name]
153+
is_vector = outer_in_var.ndim == 1
154+
add_inner_in_expr(outer_in_name, 0, None, vector_slice_opt=is_vector)
134155

135156
inner_in_names_to_input_taps: Dict[str, Tuple[int, ...]] = dict(
136157
zip(
@@ -232,7 +253,13 @@ def add_output_storage_post_proc_stmt(
232253
for in_tap in input_taps:
233254
tap_offset = in_tap + tap_storage_size
234255
assert tap_offset >= 0
235-
add_inner_in_expr(outer_in_name, tap_offset, storage_size_name)
256+
is_vector = outer_in_var.ndim == 1
257+
add_inner_in_expr(
258+
outer_in_name,
259+
tap_offset,
260+
storage_size_name,
261+
vector_slice_opt=is_vector,
262+
)
236263

237264
output_taps = inner_in_names_to_output_taps.get(
238265
outer_in_name, [tap_storage_size]
@@ -253,7 +280,7 @@ def add_output_storage_post_proc_stmt(
253280

254281
else:
255282
storage_size_stmt = ""
256-
add_inner_in_expr(outer_in_name, None, None)
283+
add_inner_in_expr(outer_in_name, None, None, vector_slice_opt=False)
257284
inner_out_to_outer_in_stmts.append(storage_name)
258285

259286
output_idx = outer_output_names.index(storage_name)
@@ -325,17 +352,19 @@ def add_output_storage_post_proc_stmt(
325352
)
326353

327354
for name in outer_in_non_seqs_names:
328-
add_inner_in_expr(name, None, None)
355+
add_inner_in_expr(name, None, None, vector_slice_opt=False)
329356

330357
if op.info.as_while:
331358
# The inner function will return a boolean as the last value
332359
inner_out_to_outer_in_stmts.append("cond")
333360

334361
assert len(inner_in_exprs) == len(op.fgraph.inputs)
335362

363+
inner_scalar_in_args_to_temp_storage = "\n".join(inner_in_exprs_scalar)
336364
inner_in_args = create_arg_string(inner_in_exprs)
337365
inner_outputs = create_tuple_string(inner_output_names)
338366
input_storage_block = "\n".join(storage_alloc_stmts)
367+
input_temp_scalar_storage_block = "\n".join(temp_scalar_storage_alloc_stmts)
339368
output_storage_post_processing_block = "\n".join(output_storage_post_proc_stmts)
340369
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
341370

@@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}):
348377
349378
{indent(input_storage_block, " " * 4)}
350379
380+
{indent(input_temp_scalar_storage_block, " " * 4)}
381+
351382
i = 0
352383
cond = np.array(False)
353384
while i < n_steps and not cond.item():
385+
{indent(inner_scalar_in_args_to_temp_storage, " " * 8)}
386+
354387
{inner_outputs} = scan_inner_func({inner_in_args})
355388
{indent(inner_out_post_processing_block, " " * 8)}
356389
{indent(inner_out_to_outer_out_stmts, " " * 8)}

tests/link/numba/test_scan.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pytest
33

4+
import pytensor
45
import pytensor.tensor as at
56
from pytensor import config, function, grad
67
from pytensor.compile.mode import Mode, get_mode
@@ -9,7 +10,7 @@
910
from pytensor.scan.basic import scan
1011
from pytensor.scan.op import Scan
1112
from pytensor.scan.utils import until
12-
from pytensor.tensor import log, vector
13+
from pytensor.tensor import log, scalar, vector
1314
from pytensor.tensor.elemwise import Elemwise
1415
from pytensor.tensor.random.utils import RandomStream
1516
from tests import unittest_tools as utt
@@ -442,3 +443,54 @@ def test_inner_graph_optimized():
442443
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
443444
inner_scan_node.op.scalar_op, Log1p
444445
)
446+
447+
448+
def test_vector_taps_benchmark(benchmark):
449+
"""Test vector taps performance.
450+
451+
Vector taps get indexed into numeric types, that must be wrapped back into
452+
scalar arrays. The numba Scan implementation has an optimization to reuse
453+
these scalar arrays instead of allocating them in every iteration.
454+
"""
455+
n_steps = 1000
456+
457+
seq1 = vector("seq1", dtype="float64", shape=(n_steps,))
458+
seq2 = vector("seq2", dtype="float64", shape=(n_steps,))
459+
mitsot_init = vector("mitsot_init", dtype="float64", shape=(2,))
460+
sitsot_init = scalar("sitsot_init", dtype="float64")
461+
462+
def step(seq1, seq2, mitsot1, mitsot2, sitsot1):
463+
mitsot3 = mitsot1 + seq2 + mitsot2 + seq1
464+
sitsot2 = sitsot1 + mitsot3
465+
return mitsot3, sitsot2
466+
467+
outs, _ = scan(
468+
fn=step,
469+
sequences=[seq1, seq2],
470+
outputs_info=[
471+
dict(initial=mitsot_init, taps=[-2, -1]),
472+
dict(initial=sitsot_init, taps=[-1]),
473+
],
474+
)
475+
476+
rng = np.random.default_rng(474)
477+
test = {
478+
seq1: rng.normal(size=n_steps),
479+
seq2: rng.normal(size=n_steps),
480+
mitsot_init: rng.normal(size=(2,)),
481+
sitsot_init: rng.normal(),
482+
}
483+
484+
numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA"))
485+
scan_nodes = [
486+
node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
487+
]
488+
assert len(scan_nodes) == 1
489+
numba_res = numba_fn(*test.values())
490+
491+
ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE"))
492+
ref_res = ref_fn(*test.values())
493+
for numba_r, ref_r in zip(numba_res, ref_res):
494+
np.testing.assert_array_almost_equal(numba_r, ref_r)
495+
496+
benchmark(numba_fn, *test.values())

0 commit comments

Comments
 (0)