@@ -643,35 +643,37 @@ def no_shared_fn(n, x_tm1, M):
643
643
# (i.e. from `Scan._fn`)
644
644
out = pytensor .function ([M ], out , updates = updates , mode = "FAST_RUN" )
645
645
646
- expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
647
- ├─ 20000 [id B] (n_steps)
648
- ├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0)
649
- ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
650
- │ ├─ AllocEmpty{dtype='int64'} [id E] 0
651
- │ │ └─ 20000 [id B]
652
- │ ├─ [0] [id F]
653
- │ └─ 1 [id G]
654
- └─ <Tensor3(float64, shape=(20000, 2, 2))> [id H] (outer_in_non_seqs-0)
655
-
656
- Inner graphs:
657
-
658
- Scan{scan_fn, while_loop=False, inplace=all} [id A]
659
- ← Composite{switch(lt(0, i0), 1, 0)} [id I] (inner_out_sit_sot-0)
660
- └─ Subtensor{i, j, k} [id J]
661
- ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id K] -> [id H] (inner_in_non_seqs-0)
662
- ├─ ScalarFromTensor [id L]
663
- │ └─ *0-<Scalar(int64, shape=())> [id M] -> [id C] (inner_in_seqs-0)
664
- ├─ ScalarFromTensor [id N]
665
- │ └─ *1-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_sit_sot-0)
666
- └─ 0 [id P]
667
-
668
- Composite{switch(lt(0, i0), 1, 0)} [id I]
669
- ← Switch [id Q] 'o0'
670
- ├─ LT [id R]
671
- │ ├─ 0 [id S]
672
- │ └─ i0 [id T]
673
- ├─ 1 [id U]
674
- └─ 0 [id S]
646
+ expected_output = """Subtensor{start:} [id A] 3
647
+ ├─ Scan{scan_fn, while_loop=False, inplace=all} [id B] 2 (outer_out_sit_sot-0)
648
+ │ ├─ 20000 [id C] (n_steps)
649
+ │ ├─ [ 0 ... 998 19999] [id D] (outer_in_seqs-0)
650
+ │ ├─ SetSubtensor{:stop} [id E] 1 (outer_in_sit_sot-0)
651
+ │ │ ├─ AllocEmpty{dtype='int64'} [id F] 0
652
+ │ │ │ └─ 20001 [id G]
653
+ │ │ ├─ [0] [id H]
654
+ │ │ └─ 1 [id I]
655
+ │ └─ <Tensor3(float64, shape=(20000, 2, 2))> [id J] (outer_in_non_seqs-0)
656
+ └─ 1 [id I]
657
+
658
+ Inner graphs:
659
+
660
+ Scan{scan_fn, while_loop=False, inplace=all} [id B]
661
+ ← Composite{switch(lt(0, i0), 1, 0)} [id K] (inner_out_sit_sot-0)
662
+ └─ Subtensor{i, j, k} [id L]
663
+ ├─ *2-<Tensor3(float64, shape=(20000, 2, 2))> [id M] -> [id J] (inner_in_non_seqs-0)
664
+ ├─ ScalarFromTensor [id N]
665
+ │ └─ *0-<Scalar(int64, shape=())> [id O] -> [id D] (inner_in_seqs-0)
666
+ ├─ ScalarFromTensor [id P]
667
+ │ └─ *1-<Scalar(int64, shape=())> [id Q] -> [id E] (inner_in_sit_sot-0)
668
+ └─ 0 [id R]
669
+
670
+ Composite{switch(lt(0, i0), 1, 0)} [id K]
671
+ ← Switch [id S] 'o0'
672
+ ├─ LT [id T]
673
+ │ ├─ 0 [id U]
674
+ │ └─ i0 [id V]
675
+ ├─ 1 [id W]
676
+ └─ 0 [id U]
675
677
"""
676
678
677
679
output_str = debugprint (out , file = "str" , print_op_info = True )
0 commit comments