Skip to content

Commit 762bcad

Browse files
ricardoV94jessegrabowskiaseyboldt
committed
Adapt Elemwise iterator for Numba Generators
Also drops support for RandomState Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent bae694d commit 762bcad

File tree

14 files changed

+752
-894
lines changed

14 files changed

+752
-894
lines changed

pytensor/link/jax/dispatch/random.py

+1-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import jax
44
import numpy as np
5-
from numpy.random import Generator, RandomState
5+
from numpy.random import Generator
66
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
77
_coerce_to_uint32_array,
88
)
@@ -52,15 +52,6 @@ def assert_size_argument_jax_compatible(node):
5252
raise NotImplementedError(SIZE_NOT_COMPATIBLE)
5353

5454

55-
@jax_typify.register(RandomState)
56-
def jax_typify_RandomState(state, **kwargs):
57-
state = state.get_state(legacy=False)
58-
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
59-
# XXX: Is this a reasonable approach?
60-
state["jax_state"] = state["state"]["key"][0:2]
61-
return state
62-
63-
6455
@jax_typify.register(Generator)
6556
def jax_typify_Generator(rng, **kwargs):
6657
state = rng.__getstate__()
@@ -185,7 +176,6 @@ def sample_fn(rng, size, dtype, *parameters):
185176
return sample_fn
186177

187178

188-
@jax_sample_fn.register(ptr.RandIntRV)
189179
@jax_sample_fn.register(ptr.IntegersRV)
190180
@jax_sample_fn.register(ptr.UniformRV)
191181
def jax_sample_fn_uniform(op):

pytensor/link/numba/dispatch/elemwise.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.link.numba.dispatch.vectorize_codegen import (
2424
_vectorized,
2525
encode_literals,
26+
store_core_outputs,
2627
)
2728
from pytensor.link.utils import compile_function_src, get_name_for_object
2829
from pytensor.scalar.basic import (
@@ -483,10 +484,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
483484
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
484485
)
485486

487+
nin = len(node.inputs)
488+
nout = len(node.outputs)
489+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
490+
486491
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
487492
output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])
488493
output_dtypes = tuple(out.type.dtype for out in node.outputs)
489494
inplace_pattern = tuple(op.inplace_pattern.items())
495+
core_output_shapes = tuple(() for _ in range(nout))
490496

491497
# numba doesn't support nested literals right now...
492498
input_bc_patterns_enc = encode_literals(input_bc_patterns)
@@ -496,12 +502,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
496502

497503
def elemwise_wrapper(*inputs):
498504
return _vectorized(
499-
scalar_op_fn,
505+
core_op_fn,
500506
input_bc_patterns_enc,
501507
output_bc_patterns_enc,
502508
output_dtypes_enc,
503509
inplace_pattern_enc,
510+
(), # constant_inputs
504511
inputs,
512+
core_output_shapes, # core_shapes
513+
None, # size
505514
)
506515

507516
# Pure python implementation, that will be used in tests

0 commit comments

Comments
 (0)