@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
28
28
29
29
expected_output = """Subtensor{i} [id A]
30
30
├─ Subtensor{start:} [id B]
31
- │ ├─ for{cpu, scan_fn} [id C] (outer_out_sit_sot-0)
31
+ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id C] (outer_out_sit_sot-0)
32
32
│ │ ├─ k [id D] (n_steps)
33
33
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
34
34
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
@@ -59,7 +59,7 @@ def test_debugprint_sitsot():
59
59
60
60
Inner graphs:
61
61
62
- for{cpu, scan_fn} [id C]
62
+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
63
63
← Mul [id W] (inner_out_sit_sot-0)
64
64
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
65
65
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
@@ -86,7 +86,7 @@ def test_debugprint_sitsot_no_extra_info():
86
86
87
87
expected_output = """Subtensor{i} [id A]
88
88
├─ Subtensor{start:} [id B]
89
- │ ├─ for{cpu, scan_fn} [id C]
89
+ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
90
90
│ │ ├─ k [id D]
91
91
│ │ ├─ SetSubtensor{:stop} [id E]
92
92
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
@@ -117,7 +117,7 @@ def test_debugprint_sitsot_no_extra_info():
117
117
118
118
Inner graphs:
119
119
120
- for{cpu, scan_fn} [id C]
120
+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
121
121
← Mul [id W]
122
122
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E]
123
123
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
@@ -148,7 +148,7 @@ def test_debugprint_nitsot():
148
148
lines = output_str .split ("\n " )
149
149
150
150
expected_output = """Sum{axes=None} [id A]
151
- └─ for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
151
+ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
152
152
├─ Minimum [id C] (outer_in_nit_sot-0)
153
153
│ ├─ Subtensor{i} [id D]
154
154
│ │ ├─ Shape [id E]
@@ -183,7 +183,7 @@ def test_debugprint_nitsot():
183
183
184
184
Inner graphs:
185
185
186
- for{cpu, scan_fn} [id B]
186
+ Scan{ scan_fn, while_loop=False, inplace=none } [id B]
187
187
← Mul [id X] (inner_out_nit_sot-0)
188
188
├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
189
189
└─ Pow [id Z]
@@ -226,7 +226,7 @@ def compute_A_k(A, k):
226
226
lines = output_str .split ("\n " )
227
227
228
228
expected_output = """Sum{axes=None} [id A]
229
- └─ for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
229
+ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
230
230
├─ Minimum [id C] (outer_in_nit_sot-0)
231
231
│ ├─ Subtensor{i} [id D]
232
232
│ │ ├─ Shape [id E]
@@ -262,14 +262,14 @@ def compute_A_k(A, k):
262
262
263
263
Inner graphs:
264
264
265
- for{cpu, scan_fn} [id B]
265
+ Scan{ scan_fn, while_loop=False, inplace=none } [id B]
266
266
← Mul [id Y] (inner_out_nit_sot-0)
267
267
├─ ExpandDims{axis=0} [id Z]
268
268
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
269
269
└─ Pow [id BB]
270
270
├─ Subtensor{i} [id BC]
271
271
│ ├─ Subtensor{start:} [id BD]
272
- │ │ ├─ for{cpu, scan_fn} [id BE] (outer_out_sit_sot-0)
272
+ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id BE] (outer_out_sit_sot-0)
273
273
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
274
274
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
275
275
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
@@ -300,7 +300,7 @@ def compute_A_k(A, k):
300
300
└─ ExpandDims{axis=0} [id BY]
301
301
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
302
302
303
- for{cpu, scan_fn} [id BE]
303
+ Scan{ scan_fn, while_loop=False, inplace=none } [id BE]
304
304
← Mul [id CA] (inner_out_sit_sot-0)
305
305
├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306
306
└─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
@@ -319,7 +319,7 @@ def compute_A_k(A, k):
319
319
→ k [id B]
320
320
→ A [id C]
321
321
Sum{axes=None} [id D] 13
322
- └─ for{cpu, scan_fn} [id E] 12 (outer_out_nit_sot-0)
322
+ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id E] 12 (outer_out_nit_sot-0)
323
323
├─ Minimum [id F] 7 (outer_in_nit_sot-0)
324
324
│ ├─ Subtensor{i} [id G] 6
325
325
│ │ ├─ Shape [id H] 5
@@ -355,7 +355,7 @@ def compute_A_k(A, k):
355
355
356
356
Inner graphs:
357
357
358
- for{cpu, scan_fn} [id E]
358
+ Scan{ scan_fn, while_loop=False, inplace=none } [id E]
359
359
→ *0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
360
360
→ *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
361
361
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
@@ -366,7 +366,7 @@ def compute_A_k(A, k):
366
366
└─ Pow [id BE]
367
367
├─ Subtensor{i} [id BF]
368
368
│ ├─ Subtensor{start:} [id BG]
369
- │ │ ├─ for{cpu, scan_fn} [id BH] (outer_out_sit_sot-0)
369
+ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id BH] (outer_out_sit_sot-0)
370
370
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
371
371
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
372
372
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
@@ -397,7 +397,7 @@ def compute_A_k(A, k):
397
397
└─ ExpandDims{axis=0} [id BZ]
398
398
└─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
399
399
400
- for{cpu, scan_fn} [id BH]
400
+ Scan{ scan_fn, while_loop=False, inplace=none } [id BH]
401
401
→ *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402
402
→ *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
403
403
← Mul [id CC] (inner_out_sit_sot-0)
@@ -431,7 +431,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
431
431
432
432
expected_output = """Add [id A]
433
433
├─ Subtensor{start:} [id B]
434
- │ ├─ for{cpu, scan_fn}.0 [id C] (outer_out_mit_sot-0)
434
+ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none }.0 [id C] (outer_out_mit_sot-0)
435
435
│ │ ├─ TensorConstant{5} [id D] (n_steps)
436
436
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
437
437
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
@@ -465,13 +465,13 @@ def fn(a_m2, a_m1, b_m2, b_m1):
465
465
│ │ └─ ···
466
466
│ └─ ScalarConstant{2} [id Y]
467
467
└─ Subtensor{start:} [id Z]
468
- ├─ for{cpu, scan_fn}.1 [id C] (outer_out_mit_sot-1)
468
+ ├─ Scan{ scan_fn, while_loop=False, inplace=none }.1 [id C] (outer_out_mit_sot-1)
469
469
│ └─ ···
470
470
└─ ScalarConstant{2} [id BA]
471
471
472
472
Inner graphs:
473
473
474
- for{cpu, scan_fn} [id C]
474
+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
475
475
← Add [id BB] (inner_out_mit_sot-0)
476
476
├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477
477
└─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
@@ -502,11 +502,11 @@ def test_debugprint_mitmot():
502
502
lines = output_str .split ("\n " )
503
503
504
504
expected_output = """Subtensor{i} [id A]
505
- ├─ for{cpu, grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
505
+ ├─ Scan{ grad_of_scan_fn, while_loop=False, inplace=none }.1 [id B] (outer_out_sit_sot-0)
506
506
│ ├─ Sub [id C] (n_steps)
507
507
│ │ ├─ Subtensor{i} [id D]
508
508
│ │ │ ├─ Shape [id E]
509
- │ │ │ │ └─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
509
+ │ │ │ │ └─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
510
510
│ │ │ │ ├─ k [id G] (n_steps)
511
511
│ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
512
512
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
@@ -537,7 +537,7 @@ def test_debugprint_mitmot():
537
537
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
538
538
│ │ ├─ Subtensor{::step} [id BA]
539
539
│ │ │ ├─ Subtensor{:stop} [id BB]
540
- │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
540
+ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
541
541
│ │ │ │ │ └─ ···
542
542
│ │ │ │ └─ ScalarConstant{-1} [id BC]
543
543
│ │ │ └─ ScalarConstant{-1} [id BD]
@@ -547,7 +547,7 @@ def test_debugprint_mitmot():
547
547
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
548
548
│ │ ├─ Subtensor{:stop} [id BG]
549
549
│ │ │ ├─ Subtensor{::step} [id BH]
550
- │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
550
+ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
551
551
│ │ │ │ │ └─ ···
552
552
│ │ │ │ └─ ScalarConstant{-1} [id BI]
553
553
│ │ │ └─ ScalarConstant{-1} [id BJ]
@@ -557,14 +557,14 @@ def test_debugprint_mitmot():
557
557
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
558
558
│ │ ├─ IncSubtensor{start:} [id BM]
559
559
│ │ │ ├─ Second [id BN]
560
- │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
560
+ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
561
561
│ │ │ │ │ └─ ···
562
562
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
563
563
│ │ │ │ └─ TensorConstant{0.0} [id BP]
564
564
│ │ │ ├─ IncSubtensor{i} [id BQ]
565
565
│ │ │ │ ├─ Second [id BR]
566
566
│ │ │ │ │ ├─ Subtensor{start:} [id BS]
567
- │ │ │ │ │ │ ├─ for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
567
+ │ │ │ │ │ │ ├─ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
568
568
│ │ │ │ │ │ │ └─ ···
569
569
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
570
570
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
@@ -598,7 +598,7 @@ def test_debugprint_mitmot():
598
598
599
599
Inner graphs:
600
600
601
- for{cpu, grad_of_scan_fn} [id B]
601
+ Scan{ grad_of_scan_fn, while_loop=False, inplace=none } [id B]
602
602
← Add [id CM] (inner_out_mit_mot-0-0)
603
603
├─ Mul [id CN]
604
604
│ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
@@ -610,7 +610,7 @@ def test_debugprint_mitmot():
610
610
│ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611
611
└─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
612
612
613
- for{cpu, scan_fn} [id F]
613
+ Scan{ scan_fn, while_loop=False, inplace=none } [id F]
614
614
← Mul [id CV] (inner_out_sit_sot-0)
615
615
├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616
616
└─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
@@ -641,7 +641,7 @@ def no_shared_fn(n, x_tm1, M):
641
641
# (i.e. from `Scan._fn`)
642
642
out = pytensor .function ([M ], out , updates = updates , mode = "FAST_RUN" )
643
643
644
- expected_output = """forall_inplace,cpu,scan_fn } [id A] 2 (outer_out_sit_sot-0)
644
+ expected_output = """Scan{scan_fn, while_loop=False, inplace=all } [id A] 2 (outer_out_sit_sot-0)
645
645
├─ TensorConstant{20000} [id B] (n_steps)
646
646
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
647
647
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
@@ -653,7 +653,7 @@ def no_shared_fn(n, x_tm1, M):
653
653
654
654
Inner graphs:
655
655
656
- forall_inplace,cpu,scan_fn } [id A]
656
+ Scan{scan_fn, while_loop=False, inplace=all } [id A]
657
657
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
658
658
├─ TensorConstant{0} [id J]
659
659
├─ Subtensor{i, j, k} [id K]
0 commit comments