Skip to content

Fix Scan JAX dispatcher #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 156 additions & 122 deletions pytensor/link/jax/dispatch/scan.py
Original file line number Diff line number Diff line change
@@ -1,159 +1,193 @@
import jax
import jax.numpy as jnp

from pytensor.graph.fg import FunctionGraph
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.scan.op import Scan
from pytensor.scan.utils import ScanArgs


@jax_funcify.register(Scan)
def jax_funcify_Scan(op, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
def jax_funcify_Scan(op: Scan, **kwargs):
info = op.info

def scan(*outer_inputs):
scan_args = ScanArgs(
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
if info.as_while:
raise NotImplementedError("While Scan cannot yet be converted to JAX")

if info.n_mit_mot:
raise NotImplementedError(
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
)

# `outer_inputs` is a list with the following composite form:
# [n_steps]
# + outer_in_seqs
# + outer_in_mit_mot
# + outer_in_mit_sot
# + outer_in_sit_sot
# + outer_in_shared
# + outer_in_nit_sot
# + outer_in_non_seqs
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs

# TODO: mit_mots
mit_mot_in_slices = []

mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
pos_taps = [abs(t) for t in tap if t > 0]
max_neg = max(neg_taps) if neg_taps else 0
max_pos = max(pos_taps) if pos_taps else 0
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)

sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
# Optimize inner graph
rewriter = op.mode_instance.optimizer
rewriter(op.fgraph)
scan_inner_func = jax_funcify(op.fgraph, **kwargs)

def scan(*outer_inputs):
# Extract JAX scan inputs
outer_inputs = list(outer_inputs)
n_steps = outer_inputs[0] # JAX `length`
seqs = op.outer_seqs(outer_inputs) # JAX `xs`

mit_sot_init = []
for tap, seq in zip(op.info.mit_sot_in_slices, op.outer_mitsot(outer_inputs)):
init_slice = seq[: abs(min(tap))]
mit_sot_init.append(init_slice)

sit_sot_init = [seq[0] for seq in op.outer_sitsot(outer_inputs)]

init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
mit_sot_init,
sit_sot_init,
op.outer_shared(outer_inputs),
op.outer_non_seqs(outer_inputs),
) # JAX `init`

def jax_args_to_inner_func_args(carry, x):
"""Convert JAX scan arguments into format expected by scan_inner_func.

scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
"""

def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
# `carry` contains all inner taps, shared terms, and non_seqs
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_non_seqs,
) = carry

# `x` contains the in_seqs
inner_in_seqs = x

# `inner_scan_inputs` is a list with the following composite form:
# inner_in_seqs
# + sum(inner_in_mit_mot, [])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])

inner_scan_inputs = sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
# `x` contains the inner sequences
inner_seqs = x

mit_sot_flatten = []
for array, index in zip(inner_mit_sot, op.info.mit_sot_in_slices):
mit_sot_flatten.extend(array[jnp.array(index)])

inner_scan_inputs = [
*inner_seqs,
*mit_sot_flatten,
*inner_sit_sot,
*inner_shared,
*inner_non_seqs,
]

return inner_scan_inputs

def inner_scan_outs_to_jax_outs(
op,
def inner_func_outs_to_jax_outs(
old_carry,
inner_scan_outs,
):
"""Convert inner_scan_func outputs into format expected by JAX scan.

old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
"""
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
inner_mit_sot,
inner_sit_sot,
inner_shared,
inner_non_seqs,
) = old_carry

def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)

inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
inner_mit_sot_outs = op.inner_mitsot_outs(inner_scan_outs)
inner_sit_sot_outs = op.inner_sitsot_outs(inner_scan_outs)
inner_nit_sot_outs = op.inner_nitsot_outs(inner_scan_outs)
inner_shared_outs = op.inner_shared_outs(inner_scan_outs)

# Replace the oldest mit_sot tap by the newest value
inner_mit_sot_new = [
jnp.concatenate([old_mit_sot[1:], new_val[None, ...]], axis=0)
for old_mit_sot, new_val in zip(
inner_mit_sot,
inner_mit_sot_outs,
)
]

# This should contain all inner-output taps, non_seqs, and shared
# terms
if not inner_in_sit_sot:
inner_out_sit_sot = []
else:
inner_out_sit_sot = inner_scan_outs
# Nothing needs to be done with sit_sot
inner_sit_sot_new = inner_sit_sot_outs

inner_shared_new = inner_shared
# Replace old shared inputs by new shared outputs
inner_shared_new[: len(inner_shared_outs)] = inner_shared_outs

new_carry = (
inner_in_mit_mot,
inner_mit_sot_new,
inner_sit_sot_new,
inner_shared_new,
inner_non_seqs,
)

# Shared variables and non_seqs are not traced
traced_outs = [
*inner_mit_sot_outs,
*inner_sit_sot_outs,
*inner_nit_sot_outs,
]

return new_carry, traced_outs

def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_func_args(carry, x)
inner_scan_outs = list(scan_inner_func(*inner_args))
new_carry, traced_outs = inner_func_outs_to_jax_outs(carry, inner_scan_outs)
return new_carry, traced_outs

# Extract PyTensor scan outputs
final_carry, traces = jax.lax.scan(
jax_inner_func, init_carry, seqs, length=n_steps
)

def get_partial_traces(traces):
"""Convert JAX scan traces to PyTensor traces.

We need to:
1. Prepend initial states to JAX output traces
2. Slice final traces if Scan was instructed to only keep a portion
"""

init_states = mit_sot_init + sit_sot_init + [None] * op.info.n_nit_sot
buffers = (
op.outer_mitsot(outer_inputs)
+ op.outer_sitsot(outer_inputs)
+ op.outer_nitsot(outer_inputs)
)
partial_traces = []
for init_state, trace, buffer in zip(init_states, traces, buffers):
if init_state is not None:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
full_trace = jnp.concatenate(
[jnp.atleast_1d(init_state), jnp.atleast_1d(trace)],
axis=0,
)
buffer_size = buffer.shape[0]
else:
# NIT-SOT: Buffer is just the number of entries that should be returned
full_trace = jnp.atleast_1d(trace)
buffer_size = buffer

partial_trace = full_trace[-buffer_size:]
partial_traces.append(partial_trace)

return partial_traces

def get_shared_outs(final_carry):
"""Retrive last state of shared_outs from final_carry.

These outputs cannot be traced in PyTensor Scan
"""
(
inner_out_mit_sot,
inner_out_sit_sot,
inner_in_shared,
inner_out_shared,
inner_in_non_seqs,
)
) = final_carry

return new_carry
shared_outs = inner_out_shared[: info.n_shared_outs]
return list(shared_outs)

def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = list(jax_at_inner_func(*inner_args))
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs

_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)

# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)

if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
scan_outs_final = get_partial_traces(traces) + get_shared_outs(final_carry)

if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
if len(scan_outs_final) == 1:
scan_outs_final = scan_outs_final[0]
return scan_outs_final

return scan
17 changes: 7 additions & 10 deletions tests/link/jax/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,13 @@
import pytest

from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op, get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.ifelse import ifelse
from pytensor.link.jax import JAXLinker
from pytensor.raise_op import assert_op
from pytensor.tensor.type import dscalar, scalar, vector

Expand All @@ -27,19 +25,18 @@ def set_pytensor_flags():
jax = pytest.importorskip("jax")


jax_mode = Mode(
JAXLinker(), RewriteDatabaseQuery(include=["jax"], exclude=["cxx_only", "BlasOpt"])
)
py_mode = Mode(
"py", RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
)
# We assume that the JAX mode includes all the rewrites needed to transpile JAX graphs
jax_mode = get_mode("JAX")
py_mode = get_mode("FAST_COMPILE")


def compare_jax_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
assert_fn: Optional[Callable] = None,
must_be_device_array: bool = True,
jax_mode=jax_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and jax compiled output for testing equality

Expand Down Expand Up @@ -87,7 +84,7 @@ def compare_jax_and_py(
else:
assert_fn(jax_res, py_res)

return jax_res
return pytensor_jax_fn, jax_res


def test_jax_FunctionGraph_once():
Expand Down
Loading