Skip to content

Add support for random Generators in Numba backend #691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from pytensor.graph.replace import clone_replace
from pytensor.graph.rewriting.basic import in2out, node_rewriter
from pytensor.graph.utils import MissingInputError
from pytensor.tensor.rewriting.shape import ShapeFeature


def infer_shape(outs, inputs, input_shapes):
Expand All @@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
# inside. We don't use the full ShapeFeature interface, but we
# let it initialize itself with an empty fgraph, otherwise we will
# need to do it manually

# TODO: ShapeFeature should live elsewhere
from pytensor.tensor.rewriting.shape import ShapeFeature

for inp, inp_shp in zip(inputs, input_shapes):
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
assert len(inp_shp) == inp.type.ndim
Expand Down Expand Up @@ -307,6 +310,7 @@ def __init__(
connection_pattern: list[list[bool]] | None = None,
strict: bool = False,
name: str | None = None,
destroy_map: dict[int, tuple[int, ...]] | None = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -464,6 +468,7 @@ def __init__(
if name is not None:
assert isinstance(name, str), "name must be None or string object"
self.name = name
self.destroy_map = destroy_map if destroy_map is not None else {}

def __eq__(self, other):
# TODO: recognize a copy
Expand Down Expand Up @@ -862,6 +867,7 @@ def make_node(self, *inputs):
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
destroy_map=self.destroy_map,
**self.kwargs,
)
new_inputs = (
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
include=["fast_run", "numba"],
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
),
)
Expand Down
62 changes: 19 additions & 43 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax
import numpy as np
from numpy.random import Generator, RandomState
from numpy.random import Generator
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
_coerce_to_uint32_array,
)
Expand All @@ -12,6 +12,7 @@
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i
from pytensor.tensor.type_other import NoneTypeT


try:
Expand Down Expand Up @@ -53,15 +54,6 @@
raise NotImplementedError(SIZE_NOT_COMPATIBLE)


@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state


@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
Expand All @@ -88,41 +80,36 @@


@jax_funcify.register(ptr.RandomVariable)
def jax_funcify_RandomVariable(op, node, **kwargs):
def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
"""JAX implementation of random variables."""
rv = node.outputs[1]
out_dtype = rv.type.dtype
static_shape = rv.type.shape

batch_ndim = op.batch_ndim(node)

# Try to pass static size directly to JAX
static_size = static_shape[:batch_ndim]
if None in static_size:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = node.inputs[1]
if isinstance(size_param, Constant):
size_tuple = tuple(size_param.data)
# PyTensor uses empty size to represent size = None
if len(size_tuple):
static_size = tuple(size_param.data)
size_param = op.size_param(node)
if isinstance(size_param, Constant) and not isinstance(
size_param.type, NoneTypeT
):
static_size = tuple(size_param.data)

# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if None in static_size:
assert_size_argument_jax_compatible(node)

def sample_fn(rng, size, dtype, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)

else:

def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, *parameters):
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)
Expand Down Expand Up @@ -162,7 +149,6 @@
@jax_sample_fn.register(ptr.LaplaceRV)
@jax_sample_fn.register(ptr.LogisticRV)
@jax_sample_fn.register(ptr.NormalRV)
@jax_sample_fn.register(ptr.StandardNormalRV)
def jax_sample_fn_loc_scale(op, node):
"""JAX implementation of random variables in the loc-scale families.

Expand Down Expand Up @@ -219,7 +205,6 @@
return sample_fn


@jax_sample_fn.register(ptr.RandIntRV)
@jax_sample_fn.register(ptr.IntegersRV)
@jax_sample_fn.register(ptr.UniformRV)
def jax_sample_fn_uniform(op, node):
Expand Down Expand Up @@ -305,11 +290,10 @@


@jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
def jax_funcify_choice(op, node):
def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"""JAX implementation of `ChoiceRV`."""

batch_ndim = op.batch_ndim(node)
a, *p, core_shape = node.inputs[3:]
a_core_ndim, *p_core_ndim, _ = op.ndims_params

if batch_ndim and a_core_ndim == 0:
Expand All @@ -318,12 +302,6 @@
"A default JAX rewrite should have materialized the implicit arange"
)

a_batch_ndim = a.type.ndim - a_core_ndim
if op.has_p_param:
[p] = p
[p_core_ndim] = p_core_ndim
p_batch_ndim = p.type.ndim - p_core_ndim

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
Expand All @@ -333,7 +311,7 @@
else:
a, core_shape = parameters
p = None
core_shape = tuple(np.asarray(core_shape))
core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim])

if batch_ndim == 0:
sample = jax.random.choice(
Expand All @@ -343,16 +321,16 @@
else:
if size is None:
if p is None:
size = a.shape[:a_batch_ndim]
size = a.shape[:batch_ndim]

Check warning on line 324 in pytensor/link/jax/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/random.py#L324

Added line #L324 was not covered by tests
else:
size = jax.numpy.broadcast_shapes(
a.shape[:a_batch_ndim],
p.shape[:p_batch_ndim],
a.shape[:batch_ndim],
p.shape[:batch_ndim],
)

a = jax.numpy.broadcast_to(a, size + a.shape[a_batch_ndim:])
a = jax.numpy.broadcast_to(a, size + a.shape[batch_ndim:])
if p is not None:
p = jax.numpy.broadcast_to(p, size + p.shape[p_batch_ndim:])
p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:])

batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))

Expand Down Expand Up @@ -386,19 +364,17 @@
"""JAX implementation of `PermutationRV`."""

batch_ndim = op.batch_ndim(node)
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(x,) = parameters
if batch_ndim:
# jax.random.permutation has no concept of batch dims
x_core_shape = x.shape[x_batch_ndim:]
if size is None:
size = x.shape[:x_batch_ndim]
size = x.shape[:batch_ndim]

Check warning on line 375 in pytensor/link/jax/dispatch/random.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/jax/dispatch/random.py#L375

Added line #L375 was not covered by tests
else:
x = jax.numpy.broadcast_to(x, size + x_core_shape)
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])

batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
Expand Down
16 changes: 14 additions & 2 deletions pytensor/link/numba/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from numba.extending import box, overload

from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -62,10 +63,16 @@ def numba_njit(*args, **kwargs):
kwargs.setdefault("no_cpython_wrapper", True)
kwargs.setdefault("no_cfunc_wrapper", True)

# Supress caching warnings
# Suppress cache warning for internal functions
# We have to add an ansi escape code for optional bold text by numba
warnings.filterwarnings(
"ignore",
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
message=(
"(\x1b\\[1m)*" # ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs)" '
"as it uses dynamic globals"
),
category=NumbaWarning,
)

Expand Down Expand Up @@ -434,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)

# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
NUMBA.optimizer(op.fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))

if len(op.fgraph.outputs) == 1:
Expand Down
Loading
Loading