Skip to content

Commit 29a630b

Browse files
committed
Simplify Scan string representation
1 parent c2cee45 commit 29a630b

File tree

2 files changed

+34
-43
lines changed

2 files changed

+34
-43
lines changed

pytensor/scan/op.py

+6-15
Original file line numberDiff line numberDiff line change
@@ -1282,27 +1282,18 @@ def __eq__(self, other):
12821282
)
12831283

12841284
def __str__(self):
1285-
device_str = "cpu"
1286-
if self.info.as_while:
1287-
name = "do_while"
1288-
else:
1289-
name = "for"
1290-
aux_txt = "%s"
1285+
inplace = "none"
12911286
if len(self.destroy_map.keys()) > 0:
12921287
# Check if all outputs are inplace
12931288
if sorted(self.destroy_map.keys()) == sorted(
12941289
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
12951290
):
1296-
aux_txt += "all_inplace,%s,%s}"
1291+
inplace = "all"
12971292
else:
1298-
aux_txt += "{inplace{"
1299-
for k in self.destroy_map.keys():
1300-
aux_txt += str(k) + ","
1301-
aux_txt += "},%s,%s}"
1302-
else:
1303-
aux_txt += "{%s,%s}"
1304-
aux_txt = aux_txt % (name, device_str, str(self.name))
1305-
return aux_txt
1293+
inplace = str(list(self.destroy_map.keys()))
1294+
return (
1295+
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
1296+
)
13061297

13071298
def __hash__(self):
13081299
return hash(

tests/scan/test_printing.py

+28-28
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def test_debugprint_sitsot():
2828

2929
expected_output = """Subtensor{i} [id A]
3030
|Subtensor{start:} [id B]
31-
| |for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
31+
| |Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0)
3232
| | |k [id D] (n_steps)
3333
| | |SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
3434
| | | |AllocEmpty{dtype='float64'} [id F]
@@ -56,7 +56,7 @@ def test_debugprint_sitsot():
5656
5757
Inner graphs:
5858
59-
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
59+
Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0)
6060
>Mul [id W] (inner_out_sit_sot-0)
6161
> |*0-<Vector(float64, shape=(?,))> [id X] -> [id E] (inner_in_sit_sot-0)
6262
> |*1-<Vector(float64, shape=(?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
@@ -83,7 +83,7 @@ def test_debugprint_sitsot_no_extra_info():
8383

8484
expected_output = """Subtensor{i} [id A]
8585
|Subtensor{start:} [id B]
86-
| |for{cpu,scan_fn} [id C]
86+
| |Scan{scan_fn, while_loop=False, inplace=none} [id C]
8787
| | |k [id D]
8888
| | |SetSubtensor{:stop} [id E]
8989
| | | |AllocEmpty{dtype='float64'} [id F]
@@ -111,7 +111,7 @@ def test_debugprint_sitsot_no_extra_info():
111111
112112
Inner graphs:
113113
114-
for{cpu,scan_fn} [id C]
114+
Scan{scan_fn, while_loop=False, inplace=none} [id C]
115115
>Mul [id W]
116116
> |*0-<Vector(float64, shape=(?,))> [id X] -> [id E]
117117
> |*1-<Vector(float64, shape=(?,))> [id Y] -> [id M]"""
@@ -142,7 +142,7 @@ def test_debugprint_nitsot():
142142
lines = output_str.split("\n")
143143

144144
expected_output = """Sum{axes=None} [id A]
145-
|for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
145+
|Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
146146
|Minimum [id C] (outer_in_nit_sot-0)
147147
| |Subtensor{i} [id D]
148148
| | |Shape [id E]
@@ -172,7 +172,7 @@ def test_debugprint_nitsot():
172172
173173
Inner graphs:
174174
175-
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
175+
Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
176176
>Mul [id X] (inner_out_nit_sot-0)
177177
> |*0-<Scalar(float64, shape=())> [id Y] -> [id S] (inner_in_seqs-0)
178178
> |Pow [id Z]
@@ -215,7 +215,7 @@ def compute_A_k(A, k):
215215
lines = output_str.split("\n")
216216

217217
expected_output = """Sum{axes=None} [id A]
218-
|for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
218+
|Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
219219
|Minimum [id C] (outer_in_nit_sot-0)
220220
| |Subtensor{i} [id D]
221221
| | |Shape [id E]
@@ -246,14 +246,14 @@ def compute_A_k(A, k):
246246
247247
Inner graphs:
248248
249-
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
249+
Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
250250
>Mul [id Y] (inner_out_nit_sot-0)
251251
> |ExpandDims{axis=0} [id Z]
252252
> | |*0-<Scalar(float64, shape=())> [id BA] -> [id S] (inner_in_seqs-0)
253253
> |Pow [id BB]
254254
> |Subtensor{i} [id BC]
255255
> | |Subtensor{start:} [id BD]
256-
> | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
256+
> | | |Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
257257
> | | | |*3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
258258
> | | | |SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
259259
> | | | | |AllocEmpty{dtype='float64'} [id BH]
@@ -281,7 +281,7 @@ def compute_A_k(A, k):
281281
> |ExpandDims{axis=0} [id BY]
282282
> |*1-<Scalar(int64, shape=())> [id BZ] -> [id U] (inner_in_seqs-1)
283283
284-
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
284+
Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
285285
>Mul [id CA] (inner_out_sit_sot-0)
286286
> |*0-<Vector(float64, shape=(?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
287287
> |*1-<Vector(float64, shape=(?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
@@ -300,7 +300,7 @@ def compute_A_k(A, k):
300300
-k [id B]
301301
-A [id C]
302302
Sum{axes=None} [id D] 13
303-
|for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
303+
|Scan{scan_fn, while_loop=False, inplace=none} [id E] 12 (outer_out_nit_sot-0)
304304
|Minimum [id F] 7 (outer_in_nit_sot-0)
305305
| |Subtensor{i} [id G] 6
306306
| | |Shape [id H] 5
@@ -331,7 +331,7 @@ def compute_A_k(A, k):
331331
332332
Inner graphs:
333333
334-
for{cpu,scan_fn} [id E] (outer_out_nit_sot-0)
334+
Scan{scan_fn, while_loop=False, inplace=none} [id E] (outer_out_nit_sot-0)
335335
-*0-<Scalar(float64, shape=())> [id Y] -> [id U] (inner_in_seqs-0)
336336
-*1-<Scalar(int64, shape=())> [id Z] -> [id W] (inner_in_seqs-1)
337337
-*2-<Vector(float64, shape=(?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
@@ -342,7 +342,7 @@ def compute_A_k(A, k):
342342
> |Pow [id BE]
343343
> |Subtensor{i} [id BF]
344344
> | |Subtensor{start:} [id BG]
345-
> | | |for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
345+
> | | |Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
346346
> | | | |*3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1) (n_steps)
347347
> | | | |SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
348348
> | | | | |AllocEmpty{dtype='float64'} [id BJ]
@@ -370,7 +370,7 @@ def compute_A_k(A, k):
370370
> |ExpandDims{axis=0} [id BZ]
371371
> |*1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
372372
373-
for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
373+
Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
374374
-*0-<Vector(float64, shape=(?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
375375
-*1-<Vector(float64, shape=(?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
376376
>Mul [id CC] (inner_out_sit_sot-0)
@@ -404,7 +404,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
404404

405405
expected_output = """Add [id A]
406406
|Subtensor{start:} [id B]
407-
| |for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
407+
| |Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0)
408408
| | |TensorConstant{5} [id D] (n_steps)
409409
| | |SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
410410
| | | |AllocEmpty{dtype='int64'} [id F]
@@ -434,20 +434,20 @@ def fn(a_m2, a_m1, b_m2, b_m1):
434434
| | |Subtensor{i} [id R]
435435
| |ScalarConstant{2} [id Y]
436436
|Subtensor{start:} [id Z]
437-
|for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
437+
|Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1)
438438
|ScalarConstant{2} [id BA]
439439
440440
Inner graphs:
441441
442-
for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
442+
Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0)
443443
>Add [id BB] (inner_out_mit_sot-0)
444444
> |*1-<Scalar(int64, shape=())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
445445
> |*0-<Scalar(int64, shape=())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
446446
>Add [id BE] (inner_out_mit_sot-1)
447447
> |*3-<Scalar(int64, shape=())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
448448
> |*2-<Scalar(int64, shape=())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
449449
450-
for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
450+
Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1)
451451
>Add [id BB] (inner_out_mit_sot-0)
452452
>Add [id BE] (inner_out_mit_sot-1)"""
453453

@@ -474,11 +474,11 @@ def test_debugprint_mitmot():
474474
lines = output_str.split("\n")
475475

476476
expected_output = """Subtensor{i} [id A]
477-
|for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
477+
|Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0)
478478
| |Sub [id C] (n_steps)
479479
| | |Subtensor{i} [id D]
480480
| | | |Shape [id E]
481-
| | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
481+
| | | | |Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
482482
| | | | |k [id G] (n_steps)
483483
| | | | |SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
484484
| | | | | |AllocEmpty{dtype='float64'} [id I]
@@ -506,29 +506,29 @@ def test_debugprint_mitmot():
506506
| |Subtensor{:stop} [id Z] (outer_in_seqs-0)
507507
| | |Subtensor{::step} [id BA]
508508
| | | |Subtensor{:stop} [id BB]
509-
| | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
509+
| | | | |Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
510510
| | | | |ScalarConstant{-1} [id BC]
511511
| | | |ScalarConstant{-1} [id BD]
512512
| | |ScalarFromTensor [id BE]
513513
| | |Sub [id C]
514514
| |Subtensor{:stop} [id BF] (outer_in_seqs-1)
515515
| | |Subtensor{:stop} [id BG]
516516
| | | |Subtensor{::step} [id BH]
517-
| | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
517+
| | | | |Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
518518
| | | | |ScalarConstant{-1} [id BI]
519519
| | | |ScalarConstant{-1} [id BJ]
520520
| | |ScalarFromTensor [id BK]
521521
| | |Sub [id C]
522522
| |Subtensor{::step} [id BL] (outer_in_mit_mot-0)
523523
| | |IncSubtensor{start:} [id BM]
524524
| | | |Second [id BN]
525-
| | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
525+
| | | | |Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
526526
| | | | |ExpandDims{axes=[0, 1]} [id BO]
527527
| | | | |TensorConstant{0.0} [id BP]
528528
| | | |IncSubtensor{i} [id BQ]
529529
| | | | |Second [id BR]
530530
| | | | | |Subtensor{start:} [id BS]
531-
| | | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
531+
| | | | | | |Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
532532
| | | | | | |ScalarConstant{1} [id BT]
533533
| | | | | |ExpandDims{axes=[0, 1]} [id BU]
534534
| | | | | |TensorConstant{0.0} [id BV]
@@ -558,7 +558,7 @@ def test_debugprint_mitmot():
558558
559559
Inner graphs:
560560
561-
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
561+
Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0)
562562
>Add [id CM] (inner_out_mit_mot-0-0)
563563
> |Mul [id CN]
564564
> | |*2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
@@ -570,7 +570,7 @@ def test_debugprint_mitmot():
570570
> | |*0-<Vector(float64, shape=(?,))> [id CT] -> [id Z] (inner_in_seqs-0)
571571
> |*4-<Vector(float64, shape=(?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
572572
573-
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
573+
Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
574574
>Mul [id CV] (inner_out_sit_sot-0)
575575
> |*0-<Vector(float64, shape=(?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
576576
> |*1-<Vector(float64, shape=(?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
@@ -601,7 +601,7 @@ def no_shared_fn(n, x_tm1, M):
601601
# (i.e. from `Scan._fn`)
602602
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
603603

604-
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
604+
expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
605605
|TensorConstant{20000} [id B] (n_steps)
606606
|TensorConstant{[ 0 ... 998 19999]} [id C] (outer_in_seqs-0)
607607
|SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
@@ -613,7 +613,7 @@ def no_shared_fn(n, x_tm1, M):
613613
614614
Inner graphs:
615615
616-
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
616+
Scan{scan_fn, while_loop=False, inplace=all} [id A] (outer_out_sit_sot-0)
617617
>Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
618618
> |TensorConstant{0} [id J]
619619
> |Subtensor{i, j, k} [id K]

0 commit comments

Comments
 (0)