Skip to content

Allow explicit RNG and Sparse input types in JAX functions #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,9 @@ def create_thunk_inputs(self, storage_map: Dict[Variable, List[Any]]) -> List[An
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""

def typify(self, var: Variable):
return var

def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call."""
return out
Expand Down Expand Up @@ -735,7 +738,7 @@ def make_all(self, input_storage=None, output_storage=None, storage_map=None):
return (
fn,
[
Container(input, storage)
Container(self.typify(input), storage)
for input, storage in zip(fgraph.inputs, input_storage)
],
[
Expand Down
8 changes: 4 additions & 4 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,11 @@ def assert_fn(x, *inputs):
def jnp_safe_copy(x):
try:
res = jnp.copy(x)
except NotImplementedError:
warnings.warn(
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
)
except (NotImplementedError, TypeError):
if hasattr(x, "copy"):
warnings.warn(
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
)
res = jnp.array(x.copy())
else:
warnings.warn(f"Object has no `copy` method: {x}")
Expand Down
110 changes: 56 additions & 54 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pytensor.tensor.random.basic as aer
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.shape import Shape, Shape_i


Expand Down Expand Up @@ -57,8 +58,7 @@ def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state
return state["state"]["key"][0:2]


@jax_typify.register(Generator)
Expand All @@ -83,7 +83,36 @@ def jax_typify_Generator(rng, **kwargs):
state_32 = _coerce_to_uint32_array(state["state"]["state"])
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
state["state"]["state"] = state_32[0] << 32 | state_32[1]
return state
return state["jax_state"]


class RandomPRNGKeyType(RandomType[jax.random.PRNGKey]):
"""JAX-compatible PRNGKey type.

This type is not exposed to users directly.

It is introduced by the JIT linker in place of any RandomType input
variables used in the original function. Nodes in the function graph will
still show the original types as inputs and outputs.
"""

def filter(self, data, strict: bool = False, allow_downcast=None):
# PRNGs are just JAX Arrays, we assume this is a valid one!
if isinstance(data, jax.Array):
return data

if strict:
raise TypeError()

return jax_typify(data)


random_prng_key_type = RandomPRNGKeyType()


@jax_typify.register(RandomType)
def jax_typify_RandomType(type):
return random_prng_key_type()


@jax_funcify.register(aer.RandomVariable)
Expand Down Expand Up @@ -130,12 +159,10 @@ def jax_sample_fn_generic(op):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -157,13 +184,11 @@ def jax_sample_fn_loc_scale(op):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -175,12 +200,10 @@ def jax_sample_fn_no_dtype(op):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -201,15 +224,13 @@ def jax_sample_fn_uniform(op):
name = "randint"
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
minval, maxval = parameters
sample = jax_op(
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
)
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -226,13 +247,11 @@ def jax_sample_fn_shape_rate(op):
name = op.name
jax_op = getattr(jax.random, name)

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
(shape, rate) = parameters
sample = jax_op(sampling_key, shape, size, dtype) / rate
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -241,13 +260,11 @@ def sample_fn(rng, size, dtype, *parameters):
def jax_sample_fn_exponential(op):
"""JAX implementation of `ExponentialRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
(scale,) = parameters
sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -256,17 +273,15 @@ def sample_fn(rng, size, dtype, *parameters):
def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
(
df,
loc,
scale,
) = parameters
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -275,13 +290,11 @@ def sample_fn(rng, size, dtype, *parameters):
def jax_funcify_choice(op):
"""JAX implementation of `ChoiceRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
(a, p, replace) = parameters
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
rng["jax_state"] = rng_key
return (rng, smpl_value)
return (rng_key, smpl_value)

return sample_fn

Expand All @@ -290,13 +303,11 @@ def sample_fn(rng, size, dtype, *parameters):
def jax_sample_fn_permutation(op):
"""JAX implementation of `PermutationRV`."""

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, *parameters):
rng_key, sampling_key = jax.random.split(rng_key, 2)
(x,) = parameters
sample = jax.random.permutation(sampling_key, x)
rng["jax_state"] = rng_key
return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -311,15 +322,12 @@ def jax_sample_fn_binomial(op):

from numpyro.distributions.util import binomial

def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, n, p):
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = binomial(key=sampling_key, n=n, p=p, shape=size)

rng["jax_state"] = rng_key

return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -334,15 +342,12 @@ def jax_sample_fn_multinomial(op):

from numpyro.distributions.util import multinomial

def sample_fn(rng, size, dtype, n, p):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, n, p):
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = multinomial(key=sampling_key, n=n, p=p, shape=size)

rng["jax_state"] = rng_key

return (rng, sample)
return (rng_key, sample)

return sample_fn

Expand All @@ -357,17 +362,14 @@ def jax_sample_fn_vonmises(op):

from numpyro.distributions.util import von_mises_centered

def sample_fn(rng, size, dtype, mu, kappa):
rng_key = rng["jax_state"]
def sample_fn(rng_key, size, dtype, mu, kappa):
rng_key, sampling_key = jax.random.split(rng_key, 2)

sample = von_mises_centered(
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
)
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi

rng["jax_state"] = rng_key

return (rng, sample)
return (rng_key, sample)

return sample_fn
60 changes: 44 additions & 16 deletions pytensor/link/jax/dispatch/sparse.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,66 @@
import jax.experimental.sparse as jsp
from scipy.sparse import spmatrix

from pytensor.graph.basic import Constant
from pytensor.graph.type import HasDataType
from pytensor.link.jax.dispatch import jax_funcify, jax_typify
from pytensor.sparse.basic import Dot, StructuredDot
from pytensor.sparse.basic import Dot, StructuredDot, Transpose
from pytensor.sparse.type import SparseTensorType
from pytensor.tensor import TensorType


@jax_typify.register(spmatrix)
def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
# Note: This changes the type of the constants from CSR/CSC to BCOO
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
# and it would break the premise of one graph -> multiple backends.
# The same situation happens with RandomGenerators...
return jsp.BCOO.from_scipy_sparse(matrix)


class BCOOType(TensorType, HasDataType):
"""JAX-compatible BCOO type.

This type is not exposed to users directly.

It is introduced by the JIT linker in place of any SparseTensorType input
variables used in the original function. Nodes in the function graph will
still show the original types as inputs and outputs.
"""

def filter(self, data, strict: bool = False, allow_downcast=None):
if isinstance(data, jsp.BCOO):
return data

if strict:
raise TypeError()

return jax_typify(data)


@jax_typify.register(SparseTensorType)
def jax_typify_SparseTensorType(type):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea to have BCOO be the default sparse tensor type, or is it a stopgap? I think some algorithms prefer different types, so it'd be good long term to have different subclasses for SparseTensorType (BCOO, CSC, etc.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From reading JAX docs it seems they are pushing for BCOO only at the moment

Copy link

@bwengals bwengals Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jax's sparse support isn't great though, I'm not sure they're the best lead to follow. Or, I guess this PR is about pytensor's Jax support only and not necessarily other backends?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just JAX backend. AFAICT BCOO is the only thing somewhat supported. Their other format (CSC or CSR) doesn't allow for any of the other jax transformations (vmap, grad, jit?). They pushed a paper on BCOO so I think it's really what their planning publicly at least.

Pytensor itself uses scipy formats as well as numba (haven't worked on it much tough)

Copy link

@bwengals bwengals Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. I didn't know about the paper, will try and find that. And that will make it more difficult if Jax has a particular way of handling this vs scipy or numba.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw pytensor only supports a subset of scipy formats (crs and csc, scipy has 7 formats listed). Numba supports the same formats pytensor does, but that's not a coincidence. My point is that there's room to redefine what pytensor's sparse formats should be, if it were advantageous to do so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski that's definitely true.

Still, it's unlikely that we will have a common set of types (RNG / Shared / Tuples / Whatever), that work for all backends. This PR is more focused on how we can allow specalized backend-only types and not about deciding which specific types we want to provide to users as default in PyTensor.

return BCOOType(
dtype=type.dtype,
shape=type.shape,
name=type.name,
broadcastable=type.broadcastable,
)


@jax_funcify.register(Dot)
@jax_funcify.register(StructuredDot)
def jax_funcify_sparse_dot(op, node, **kwargs):
for input in node.inputs:
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
raise NotImplementedError(
"JAX sparse dot only implemented for constant sparse inputs"
)

if isinstance(node.outputs[0].type, SparseTensorType):
raise NotImplementedError("JAX sparse dot only implemented for dense outputs")

@jsp.sparsify
def sparse_dot(x, y):
out = x @ y
if isinstance(out, jsp.BCOO):
if isinstance(out, jsp.BCOO) and not isinstance(
node.outputs[0].type, SparseTensorType
):
out = out.todense()
return out

return sparse_dot


@jax_funcify.register(Transpose)
def jax_funcify_sparse_transpose(op, **kwargs):
def sparse_transpose(x):
return x.T

return sparse_transpose
Loading