27
27
from pytensor .graph .replace import clone_replace
28
28
from pytensor .graph .rewriting .basic import in2out , node_rewriter
29
29
from pytensor .graph .utils import MissingInputError
30
- from pytensor .tensor .rewriting .shape import ShapeFeature
31
30
32
31
33
32
def infer_shape (outs , inputs , input_shapes ):
@@ -43,6 +42,10 @@ def infer_shape(outs, inputs, input_shapes):
43
42
# inside. We don't use the full ShapeFeature interface, but we
44
43
# let it initialize itself with an empty fgraph, otherwise we will
45
44
# need to do it manually
45
+
46
+ # TODO: ShapeFeature should live elsewhere
47
+ from pytensor .tensor .rewriting .shape import ShapeFeature
48
+
46
49
for inp , inp_shp in zip (inputs , input_shapes ):
47
50
if inp_shp is not None and len (inp_shp ) != inp .type .ndim :
48
51
assert len (inp_shp ) == inp .type .ndim
@@ -307,6 +310,7 @@ def __init__(
307
310
connection_pattern : list [list [bool ]] | None = None ,
308
311
strict : bool = False ,
309
312
name : str | None = None ,
313
+ destroy_map : dict [int , tuple [int , ...]] | None = None ,
310
314
** kwargs ,
311
315
):
312
316
"""
@@ -464,6 +468,7 @@ def __init__(
464
468
if name is not None :
465
469
assert isinstance (name , str ), "name must be None or string object"
466
470
self .name = name
471
+ self .destroy_map = destroy_map if destroy_map is not None else {}
467
472
468
473
def __eq__ (self , other ):
469
474
# TODO: recognize a copy
@@ -862,6 +867,7 @@ def make_node(self, *inputs):
862
867
rop_overrides = self .rop_overrides ,
863
868
connection_pattern = self ._connection_pattern ,
864
869
name = self .name ,
870
+ destroy_map = self .destroy_map ,
865
871
** self .kwargs ,
866
872
)
867
873
new_inputs = (
0 commit comments