@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
28
28
29
29
expected_output = """Subtensor{i} [id A]
30
30
|Subtensor{start:} [id B]
31
- | |for{cpu, scan_fn} [id C] (outer_out_sit_sot-0)
31
+ | |Scan{ scan_fn, while_loop=False, inplace=none } [id C] (outer_out_sit_sot-0)
32
32
| | |k [id D] (n_steps)
33
33
| | |SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
34
34
| | | |AllocEmpty{dtype='float64'} [id F]
@@ -56,7 +56,7 @@ def test_debugprint_sitsot():
56
56
57
57
Inner graphs:
58
58
59
- for{cpu, scan_fn} [id C] (outer_out_sit_sot-0)
59
+ Scan{ scan_fn, while_loop=False, inplace=none } [id C] (outer_out_sit_sot-0)
60
60
>Mul [id W] (inner_out_sit_sot-0)
61
61
> |*0-<Vector(float64, shape=(?,))> [id X] -> [id E] (inner_in_sit_sot-0)
62
62
> |*1-<Vector(float64, shape=(?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
@@ -83,7 +83,7 @@ def test_debugprint_sitsot_no_extra_info():
83
83
84
84
expected_output = """Subtensor{i} [id A]
85
85
|Subtensor{start:} [id B]
86
- | |for{cpu, scan_fn} [id C]
86
+ | |Scan{ scan_fn, while_loop=False, inplace=none } [id C]
87
87
| | |k [id D]
88
88
| | |SetSubtensor{:stop} [id E]
89
89
| | | |AllocEmpty{dtype='float64'} [id F]
@@ -111,7 +111,7 @@ def test_debugprint_sitsot_no_extra_info():
111
111
112
112
Inner graphs:
113
113
114
- for{cpu, scan_fn} [id C]
114
+ Scan{ scan_fn, while_loop=False, inplace=none } [id C]
115
115
>Mul [id W]
116
116
> |*0-<Vector(float64, shape=(?,))> [id X] -> [id E]
117
117
> |*1-<Vector(float64, shape=(?,))> [id Y] -> [id M]"""
@@ -142,7 +142,7 @@ def test_debugprint_nitsot():
142
142
lines = output_str .split ("\n " )
143
143
144
144
expected_output = """Sum{axes=None} [id A]
145
- |for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
145
+ |Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
146
146
|Minimum [id C] (outer_in_nit_sot-0)
147
147
| |Subtensor{i} [id D]
148
148
| | |Shape [id E]
@@ -172,7 +172,7 @@ def test_debugprint_nitsot():
172
172
173
173
Inner graphs:
174
174
175
- for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
175
+ Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
176
176
>Mul [id X] (inner_out_nit_sot-0)
177
177
> |*0-<Scalar(float64, shape=())> [id Y] -> [id S] (inner_in_seqs-0)
178
178
> |Pow [id Z]
@@ -215,7 +215,7 @@ def compute_A_k(A, k):
215
215
lines = output_str .split ("\n " )
216
216
217
217
expected_output = """Sum{axes=None} [id A]
218
- |for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
218
+ |Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
219
219
|Minimum [id C] (outer_in_nit_sot-0)
220
220
| |Subtensor{i} [id D]
221
221
| | |Shape [id E]
@@ -246,14 +246,14 @@ def compute_A_k(A, k):
246
246
247
247
Inner graphs:
248
248
249
- for{cpu, scan_fn} [id B] (outer_out_nit_sot-0)
249
+ Scan{ scan_fn, while_loop=False, inplace=none } [id B] (outer_out_nit_sot-0)
250
250
>Mul [id Y] (inner_out_nit_sot-0)
251
251
> |ExpandDims{axis=0} [id Z]
252
252
> | |*0-<Scalar(float64, shape=())> [id BA] -> [id S] (inner_in_seqs-0)
253
253
> |Pow [id BB]
254
254
> |Subtensor{i} [id BC]
255
255
> | |Subtensor{start:} [id BD]
256
- > | | |for{cpu, scan_fn} [id BE] (outer_out_sit_sot-0)
256
+ > | | |Scan{ scan_fn, while_loop=False, inplace=none } [id BE] (outer_out_sit_sot-0)
257
257
> | | | |*3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
258
258
> | | | |SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
259
259
> | | | | |AllocEmpty{dtype='float64'} [id BH]
@@ -281,7 +281,7 @@ def compute_A_k(A, k):
281
281
> |ExpandDims{axis=0} [id BY]
282
282
> |*1-<Scalar(int64, shape=())> [id BZ] -> [id U] (inner_in_seqs-1)
283
283
284
- for{cpu, scan_fn} [id BE] (outer_out_sit_sot-0)
284
+ Scan{ scan_fn, while_loop=False, inplace=none } [id BE] (outer_out_sit_sot-0)
285
285
>Mul [id CA] (inner_out_sit_sot-0)
286
286
> |*0-<Vector(float64, shape=(?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
287
287
> |*1-<Vector(float64, shape=(?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
@@ -300,7 +300,7 @@ def compute_A_k(A, k):
300
300
-k [id B]
301
301
-A [id C]
302
302
Sum{axes=None} [id D] 13
303
- |for{cpu, scan_fn} [id E] 12 (outer_out_nit_sot-0)
303
+ |Scan{ scan_fn, while_loop=False, inplace=none } [id E] 12 (outer_out_nit_sot-0)
304
304
|Minimum [id F] 7 (outer_in_nit_sot-0)
305
305
| |Subtensor{i} [id G] 6
306
306
| | |Shape [id H] 5
@@ -331,7 +331,7 @@ def compute_A_k(A, k):
331
331
332
332
Inner graphs:
333
333
334
- for{cpu, scan_fn} [id E] (outer_out_nit_sot-0)
334
+ Scan{ scan_fn, while_loop=False, inplace=none } [id E] (outer_out_nit_sot-0)
335
335
-*0-<Scalar(float64, shape=())> [id Y] -> [id U] (inner_in_seqs-0)
336
336
-*1-<Scalar(int64, shape=())> [id Z] -> [id W] (inner_in_seqs-1)
337
337
-*2-<Vector(float64, shape=(?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
@@ -342,7 +342,7 @@ def compute_A_k(A, k):
342
342
> |Pow [id BE]
343
343
> |Subtensor{i} [id BF]
344
344
> | |Subtensor{start:} [id BG]
345
- > | | |for{cpu, scan_fn} [id BH] (outer_out_sit_sot-0)
345
+ > | | |Scan{ scan_fn, while_loop=False, inplace=none } [id BH] (outer_out_sit_sot-0)
346
346
> | | | |*3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1) (n_steps)
347
347
> | | | |SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
348
348
> | | | | |AllocEmpty{dtype='float64'} [id BJ]
@@ -370,7 +370,7 @@ def compute_A_k(A, k):
370
370
> |ExpandDims{axis=0} [id BZ]
371
371
> |*1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
372
372
373
- for{cpu, scan_fn} [id BH] (outer_out_sit_sot-0)
373
+ Scan{ scan_fn, while_loop=False, inplace=none } [id BH] (outer_out_sit_sot-0)
374
374
-*0-<Vector(float64, shape=(?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
375
375
-*1-<Vector(float64, shape=(?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
376
376
>Mul [id CC] (inner_out_sit_sot-0)
@@ -404,7 +404,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
404
404
405
405
expected_output = """Add [id A]
406
406
|Subtensor{start:} [id B]
407
- | |for{cpu, scan_fn}.0 [id C] (outer_out_mit_sot-0)
407
+ | |Scan{ scan_fn, while_loop=False, inplace=none }.0 [id C] (outer_out_mit_sot-0)
408
408
| | |TensorConstant{5} [id D] (n_steps)
409
409
| | |SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
410
410
| | | |AllocEmpty{dtype='int64'} [id F]
@@ -434,20 +434,20 @@ def fn(a_m2, a_m1, b_m2, b_m1):
434
434
| | |Subtensor{i} [id R]
435
435
| |ScalarConstant{2} [id Y]
436
436
|Subtensor{start:} [id Z]
437
- |for{cpu, scan_fn}.1 [id C] (outer_out_mit_sot-1)
437
+ |Scan{ scan_fn, while_loop=False, inplace=none }.1 [id C] (outer_out_mit_sot-1)
438
438
|ScalarConstant{2} [id BA]
439
439
440
440
Inner graphs:
441
441
442
- for{cpu, scan_fn}.0 [id C] (outer_out_mit_sot-0)
442
+ Scan{ scan_fn, while_loop=False, inplace=none }.0 [id C] (outer_out_mit_sot-0)
443
443
>Add [id BB] (inner_out_mit_sot-0)
444
444
> |*1-<Scalar(int64, shape=())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
445
445
> |*0-<Scalar(int64, shape=())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
446
446
>Add [id BE] (inner_out_mit_sot-1)
447
447
> |*3-<Scalar(int64, shape=())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
448
448
> |*2-<Scalar(int64, shape=())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
449
449
450
- for{cpu, scan_fn}.1 [id C] (outer_out_mit_sot-1)
450
+ Scan{ scan_fn, while_loop=False, inplace=none }.1 [id C] (outer_out_mit_sot-1)
451
451
>Add [id BB] (inner_out_mit_sot-0)
452
452
>Add [id BE] (inner_out_mit_sot-1)"""
453
453
@@ -474,11 +474,11 @@ def test_debugprint_mitmot():
474
474
lines = output_str .split ("\n " )
475
475
476
476
expected_output = """Subtensor{i} [id A]
477
- |for{cpu, grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
477
+ |Scan{ grad_of_scan_fn, while_loop=False, inplace=none }.1 [id B] (outer_out_sit_sot-0)
478
478
| |Sub [id C] (n_steps)
479
479
| | |Subtensor{i} [id D]
480
480
| | | |Shape [id E]
481
- | | | | |for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
481
+ | | | | |Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
482
482
| | | | |k [id G] (n_steps)
483
483
| | | | |SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
484
484
| | | | | |AllocEmpty{dtype='float64'} [id I]
@@ -506,29 +506,29 @@ def test_debugprint_mitmot():
506
506
| |Subtensor{:stop} [id Z] (outer_in_seqs-0)
507
507
| | |Subtensor{::step} [id BA]
508
508
| | | |Subtensor{:stop} [id BB]
509
- | | | | |for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
509
+ | | | | |Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
510
510
| | | | |ScalarConstant{-1} [id BC]
511
511
| | | |ScalarConstant{-1} [id BD]
512
512
| | |ScalarFromTensor [id BE]
513
513
| | |Sub [id C]
514
514
| |Subtensor{:stop} [id BF] (outer_in_seqs-1)
515
515
| | |Subtensor{:stop} [id BG]
516
516
| | | |Subtensor{::step} [id BH]
517
- | | | | |for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
517
+ | | | | |Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
518
518
| | | | |ScalarConstant{-1} [id BI]
519
519
| | | |ScalarConstant{-1} [id BJ]
520
520
| | |ScalarFromTensor [id BK]
521
521
| | |Sub [id C]
522
522
| |Subtensor{::step} [id BL] (outer_in_mit_mot-0)
523
523
| | |IncSubtensor{start:} [id BM]
524
524
| | | |Second [id BN]
525
- | | | | |for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
525
+ | | | | |Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
526
526
| | | | |ExpandDims{axes=[0, 1]} [id BO]
527
527
| | | | |TensorConstant{0.0} [id BP]
528
528
| | | |IncSubtensor{i} [id BQ]
529
529
| | | | |Second [id BR]
530
530
| | | | | |Subtensor{start:} [id BS]
531
- | | | | | | |for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
531
+ | | | | | | |Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
532
532
| | | | | | |ScalarConstant{1} [id BT]
533
533
| | | | | |ExpandDims{axes=[0, 1]} [id BU]
534
534
| | | | | |TensorConstant{0.0} [id BV]
@@ -558,7 +558,7 @@ def test_debugprint_mitmot():
558
558
559
559
Inner graphs:
560
560
561
- for{cpu, grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
561
+ Scan{ grad_of_scan_fn, while_loop=False, inplace=none }.1 [id B] (outer_out_sit_sot-0)
562
562
>Add [id CM] (inner_out_mit_mot-0-0)
563
563
> |Mul [id CN]
564
564
> | |*2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
@@ -570,7 +570,7 @@ def test_debugprint_mitmot():
570
570
> | |*0-<Vector(float64, shape=(?,))> [id CT] -> [id Z] (inner_in_seqs-0)
571
571
> |*4-<Vector(float64, shape=(?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
572
572
573
- for{cpu, scan_fn} [id F] (outer_out_sit_sot-0)
573
+ Scan{ scan_fn, while_loop=False, inplace=none } [id F] (outer_out_sit_sot-0)
574
574
>Mul [id CV] (inner_out_sit_sot-0)
575
575
> |*0-<Vector(float64, shape=(?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
576
576
> |*1-<Vector(float64, shape=(?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
@@ -601,7 +601,7 @@ def no_shared_fn(n, x_tm1, M):
601
601
# (i.e. from `Scan._fn`)
602
602
out = pytensor .function ([M ], out , updates = updates , mode = "FAST_RUN" )
603
603
604
- expected_output = """forall_inplace,cpu,scan_fn } [id A] 2 (outer_out_sit_sot-0)
604
+ expected_output = """Scan{scan_fn, while_loop=False, inplace=all } [id A] 2 (outer_out_sit_sot-0)
605
605
|TensorConstant{20000} [id B] (n_steps)
606
606
|TensorConstant{[ 0 ... 998 19999]} [id C] (outer_in_seqs-0)
607
607
|SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
@@ -613,7 +613,7 @@ def no_shared_fn(n, x_tm1, M):
613
613
614
614
Inner graphs:
615
615
616
- forall_inplace,cpu,scan_fn } [id A] (outer_out_sit_sot-0)
616
+ Scan{scan_fn, while_loop=False, inplace=all } [id A] (outer_out_sit_sot-0)
617
617
>Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
618
618
> |TensorConstant{0} [id J]
619
619
> |Subtensor{i, j, k} [id K]
0 commit comments