Skip to content

Commit e6861d7

Browse files
committed
Simplify Scan string representation
1 parent 2f86f79 commit e6861d7

File tree

2 files changed

+33
-42
lines changed

2 files changed

+33
-42
lines changed

pytensor/scan/op.py

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1282,27 +1282,18 @@ def __eq__(self, other):
12821282
)
12831283

12841284
def __str__(self):
1285-
device_str = "cpu"
1286-
if self.info.as_while:
1287-
name = "do_while"
1288-
else:
1289-
name = "for"
1290-
aux_txt = "%s"
1285+
inplace = "none"
12911286
if len(self.destroy_map.keys()) > 0:
12921287
# Check if all outputs are inplace
12931288
if sorted(self.destroy_map.keys()) == sorted(
12941289
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
12951290
):
1296-
aux_txt += "all_inplace,%s,%s}"
1291+
inplace = "all"
12971292
else:
1298-
aux_txt += "{inplace{"
1299-
for k in self.destroy_map.keys():
1300-
aux_txt += str(k) + ","
1301-
aux_txt += "},%s,%s}"
1302-
else:
1303-
aux_txt += "{%s,%s}"
1304-
aux_txt = aux_txt % (name, device_str, str(self.name))
1305-
return aux_txt
1293+
inplace = str(list(self.destroy_map.keys()))
1294+
return (
1295+
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
1296+
)
13061297

13071298
def __hash__(self):
13081299
return hash(

tests/scan/test_printing.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
2828

2929
expected_output = """Subtensor{i} [id A]
3030
├─ Subtensor{start:} [id B]
31-
│ ├─ for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
31+
│ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0)
3232
│ │ ├─ k [id D] (n_steps)
3333
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
3434
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
@@ -59,7 +59,7 @@ def test_debugprint_sitsot():
5959
6060
Inner graphs:
6161
62-
for{cpu,scan_fn} [id C]
62+
Scan{scan_fn, while_loop=False, inplace=none} [id C]
6363
← Mul [id W] (inner_out_sit_sot-0)
6464
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
6565
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
@@ -86,7 +86,7 @@ def test_debugprint_sitsot_no_extra_info():
8686

8787
expected_output = """Subtensor{i} [id A]
8888
├─ Subtensor{start:} [id B]
89-
│ ├─ for{cpu,scan_fn} [id C]
89+
│ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C]
9090
│ │ ├─ k [id D]
9191
│ │ ├─ SetSubtensor{:stop} [id E]
9292
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
@@ -117,7 +117,7 @@ def test_debugprint_sitsot_no_extra_info():
117117
118118
Inner graphs:
119119
120-
for{cpu,scan_fn} [id C]
120+
Scan{scan_fn, while_loop=False, inplace=none} [id C]
121121
← Mul [id W]
122122
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E]
123123
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
@@ -148,7 +148,7 @@ def test_debugprint_nitsot():
148148
lines = output_str.split("\n")
149149

150150
expected_output = """Sum{axes=None} [id A]
151-
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
151+
└─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
152152
├─ Minimum [id C] (outer_in_nit_sot-0)
153153
│ ├─ Subtensor{i} [id D]
154154
│ │ ├─ Shape [id E]
@@ -183,7 +183,7 @@ def test_debugprint_nitsot():
183183
184184
Inner graphs:
185185
186-
for{cpu,scan_fn} [id B]
186+
Scan{scan_fn, while_loop=False, inplace=none} [id B]
187187
← Mul [id X] (inner_out_nit_sot-0)
188188
├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
189189
└─ Pow [id Z]
@@ -226,7 +226,7 @@ def compute_A_k(A, k):
226226
lines = output_str.split("\n")
227227

228228
expected_output = """Sum{axes=None} [id A]
229-
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
229+
└─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
230230
├─ Minimum [id C] (outer_in_nit_sot-0)
231231
│ ├─ Subtensor{i} [id D]
232232
│ │ ├─ Shape [id E]
@@ -262,14 +262,14 @@ def compute_A_k(A, k):
262262
263263
Inner graphs:
264264
265-
for{cpu,scan_fn} [id B]
265+
Scan{scan_fn, while_loop=False, inplace=none} [id B]
266266
← Mul [id Y] (inner_out_nit_sot-0)
267267
├─ ExpandDims{axis=0} [id Z]
268268
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
269269
└─ Pow [id BB]
270270
├─ Subtensor{i} [id BC]
271271
│ ├─ Subtensor{start:} [id BD]
272-
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
272+
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
273273
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
274274
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
275275
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
@@ -300,7 +300,7 @@ def compute_A_k(A, k):
300300
└─ ExpandDims{axis=0} [id BY]
301301
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
302302
303-
for{cpu,scan_fn} [id BE]
303+
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
304304
← Mul [id CA] (inner_out_sit_sot-0)
305305
├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306306
└─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
@@ -319,7 +319,7 @@ def compute_A_k(A, k):
319319
→ k [id B]
320320
→ A [id C]
321321
Sum{axes=None} [id D] 13
322-
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
322+
└─ Scan{scan_fn, while_loop=False, inplace=none} [id E] 12 (outer_out_nit_sot-0)
323323
├─ Minimum [id F] 7 (outer_in_nit_sot-0)
324324
│ ├─ Subtensor{i} [id G] 6
325325
│ │ ├─ Shape [id H] 5
@@ -355,7 +355,7 @@ def compute_A_k(A, k):
355355
356356
Inner graphs:
357357
358-
for{cpu,scan_fn} [id E]
358+
Scan{scan_fn, while_loop=False, inplace=none} [id E]
359359
→ *0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
360360
→ *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
361361
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
@@ -366,7 +366,7 @@ def compute_A_k(A, k):
366366
└─ Pow [id BE]
367367
├─ Subtensor{i} [id BF]
368368
│ ├─ Subtensor{start:} [id BG]
369-
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
369+
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
370370
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
371371
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
372372
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
@@ -397,7 +397,7 @@ def compute_A_k(A, k):
397397
└─ ExpandDims{axis=0} [id BZ]
398398
└─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
399399
400-
for{cpu,scan_fn} [id BH]
400+
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
401401
→ *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402402
→ *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
403403
← Mul [id CC] (inner_out_sit_sot-0)
@@ -431,7 +431,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
431431

432432
expected_output = """Add [id A]
433433
├─ Subtensor{start:} [id B]
434-
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
434+
│ ├─ Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0)
435435
│ │ ├─ TensorConstant{5} [id D] (n_steps)
436436
│ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
437437
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
@@ -465,13 +465,13 @@ def fn(a_m2, a_m1, b_m2, b_m1):
465465
│ │ └─ ···
466466
│ └─ ScalarConstant{2} [id Y]
467467
└─ Subtensor{start:} [id Z]
468-
├─ for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
468+
├─ Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1)
469469
│ └─ ···
470470
└─ ScalarConstant{2} [id BA]
471471
472472
Inner graphs:
473473
474-
for{cpu,scan_fn} [id C]
474+
Scan{scan_fn, while_loop=False, inplace=none} [id C]
475475
← Add [id BB] (inner_out_mit_sot-0)
476476
├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477477
└─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
@@ -502,11 +502,11 @@ def test_debugprint_mitmot():
502502
lines = output_str.split("\n")
503503

504504
expected_output = """Subtensor{i} [id A]
505-
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
505+
├─ Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0)
506506
│ ├─ Sub [id C] (n_steps)
507507
│ │ ├─ Subtensor{i} [id D]
508508
│ │ │ ├─ Shape [id E]
509-
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
509+
│ │ │ │ └─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
510510
│ │ │ │ ├─ k [id G] (n_steps)
511511
│ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
512512
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
@@ -537,7 +537,7 @@ def test_debugprint_mitmot():
537537
│ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
538538
│ │ ├─ Subtensor{::step} [id BA]
539539
│ │ │ ├─ Subtensor{:stop} [id BB]
540-
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
540+
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
541541
│ │ │ │ │ └─ ···
542542
│ │ │ │ └─ ScalarConstant{-1} [id BC]
543543
│ │ │ └─ ScalarConstant{-1} [id BD]
@@ -547,7 +547,7 @@ def test_debugprint_mitmot():
547547
│ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
548548
│ │ ├─ Subtensor{:stop} [id BG]
549549
│ │ │ ├─ Subtensor{::step} [id BH]
550-
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
550+
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
551551
│ │ │ │ │ └─ ···
552552
│ │ │ │ └─ ScalarConstant{-1} [id BI]
553553
│ │ │ └─ ScalarConstant{-1} [id BJ]
@@ -557,14 +557,14 @@ def test_debugprint_mitmot():
557557
│ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
558558
│ │ ├─ IncSubtensor{start:} [id BM]
559559
│ │ │ ├─ Second [id BN]
560-
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
560+
│ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
561561
│ │ │ │ │ └─ ···
562562
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
563563
│ │ │ │ └─ TensorConstant{0.0} [id BP]
564564
│ │ │ ├─ IncSubtensor{i} [id BQ]
565565
│ │ │ │ ├─ Second [id BR]
566566
│ │ │ │ │ ├─ Subtensor{start:} [id BS]
567-
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
567+
│ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
568568
│ │ │ │ │ │ │ └─ ···
569569
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
570570
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
@@ -598,7 +598,7 @@ def test_debugprint_mitmot():
598598
599599
Inner graphs:
600600
601-
for{cpu,grad_of_scan_fn} [id B]
601+
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
602602
← Add [id CM] (inner_out_mit_mot-0-0)
603603
├─ Mul [id CN]
604604
│ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
@@ -610,7 +610,7 @@ def test_debugprint_mitmot():
610610
│ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611611
└─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
612612
613-
for{cpu,scan_fn} [id F]
613+
Scan{scan_fn, while_loop=False, inplace=none} [id F]
614614
← Mul [id CV] (inner_out_sit_sot-0)
615615
├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616616
└─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
@@ -641,7 +641,7 @@ def no_shared_fn(n, x_tm1, M):
641641
# (i.e. from `Scan._fn`)
642642
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
643643

644-
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
644+
expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
645645
├─ TensorConstant{20000} [id B] (n_steps)
646646
├─ TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
647647
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
@@ -653,7 +653,7 @@ def no_shared_fn(n, x_tm1, M):
653653
654654
Inner graphs:
655655
656-
forall_inplace,cpu,scan_fn} [id A]
656+
Scan{scan_fn, while_loop=False, inplace=all} [id A]
657657
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
658658
├─ TensorConstant{0} [id J]
659659
├─ Subtensor{i, j, k} [id K]

0 commit comments

Comments
 (0)