|
| 1 | +from typing import Optional |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +from pytensor import In, Out |
| 6 | +from pytensor.compile import optdb, pfunc |
| 7 | +from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter |
| 8 | +from pytensor.graph.rewriting.basic import in2out |
| 9 | +from pytensor.scalar import constant |
| 10 | +from pytensor.tensor import NoneConst, and_, empty, scalar, set_subtensor |
| 11 | +from pytensor.tensor.type import DenseTensorType, TensorType |
| 12 | +from pytensor.tensor.type_other import NoneTypeT |
| 13 | + |
| 14 | + |
| 15 | +def validate_loop_update_types(update): |
| 16 | + assert update.outputs[0].type.dtype == "bool" |
| 17 | + for input_state, output_state in zip(update.inputs, update.outputs[1:]): |
| 18 | + assert input_state.type == output_state.type |
| 19 | + |
| 20 | + |
| 21 | +class Loop(Op): |
| 22 | + """Represent a do-while loop.""" |
| 23 | + |
| 24 | + def __init__( |
| 25 | + self, |
| 26 | + update: FunctionGraph, # (*state, *consts) -> (bool, *state) |
| 27 | + reverse: Optional[FunctionGraph] = None, |
| 28 | + ): |
| 29 | + validate_loop_update_types(update) |
| 30 | + self.state_types = [out.type for out in update.outputs[1:]] |
| 31 | + self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]] |
| 32 | + self.update = update |
| 33 | + self.reverse = reverse |
| 34 | + self._update_fn = None |
| 35 | + |
| 36 | + @property |
| 37 | + def update_fn(self): |
| 38 | + """Lazily compile the inner update function graph.""" |
| 39 | + if self._update_fn is not None: |
| 40 | + return self._update_fn |
| 41 | + |
| 42 | + fgraph = self.update |
| 43 | + wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs] |
| 44 | + wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs] |
| 45 | + |
| 46 | + self._update_fn = pfunc( |
| 47 | + wrapped_inputs, |
| 48 | + wrapped_outputs, |
| 49 | + mode="FAST_RUN", # TODO: Figure this out |
| 50 | + accept_inplace=False, |
| 51 | + on_unused_input="ignore", |
| 52 | + fgraph=fgraph, |
| 53 | + ) |
| 54 | + return self._update_fn |
| 55 | + |
| 56 | + def make_node(self, *inputs): |
| 57 | + assert len(inputs) == len(self.state_types) + len(self.const_types) |
| 58 | + |
| 59 | + states = inputs[: len(self.state_types)] |
| 60 | + states = [ |
| 61 | + inp_type.filter_variable(inp) |
| 62 | + for inp_type, inp in zip(self.state_types, states) |
| 63 | + ] |
| 64 | + |
| 65 | + consts = inputs[len(self.state_types) :] |
| 66 | + consts = [ |
| 67 | + inp_type.filter_variable(inp) |
| 68 | + for inp_type, inp in zip(self.const_types, consts) |
| 69 | + ] |
| 70 | + |
| 71 | + return Apply( |
| 72 | + self, |
| 73 | + [*states, *consts], |
| 74 | + [state_type() for state_type in self.state_types], |
| 75 | + ) |
| 76 | + |
| 77 | + def infer_shape(self, fgraph, node, input_shapes): |
| 78 | + return input_shapes[: len(self.state_types)] |
| 79 | + |
| 80 | + def perform(self, node, inputs, output_storage): |
| 81 | + update_fn = self.update_fn |
| 82 | + |
| 83 | + states = inputs[: len(self.state_types)] |
| 84 | + consts = inputs[len(self.state_types) :] |
| 85 | + while True: |
| 86 | + go_on, *states = update_fn(*states, *consts) |
| 87 | + if not go_on: |
| 88 | + break |
| 89 | + |
| 90 | + for i, state in enumerate(states): |
| 91 | + output_storage[i][0] = state |
| 92 | + |
| 93 | + def L_Op(self, *args): |
| 94 | + if not self.reverse: |
| 95 | + raise NotImplementedError() |
| 96 | + # Use L_Op of self.reverse.update |
| 97 | + ... |
| 98 | + |
| 99 | + def R_Op(self, *args): |
| 100 | + # Use R_op of self.update |
| 101 | + ... |
| 102 | + |
| 103 | + |
| 104 | +class Scan(Op): |
| 105 | + """Represent a scan. |
| 106 | +
|
| 107 | + This Op must always be converted to a Loop during compilation |
| 108 | + """ |
| 109 | + |
| 110 | + def __init__( |
| 111 | + self, |
| 112 | + update: FunctionGraph, # (*state, *consts) -> (bool, *state) |
| 113 | + reverse: Optional[FunctionGraph] = None, |
| 114 | + ): |
| 115 | + validate_loop_update_types(update) |
| 116 | + self.state_types = [out.type for out in update.outputs[1:]] |
| 117 | + self.trace_types: list[Type] = [] |
| 118 | + for state_type in self.state_types: |
| 119 | + # Accommodate SparseTensors and Scalars |
| 120 | + if isinstance(state_type, DenseTensorType): |
| 121 | + self.trace_types.append( |
| 122 | + DenseTensorType( |
| 123 | + shape=(None, *state_type.shape), dtype=state_type.dtype |
| 124 | + ) |
| 125 | + ) |
| 126 | + else: |
| 127 | + # We can't concatenate all types of states, such as RandomTypes |
| 128 | + self.trace_types.append(NoneConst.type) |
| 129 | + self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]] |
| 130 | + self.update = update |
| 131 | + self.reverse = reverse |
| 132 | + self._update_fn = None |
| 133 | + |
| 134 | + def make_node(self, n_steps, *inputs): |
| 135 | + assert len(inputs) == len(self.state_types) + len(self.const_types) |
| 136 | + |
| 137 | + n_steps = TensorType(dtype="int64", shape=()).filter_variable(n_steps) |
| 138 | + |
| 139 | + states = inputs[: len(self.state_types)] |
| 140 | + states = [ |
| 141 | + inp_type.filter_variable(inp) |
| 142 | + for inp_type, inp in zip(self.state_types, states) |
| 143 | + ] |
| 144 | + |
| 145 | + consts = inputs[len(self.state_types) :] |
| 146 | + consts = [ |
| 147 | + inp_type.filter_variable(inp) |
| 148 | + for inp_type, inp in zip(self.const_types, consts) |
| 149 | + ] |
| 150 | + |
| 151 | + return Apply( |
| 152 | + self, |
| 153 | + [n_steps, *states, *consts], |
| 154 | + [output_type() for output_type in self.state_types + self.trace_types], |
| 155 | + ) |
| 156 | + |
| 157 | + def infer_shape(self, fgraph, node, input_shapes): |
| 158 | + n_steps = node.inputs[0] |
| 159 | + state_shapes = input_shapes[1 : len(self.state_types) + 1] |
| 160 | + trace_shapes = [ |
| 161 | + (n_steps, *state_shape) if state_shape is not None else None |
| 162 | + for state_shape in state_shapes |
| 163 | + ] |
| 164 | + return state_shapes + trace_shapes |
| 165 | + |
| 166 | + def perform(self, node, inputs, output_storage): |
| 167 | + raise RuntimeError("Loop Op should not be present in compiled graph") |
| 168 | + |
| 169 | + def L_op(self, *args): |
| 170 | + # Use trace outputs |
| 171 | + ... |
| 172 | + |
| 173 | + def R_op(self, *args): |
| 174 | + # Use R_op of self.update |
| 175 | + ... |
| 176 | + |
| 177 | + |
| 178 | +@node_rewriter([Scan]) |
| 179 | +def scan_to_loop(fgraph, node): |
| 180 | + """Rewrite a Scan Op into a Loop Op""" |
| 181 | + op: Scan = node.op # type: ignore |
| 182 | + |
| 183 | + n_steps = node.inputs[0] |
| 184 | + |
| 185 | + n_state_vars = len(op.state_types) |
| 186 | + old_states = node.outputs[:n_state_vars] |
| 187 | + old_traces = node.outputs[n_state_vars:] |
| 188 | + |
| 189 | + # Only include the intermediate states that are used elsewhere |
| 190 | + used_traces_idxs = [ |
| 191 | + i |
| 192 | + for i, trace in enumerate(node.outputs[n_state_vars:]) |
| 193 | + if fgraph.clients[trace] |
| 194 | + ] |
| 195 | + |
| 196 | + # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced |
| 197 | + for trace_idx in used_traces_idxs: |
| 198 | + assert not isinstance(old_states[trace_idx].type, NoneTypeT) |
| 199 | + |
| 200 | + update_fg = op.update |
| 201 | + prev_idx = scalar(dtype="int64", name="prev_idx") |
| 202 | + prev_states = update_fg.inputs[:n_state_vars] |
| 203 | + prev_traces = [old_traces[i].type() for i in used_traces_idxs] |
| 204 | + consts = update_fg.inputs[n_state_vars * 2 :] |
| 205 | + |
| 206 | + go_on, *next_states = update_fg.outputs |
| 207 | + next_idx = prev_idx + 1 |
| 208 | + next_idx.name = "next_idx" |
| 209 | + next_traces = [ |
| 210 | + set_subtensor(prev_trace[prev_idx], next_states[trace_idx]) |
| 211 | + for trace_idx, prev_trace in zip(used_traces_idxs, prev_traces) |
| 212 | + ] |
| 213 | + go_on = and_(go_on, next_idx < n_steps) |
| 214 | + go_on.name = "go_on" |
| 215 | + |
| 216 | + new_update_fg = FunctionGraph( |
| 217 | + inputs=[prev_idx, *prev_states, *prev_traces, *consts], |
| 218 | + outputs=[go_on, next_idx, *next_states, *next_traces], |
| 219 | + ) |
| 220 | + |
| 221 | + # TODO: Implement Reverse? |
| 222 | + loop_op = Loop(update=new_update_fg) |
| 223 | + |
| 224 | + init_idx = constant(np.array(0, dtype="int64")) |
| 225 | + init_states = node.inputs[1 : len(op.state_types) + 1] |
| 226 | + init_traces = [ |
| 227 | + empty((n_steps, *tuple(init_states[trace_idx].shape))) |
| 228 | + for trace_idx in used_traces_idxs |
| 229 | + ] |
| 230 | + final_idx, *new_outs = loop_op(init_idx, *init_states, *init_traces) |
| 231 | + new_states = new_outs[:n_state_vars] |
| 232 | + new_traces = new_outs[n_state_vars:] |
| 233 | + |
| 234 | + replacements = dict(zip(old_states, new_states)) |
| 235 | + for trace_idx, new_trace in zip(used_traces_idxs, new_traces): |
| 236 | + replacements[old_traces[trace_idx]] = new_trace[:final_idx] |
| 237 | + return replacements |
| 238 | + |
| 239 | + |
| 240 | +# TODO: Create new Loop dataset |
| 241 | +optdb.register( |
| 242 | + "scan_to_loop", |
| 243 | + in2out(scan_to_loop), |
| 244 | + "fast_compile", |
| 245 | + "fast_run", |
| 246 | + position=-0.1, # TODO: When? |
| 247 | +) |
0 commit comments