@@ -61,8 +61,8 @@ def test_debugprint_sitsot():
61
61
62
62
Scan{scan_fn, while_loop=False, inplace=none} [id C]
63
63
← Mul [id W] (inner_out_sit_sot-0)
64
- ├─ *0-<TensorType (float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
65
- └─ *1-<TensorType (float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
64
+ ├─ *0-<Vector (float64, shape= (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
65
+ └─ *1-<Vector (float64, shape= (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
66
66
67
67
for truth , out in zip (expected_output .split ("\n " ), lines ):
68
68
assert truth .strip () == out .strip ()
@@ -119,8 +119,8 @@ def test_debugprint_sitsot_no_extra_info():
119
119
120
120
Scan{scan_fn, while_loop=False, inplace=none} [id C]
121
121
← Mul [id W]
122
- ├─ *0-<TensorType (float64, (?,))> [id X] -> [id E]
123
- └─ *1-<TensorType (float64, (?,))> [id Y] -> [id M]"""
122
+ ├─ *0-<Vector (float64, shape= (?,))> [id X] -> [id E]
123
+ └─ *1-<Vector (float64, shape= (?,))> [id Y] -> [id M]"""
124
124
125
125
for truth , out in zip (expected_output .split ("\n " ), lines ):
126
126
assert truth .strip () == out .strip ()
@@ -185,10 +185,10 @@ def test_debugprint_nitsot():
185
185
186
186
Scan{scan_fn, while_loop=False, inplace=none} [id B]
187
187
← Mul [id X] (inner_out_nit_sot-0)
188
- ├─ *0-<TensorType (float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
188
+ ├─ *0-<Scalar (float64, shape= ())> [id Y] -> [id S] (inner_in_seqs-0)
189
189
└─ Pow [id Z]
190
- ├─ *2-<TensorType (float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0)
191
- └─ *1-<TensorType (int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
190
+ ├─ *2-<Scalar (float64, shape= ())> [id BA] -> [id W] (inner_in_non_seqs-0)
191
+ └─ *1-<Scalar (int64, shape= ())> [id BB] -> [id U] (inner_in_seqs-1)"""
192
192
193
193
for truth , out in zip (expected_output .split ("\n " ), lines ):
194
194
assert truth .strip () == out .strip ()
@@ -265,22 +265,22 @@ def compute_A_k(A, k):
265
265
Scan{scan_fn, while_loop=False, inplace=none} [id B]
266
266
← Mul [id Y] (inner_out_nit_sot-0)
267
267
├─ ExpandDims{axis=0} [id Z]
268
- │ └─ *0-<TensorType (float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
268
+ │ └─ *0-<Scalar (float64, shape= ())> [id BA] -> [id S] (inner_in_seqs-0)
269
269
└─ Pow [id BB]
270
270
├─ Subtensor{i} [id BC]
271
271
│ ├─ Subtensor{start:} [id BD]
272
272
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
273
- │ │ │ ├─ *3-<TensorType (int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
273
+ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
274
274
│ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
275
275
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
276
276
│ │ │ │ │ ├─ Add [id BI]
277
- │ │ │ │ │ │ ├─ *3-<TensorType (int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
277
+ │ │ │ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BF] -> [id X] (inner_in_non_seqs-1)
278
278
│ │ │ │ │ │ └─ Subtensor{i} [id BJ]
279
279
│ │ │ │ │ │ ├─ Shape [id BK]
280
280
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
281
281
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
282
282
│ │ │ │ │ │ │ └─ Second [id BN]
283
- │ │ │ │ │ │ │ ├─ *2-<TensorType (float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
283
+ │ │ │ │ │ │ │ ├─ *2-<Vector (float64, shape= (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
284
284
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
285
285
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
286
286
│ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
@@ -294,16 +294,16 @@ def compute_A_k(A, k):
294
294
│ │ │ │ └─ ScalarFromTensor [id BV]
295
295
│ │ │ │ └─ Subtensor{i} [id BJ]
296
296
│ │ │ │ └─ ···
297
- │ │ │ └─ *2-<TensorType (float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
297
+ │ │ │ └─ *2-<Vector (float64, shape= (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
298
298
│ │ └─ ScalarConstant{1} [id BW]
299
299
│ └─ ScalarConstant{-1} [id BX]
300
300
└─ ExpandDims{axis=0} [id BY]
301
- └─ *1-<TensorType (int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
301
+ └─ *1-<Scalar (int64, shape= ())> [id BZ] -> [id U] (inner_in_seqs-1)
302
302
303
303
Scan{scan_fn, while_loop=False, inplace=none} [id BE]
304
304
← Mul [id CA] (inner_out_sit_sot-0)
305
- ├─ *0-<TensorType (float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306
- └─ *1-<TensorType (float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
305
+ ├─ *0-<Vector (float64, shape= (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
306
+ └─ *1-<Vector (float64, shape= (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
307
307
308
308
for truth , out in zip (expected_output .split ("\n " ), lines ):
309
309
assert truth .strip () == out .strip ()
@@ -356,28 +356,28 @@ def compute_A_k(A, k):
356
356
Inner graphs:
357
357
358
358
Scan{scan_fn, while_loop=False, inplace=none} [id E]
359
- → *0-<TensorType (float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
360
- → *1-<TensorType (int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
361
- → *2-<TensorType (float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
362
- → *3-<TensorType (int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
359
+ → *0-<Scalar (float64, shape= ())> [id Y] -> [id U] (inner_in_seqs-0)
360
+ → *1-<Scalar (int64, shape= ())> [id Z] -> [id W] (inner_in_seqs-1)
361
+ → *2-<Vector (float64, shape= (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
362
+ → *3-<Scalar (int32, shape= ())> [id BB] -> [id B] (inner_in_non_seqs-1)
363
363
← Mul [id BC] (inner_out_nit_sot-0)
364
364
├─ ExpandDims{axis=0} [id BD]
365
- │ └─ *0-<TensorType (float64, ())> [id Y] (inner_in_seqs-0)
365
+ │ └─ *0-<Scalar (float64, shape= ())> [id Y] (inner_in_seqs-0)
366
366
└─ Pow [id BE]
367
367
├─ Subtensor{i} [id BF]
368
368
│ ├─ Subtensor{start:} [id BG]
369
369
│ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
370
- │ │ │ ├─ *3-<TensorType (int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
370
+ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BB] (inner_in_non_seqs-1) (n_steps)
371
371
│ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
372
372
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
373
373
│ │ │ │ │ ├─ Add [id BK]
374
- │ │ │ │ │ │ ├─ *3-<TensorType (int32, ())> [id BB] (inner_in_non_seqs-1)
374
+ │ │ │ │ │ │ ├─ *3-<Scalar (int32, shape= ())> [id BB] (inner_in_non_seqs-1)
375
375
│ │ │ │ │ │ └─ Subtensor{i} [id BL]
376
376
│ │ │ │ │ │ ├─ Shape [id BM]
377
377
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
378
378
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
379
379
│ │ │ │ │ │ │ └─ Second [id BP]
380
- │ │ │ │ │ │ │ ├─ *2-<TensorType (float64, (?,))> [id BA] (inner_in_non_seqs-0)
380
+ │ │ │ │ │ │ │ ├─ *2-<Vector (float64, shape= (?,))> [id BA] (inner_in_non_seqs-0)
381
381
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
382
382
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
383
383
│ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
@@ -391,18 +391,18 @@ def compute_A_k(A, k):
391
391
│ │ │ │ └─ ScalarFromTensor [id BW]
392
392
│ │ │ │ └─ Subtensor{i} [id BL]
393
393
│ │ │ │ └─ ···
394
- │ │ │ └─ *2-<TensorType (float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
394
+ │ │ │ └─ *2-<Vector (float64, shape= (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
395
395
│ │ └─ ScalarConstant{1} [id BX]
396
396
│ └─ ScalarConstant{-1} [id BY]
397
397
└─ ExpandDims{axis=0} [id BZ]
398
- └─ *1-<TensorType (int64, ())> [id Z] (inner_in_seqs-1)
398
+ └─ *1-<Scalar (int64, shape= ())> [id Z] (inner_in_seqs-1)
399
399
400
400
Scan{scan_fn, while_loop=False, inplace=none} [id BH]
401
- → *0-<TensorType (float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402
- → *1-<TensorType (float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
401
+ → *0-<Vector (float64, shape= (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
402
+ → *1-<Vector (float64, shape= (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
403
403
← Mul [id CC] (inner_out_sit_sot-0)
404
- ├─ *0-<TensorType (float64, (?,))> [id CA] (inner_in_sit_sot-0)
405
- └─ *1-<TensorType (float64, (?,))> [id CB] (inner_in_non_seqs-0)"""
404
+ ├─ *0-<Vector (float64, shape= (?,))> [id CA] (inner_in_sit_sot-0)
405
+ └─ *1-<Vector (float64, shape= (?,))> [id CB] (inner_in_non_seqs-0)"""
406
406
407
407
for truth , out in zip (expected_output .split ("\n " ), lines ):
408
408
assert truth .strip () == out .strip ()
@@ -440,7 +440,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
440
440
│ │ │ │ └─ Subtensor{i} [id H]
441
441
│ │ │ │ ├─ Shape [id I]
442
442
│ │ │ │ │ └─ Subtensor{:stop} [id J]
443
- │ │ │ │ │ ├─ <TensorType (int64, (?,))> [id K]
443
+ │ │ │ │ │ ├─ <Vector (int64, shape= (?,))> [id K]
444
444
│ │ │ │ │ └─ ScalarConstant{2} [id L]
445
445
│ │ │ │ └─ ScalarConstant{0} [id M]
446
446
│ │ │ ├─ Subtensor{:stop} [id J]
@@ -455,7 +455,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
455
455
│ │ │ └─ Subtensor{i} [id R]
456
456
│ │ │ ├─ Shape [id S]
457
457
│ │ │ │ └─ Subtensor{:stop} [id T]
458
- │ │ │ │ ├─ <TensorType (int64, (?,))> [id U]
458
+ │ │ │ │ ├─ <Vector (int64, shape= (?,))> [id U]
459
459
│ │ │ │ └─ ScalarConstant{2} [id V]
460
460
│ │ │ └─ ScalarConstant{0} [id W]
461
461
│ │ ├─ Subtensor{:stop} [id T]
@@ -473,11 +473,11 @@ def fn(a_m2, a_m1, b_m2, b_m1):
473
473
474
474
Scan{scan_fn, while_loop=False, inplace=none} [id C]
475
475
← Add [id BB] (inner_out_mit_sot-0)
476
- ├─ *1-<TensorType (int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477
- └─ *0-<TensorType (int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
476
+ ├─ *1-<Scalar (int64, shape= ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
477
+ └─ *0-<Scalar (int64, shape= ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
478
478
← Add [id BE] (inner_out_mit_sot-1)
479
- ├─ *3-<TensorType (int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
480
- └─ *2-<TensorType (int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)"""
479
+ ├─ *3-<Scalar (int64, shape= ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
480
+ └─ *2-<Scalar (int64, shape= ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)"""
481
481
482
482
for truth , out in zip (expected_output .split ("\n " ), lines ):
483
483
assert truth .strip () == out .strip ()
@@ -601,19 +601,19 @@ def test_debugprint_mitmot():
601
601
Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
602
602
← Add [id CM] (inner_out_mit_mot-0-0)
603
603
├─ Mul [id CN]
604
- │ ├─ *2-<TensorType (float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
605
- │ └─ *5-<TensorType (float64, (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
606
- └─ *3-<TensorType (float64, (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
604
+ │ ├─ *2-<Vector (float64, shape= (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
605
+ │ └─ *5-<Vector (float64, shape= (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
606
+ └─ *3-<Vector (float64, shape= (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
607
607
← Add [id CR] (inner_out_sit_sot-0)
608
608
├─ Mul [id CS]
609
- │ ├─ *2-<TensorType (float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
610
- │ └─ *0-<TensorType (float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611
- └─ *4-<TensorType (float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
609
+ │ ├─ *2-<Vector (float64, shape= (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
610
+ │ └─ *0-<Vector (float64, shape= (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
611
+ └─ *4-<Vector (float64, shape= (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
612
612
613
613
Scan{scan_fn, while_loop=False, inplace=none} [id F]
614
614
← Mul [id CV] (inner_out_sit_sot-0)
615
- ├─ *0-<TensorType (float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616
- └─ *1-<TensorType (float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
615
+ ├─ *0-<Vector (float64, shape= (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
616
+ └─ *1-<Vector (float64, shape= (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
617
617
618
618
for truth , out in zip (expected_output .split ("\n " ), lines ):
619
619
assert truth .strip () == out .strip ()
@@ -643,25 +643,25 @@ def no_shared_fn(n, x_tm1, M):
643
643
644
644
expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
645
645
├─ TensorConstant{20000} [id B] (n_steps)
646
- ├─ TensorConstant{[ 0 .. 998 19999]} [id C] (outer_in_seqs-0)
646
+ ├─ TensorConstant{[ 0 ... 998 19999]} [id C] (outer_in_seqs-0)
647
647
├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
648
648
│ ├─ AllocEmpty{dtype='int64'} [id E] 0
649
649
│ │ └─ TensorConstant{20000} [id B]
650
650
│ ├─ TensorConstant{(1,) of 0} [id F]
651
651
│ └─ ScalarConstant{1} [id G]
652
- └─ <TensorType (float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
652
+ └─ <Tensor3 (float64, shape= (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
653
653
654
654
Inner graphs:
655
655
656
656
Scan{scan_fn, while_loop=False, inplace=all} [id A]
657
657
← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
658
658
├─ TensorConstant{0} [id J]
659
659
├─ Subtensor{i, j, k} [id K]
660
- │ ├─ *2-<TensorType (float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
660
+ │ ├─ *2-<Tensor3 (float64, shape= (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
661
661
│ ├─ ScalarFromTensor [id M]
662
- │ │ └─ *0-<TensorType (int64, ())> [id N] -> [id C] (inner_in_seqs-0)
662
+ │ │ └─ *0-<Scalar (int64, shape= ())> [id N] -> [id C] (inner_in_seqs-0)
663
663
│ ├─ ScalarFromTensor [id O]
664
- │ │ └─ *1-<TensorType (int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
664
+ │ │ └─ *1-<Scalar (int64, shape= ())> [id P] -> [id D] (inner_in_sit_sot-0)
665
665
│ └─ ScalarConstant{0} [id Q]
666
666
└─ TensorConstant{1} [id R]
667
667
0 commit comments