Skip to content

Commit 3b0ae7f

Browse files
ricardoV94aseyboldtjessegrabowski
committed
Add support for RandomVariable with Generators in Numba backend and drop support for RandomState
Co-authored-by: Adrian Seyboldt <[email protected]> Co-authored-by: Jesse Grabowski <[email protected]>
1 parent d3aa283 commit 3b0ae7f

File tree

16 files changed

+689
-661
lines changed

16 files changed

+689
-661
lines changed

pytensor/compile/builders.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pytensor.graph.replace import clone_replace
2828
from pytensor.graph.rewriting.basic import in2out, node_rewriter
2929
from pytensor.graph.utils import MissingInputError
30-
from pytensor.tensor.rewriting.shape import ShapeFeature
3130

3231

3332
def infer_shape(outs, inputs, input_shapes):
@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
4342
# inside. We don't use the full ShapeFeature interface, but we
4443
# let it initialize itself with an empty fgraph, otherwise we will
4544
# need to do it manually
45+
46+
# TODO: ShapeFeature should live elsewhere
47+
from pytensor.tensor.rewriting.shape import ShapeFeature
48+
4649
for inp, inp_shp in zip(inputs, input_shapes):
4750
if inp_shp is not None and len(inp_shp) != inp.type.ndim:
4851
assert len(inp_shp) == inp.type.ndim
@@ -307,6 +310,7 @@ def __init__(
307310
connection_pattern: list[list[bool]] | None = None,
308311
strict: bool = False,
309312
name: str | None = None,
313+
destroy_map: dict[int, tuple[int, ...]] | None = None,
310314
**kwargs,
311315
):
312316
"""
@@ -464,6 +468,7 @@ def __init__(
464468
if name is not None:
465469
assert isinstance(name, str), "name must be None or string object"
466470
self.name = name
471+
self.destroy_map = destroy_map if destroy_map is not None else {}
467472

468473
def __eq__(self, other):
469474
# TODO: recognize a copy
@@ -862,6 +867,7 @@ def make_node(self, *inputs):
862867
rop_overrides=self.rop_overrides,
863868
connection_pattern=self._connection_pattern,
864869
name=self.name,
870+
destroy_map=self.destroy_map,
865871
**self.kwargs,
866872
)
867873
new_inputs = (

pytensor/compile/mode.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
463463
NUMBA = Mode(
464464
NumbaLinker(),
465465
RewriteDatabaseQuery(
466-
include=["fast_run"],
466+
include=["fast_run", "numba"],
467467
exclude=["cxx_only", "BlasOpt", "local_careduce_fusion"],
468468
),
469469
)

pytensor/link/numba/dispatch/basic.py

+6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from numba.extending import box, overload
1919

2020
from pytensor import config
21+
from pytensor.compile import NUMBA
2122
from pytensor.compile.builders import OpFromGraph
2223
from pytensor.compile.ops import DeepCopyOp
2324
from pytensor.graph.basic import Apply
@@ -440,6 +441,11 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
440441
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
441442
_ = kwargs.pop("storage_map", None)
442443

444+
# Apply inner rewrites
445+
# TODO: Not sure this is the right place to do this, should we have a rewrite that
446+
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
447+
# The C-code defers it to the make_thunk phase
448+
NUMBA.optimizer(op.fgraph)
443449
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
444450

445451
if len(op.fgraph.outputs) == 1:

0 commit comments

Comments
 (0)