|
| 1 | +from typing import List, Tuple |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from pytensor import Variable, as_symbolic |
| 6 | +from pytensor.graph import FunctionGraph |
| 7 | +from pytensor.loop.op import Scan |
| 8 | +from pytensor.scan.utils import until |
| 9 | +from pytensor.tensor import as_tensor, empty_like |
| 10 | + |
| 11 | + |
| 12 | +def scan( |
| 13 | + fn, |
| 14 | + init_states=None, |
| 15 | + sequences=None, |
| 16 | + non_sequences=None, |
| 17 | + n_steps=None, |
| 18 | + go_backwards=False, |
| 19 | +) -> Tuple[List[Variable], List[Variable]]: |
| 20 | + if sequences is None and n_steps is None: |
| 21 | + raise ValueError("Must provide n_steps when scanning without sequences") |
| 22 | + |
| 23 | + if init_states is None: |
| 24 | + init_states = [] |
| 25 | + else: |
| 26 | + if not isinstance(init_states, (tuple, list)): |
| 27 | + init_states = [init_states] |
| 28 | + init_states = [as_symbolic(i) for i in init_states] |
| 29 | + |
| 30 | + if sequences is None: |
| 31 | + sequences = [] |
| 32 | + else: |
| 33 | + if not isinstance(sequences, (tuple, list)): |
| 34 | + sequences = [sequences] |
| 35 | + sequences = [as_tensor(s) for s in sequences] |
| 36 | + |
| 37 | + if non_sequences is None: |
| 38 | + non_sequences = [] |
| 39 | + else: |
| 40 | + if not isinstance(non_sequences, (tuple, list)): |
| 41 | + non_sequences = [non_sequences] |
| 42 | + non_sequences = [as_symbolic(n) for n in non_sequences] |
| 43 | + |
| 44 | + # Note: Old scan order is sequences + init + non_sequences |
| 45 | + inner_sequences = [s[0] for s in sequences] |
| 46 | + inner_inputs = [i.type() for i in init_states + inner_sequences + non_sequences] |
| 47 | + inner_outputs = fn(*inner_inputs) |
| 48 | + if not isinstance(inner_outputs, (tuple, list)): |
| 49 | + inner_outputs = [inner_outputs] |
| 50 | + next_states = [out for out in inner_outputs if not isinstance(out, until)] |
| 51 | + |
| 52 | + if len(next_states) > len(init_states): |
| 53 | + if not init_states: |
| 54 | + init_states = [None] * len(next_states) |
| 55 | + else: |
| 56 | + raise ValueError( |
| 57 | + "Please provide None as `init` for any output that is not carried over (i.e. it behaves like a map) " |
| 58 | + ) |
| 59 | + |
| 60 | + # Replace None init by dummy empty tensors |
| 61 | + prev_states = [] |
| 62 | + for i, (init_state, next_state) in enumerate(zip(init_states, next_states)): |
| 63 | + if init_state is None: |
| 64 | + init_state = empty_like(next_state) |
| 65 | + init_state.name = "empty_init_state" |
| 66 | + inner_inputs.insert(i, init_state.type()) |
| 67 | + prev_states.append(init_state) |
| 68 | + |
| 69 | + until_condition = [out.condition for out in inner_outputs if isinstance(out, until)] |
| 70 | + if not until_condition: |
| 71 | + until_condition = [as_tensor(np.array(True))] |
| 72 | + if len(until_condition) > 1: |
| 73 | + raise ValueError("Only one until condition can be returned") |
| 74 | + |
| 75 | + update_fg = FunctionGraph( |
| 76 | + inputs=inner_inputs, outputs=until_condition + next_states |
| 77 | + ) |
| 78 | + scan_op = Scan(update_fg=update_fg, n_sequences=len(sequences)) |
| 79 | + scan_outs = scan_op(n_steps, *prev_states, *sequences, *non_sequences) |
| 80 | + assert isinstance(scan_outs, list) |
| 81 | + last_states = scan_outs[: scan_op.n_states] |
| 82 | + traces = scan_outs[scan_op.n_states :] |
| 83 | + |
| 84 | + return last_states, traces |
| 85 | + |
| 86 | + |
| 87 | +def map( |
| 88 | + fn, |
| 89 | + sequences, |
| 90 | + non_sequences=None, |
| 91 | + go_backwards=False, |
| 92 | +): |
| 93 | + _, traces = scan( |
| 94 | + fn=fn, |
| 95 | + sequences=sequences, |
| 96 | + non_sequences=non_sequences, |
| 97 | + go_backwards=go_backwards, |
| 98 | + ) |
| 99 | + if len(traces) == 1: |
| 100 | + return traces[0] |
| 101 | + return traces |
| 102 | + |
| 103 | + |
| 104 | +def reduce( |
| 105 | + fn, |
| 106 | + init_states, |
| 107 | + sequences, |
| 108 | + non_sequences=None, |
| 109 | + go_backwards=False, |
| 110 | +): |
| 111 | + final_states, _ = scan( |
| 112 | + fn=fn, |
| 113 | + init_states=init_states, |
| 114 | + sequences=sequences, |
| 115 | + non_sequences=non_sequences, |
| 116 | + go_backwards=go_backwards, |
| 117 | + ) |
| 118 | + if len(final_states) == 1: |
| 119 | + return final_states[0] |
| 120 | + return final_states |
0 commit comments