Skip to content

Commit ac47857

Browse files
committed
Specialize TensorType string representation
1 parent 67135bb commit ac47857

File tree

6 files changed

+69
-63
lines changed

6 files changed

+69
-63
lines changed

pytensor/tensor/type.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -386,21 +386,27 @@ def __str__(self):
386386
if self.name:
387387
return self.name
388388
else:
389+
shape = self.shape
390+
len_shape = len(shape)
389391

390392
def shape_str(s):
391393
if s is None:
392394
return "?"
393395
else:
394396
return str(s)
395397

396-
formatted_shape = ", ".join([shape_str(s) for s in self.shape])
397-
if len(self.shape) == 1:
398+
formatted_shape = ", ".join([shape_str(s) for s in shape])
399+
if len_shape == 1:
398400
formatted_shape += ","
399401

400-
return f"TensorType({self.dtype}, ({formatted_shape}))"
402+
if len_shape > 2:
403+
name = f"Tensor{len_shape}"
404+
else:
405+
name = ("Scalar", "Vector", "Matrix")[len_shape]
406+
return f"{name}({self.dtype}, shape=({formatted_shape}))"
401407

402408
def __repr__(self):
403-
return str(self)
409+
return f"TensorType({self.dtype}, shape={self.shape})"
404410

405411
@staticmethod
406412
def may_share_memory(a, b):

pytensor/tensor/var.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def __str__(self):
10301030
else:
10311031
val = f"{self.data}"
10321032
if len(val) > 20:
1033-
val = val[:10] + ".." + val[-10:]
1033+
val = val[:10] + " ... " + val[-10:]
10341034

10351035
if self.name is not None:
10361036
name = self.name

tests/compile/test_builders.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,10 @@ def test_debugprint():
580580
581581
OpFromGraph{inline=False} [id A]
582582
>Add [id E]
583-
> |*0-<TensorType(float64, (?, ?))> [id F]
583+
> |*0-<Matrix(float64, shape=(?, ?))> [id F]
584584
> |Mul [id G]
585-
> |*1-<TensorType(float64, (?, ?))> [id H]
586-
> |*2-<TensorType(float64, (?, ?))> [id I]
585+
> |*1-<Matrix(float64, shape=(?, ?))> [id H]
586+
> |*2-<Matrix(float64, shape=(?, ?))> [id I]
587587
"""
588588

589589
for truth, out in zip(exp_res.split("\n"), lines):

tests/scan/test_printing.py

+48-48
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def test_debugprint_sitsot():
5858
5959
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
6060
>Mul [id W] (inner_out_sit_sot-0)
61-
> |*0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
62-
> |*1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
61+
> |*0-<Vector(float64, shape=(?,))> [id X] -> [id E] (inner_in_sit_sot-0)
62+
> |*1-<Vector(float64, shape=(?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
6363

6464
for truth, out in zip(expected_output.split("\n"), lines):
6565
assert truth.strip() == out.strip()
@@ -113,8 +113,8 @@ def test_debugprint_sitsot_no_extra_info():
113113
114114
for{cpu,scan_fn} [id C]
115115
>Mul [id W]
116-
> |*0-<TensorType(float64, (?,))> [id X] -> [id E]
117-
> |*1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
116+
> |*0-<Vector(float64, shape=(?,))> [id X] -> [id E]
117+
> |*1-<Vector(float64, shape=(?,))> [id Y] -> [id M]"""
118118

119119
for truth, out in zip(expected_output.split("\n"), lines):
120120
assert truth.strip() == out.strip()
@@ -174,10 +174,10 @@ def test_debugprint_nitsot():
174174
175175
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
176176
>Mul [id X] (inner_out_nit_sot-0)
177-
> |*0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
177+
> |*0-<Scalar(float64, shape=())> [id Y] -> [id S] (inner_in_seqs-0)
178178
> |Pow [id Z]
179-
> |*2-<TensorType(float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0)
180-
> |*1-<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
179+
> |*2-<Scalar(float64, shape=())> [id BA] -> [id W] (inner_in_non_seqs-0)
180+
> |*1-<Scalar(int64, shape=())> [id BB] -> [id U] (inner_in_seqs-1)"""
181181

182182
for truth, out in zip(expected_output.split("\n"), lines):
183183
assert truth.strip() == out.strip()
@@ -249,22 +249,22 @@ def compute_A_k(A, k):
249249
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
250250
>Mul [id Y] (inner_out_nit_sot-0)
251251
> |ExpandDims{axis=0} [id Z]
252-
> | |*0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
252+
> | |*0-<Scalar(float64, shape=())> [id BA] -> [id S] (inner_in_seqs-0)
253253
> |Pow [id BB]
254254
> |Subtensor{i} [id BC]
255255
> | |Subtensor{start:} [id BD]
256256
> | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
257-
> | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
257+
> | | | |*3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
258258
> | | | |SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
259259
> | | | | |AllocEmpty{dtype='float64'} [id BH]
260260
> | | | | | |Add [id BI]
261-
> | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
261+
> | | | | | | |*3-<Scalar(int32, shape=())> [id BF] -> [id X] (inner_in_non_seqs-1)
262262
> | | | | | | |Subtensor{i} [id BJ]
263263
> | | | | | | |Shape [id BK]
264264
> | | | | | | | |Unbroadcast{0} [id BL]
265265
> | | | | | | | |ExpandDims{axis=0} [id BM]
266266
> | | | | | | | |Second [id BN]
267-
> | | | | | | | |*2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
267+
> | | | | | | | |*2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
268268
> | | | | | | | |ExpandDims{axis=0} [id BP]
269269
> | | | | | | | |TensorConstant{1.0} [id BQ]
270270
> | | | | | | |ScalarConstant{0} [id BR]
@@ -275,16 +275,16 @@ def compute_A_k(A, k):
275275
> | | | | |Unbroadcast{0} [id BL]
276276
> | | | | |ScalarFromTensor [id BV]
277277
> | | | | |Subtensor{i} [id BJ]
278-
> | | | |*2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
278+
> | | | |*2-<Vector(float64, shape=(?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
279279
> | | |ScalarConstant{1} [id BW]
280280
> | |ScalarConstant{-1} [id BX]
281281
> |ExpandDims{axis=0} [id BY]
282-
> |*1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
282+
> |*1-<Scalar(int64, shape=())> [id BZ] -> [id U] (inner_in_seqs-1)
283283
284284
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
285285
>Mul [id CA] (inner_out_sit_sot-0)
286-
> |*0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
287-
> |*1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
286+
> |*0-<Vector(float64, shape=(?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
287+
> |*1-<Vector(float64, shape=(?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
288288

289289
for truth, out in zip(expected_output.split("\n"), lines):
290290
assert truth.strip() == out.strip()
@@ -332,28 +332,28 @@ def compute_A_k(A, k):
332332
Inner graphs:
333333
334334
for{cpu,scan_fn} [id E] (outer_out_nit_sot-0)
335-
-*0-<TensorType(float64, ())> [id Y] -> [id U] (inner_in_seqs-0)
336-
-*1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
337-
-*2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
338-
-*3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
335+
-*0-<Scalar(float64, shape=())> [id Y] -> [id U] (inner_in_seqs-0)
336+
-*1-<Scalar(int64, shape=())> [id Z] -> [id W] (inner_in_seqs-1)
337+
-*2-<Vector(float64, shape=(?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
338+
-*3-<Scalar(int32, shape=())> [id BB] -> [id B] (inner_in_non_seqs-1)
339339
>Mul [id BC] (inner_out_nit_sot-0)
340340
> |ExpandDims{axis=0} [id BD]
341-
> | |*0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
341+
> | |*0-<Scalar(float64, shape=())> [id Y] (inner_in_seqs-0)
342342
> |Pow [id BE]
343343
> |Subtensor{i} [id BF]
344344
> | |Subtensor{start:} [id BG]
345345
> | | |for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
346-
> | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
346+
> | | | |*3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1) (n_steps)
347347
> | | | |SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
348348
> | | | | |AllocEmpty{dtype='float64'} [id BJ]
349349
> | | | | | |Add [id BK]
350-
> | | | | | | |*3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
350+
> | | | | | | |*3-<Scalar(int32, shape=())> [id BB] (inner_in_non_seqs-1)
351351
> | | | | | | |Subtensor{i} [id BL]
352352
> | | | | | | |Shape [id BM]
353353
> | | | | | | | |Unbroadcast{0} [id BN]
354354
> | | | | | | | |ExpandDims{axis=0} [id BO]
355355
> | | | | | | | |Second [id BP]
356-
> | | | | | | | |*2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0)
356+
> | | | | | | | |*2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0)
357357
> | | | | | | | |ExpandDims{axis=0} [id BQ]
358358
> | | | | | | | |TensorConstant{1.0} [id BR]
359359
> | | | | | | |ScalarConstant{0} [id BS]
@@ -364,18 +364,18 @@ def compute_A_k(A, k):
364364
> | | | | |Unbroadcast{0} [id BN]
365365
> | | | | |ScalarFromTensor [id BW]
366366
> | | | | |Subtensor{i} [id BL]
367-
> | | | |*2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
367+
> | | | |*2-<Vector(float64, shape=(?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
368368
> | | |ScalarConstant{1} [id BX]
369369
> | |ScalarConstant{-1} [id BY]
370370
> |ExpandDims{axis=0} [id BZ]
371-
> |*1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
371+
> |*1-<Scalar(int64, shape=())> [id Z] (inner_in_seqs-1)
372372
373373
for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
374-
-*0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
375-
-*1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
374+
-*0-<Vector(float64, shape=(?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
375+
-*1-<Vector(float64, shape=(?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
376376
>Mul [id CC] (inner_out_sit_sot-0)
377-
> |*0-<TensorType(float64, (?,))> [id CA] (inner_in_sit_sot-0)
378-
> |*1-<TensorType(float64, (?,))> [id CB] (inner_in_non_seqs-0)"""
377+
> |*0-<Vector(float64, shape=(?,))> [id CA] (inner_in_sit_sot-0)
378+
> |*1-<Vector(float64, shape=(?,))> [id CB] (inner_in_non_seqs-0)"""
379379

380380
for truth, out in zip(expected_output.split("\n"), lines):
381381
assert truth.strip() == out.strip()
@@ -413,7 +413,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
413413
| | | | |Subtensor{i} [id H]
414414
| | | | |Shape [id I]
415415
| | | | | |Subtensor{:stop} [id J]
416-
| | | | | |<TensorType(int64, (?,))> [id K]
416+
| | | | | |<Vector(int64, shape=(?,))> [id K]
417417
| | | | | |ScalarConstant{2} [id L]
418418
| | | | |ScalarConstant{0} [id M]
419419
| | | |Subtensor{:stop} [id J]
@@ -426,7 +426,7 @@ def fn(a_m2, a_m1, b_m2, b_m1):
426426
| | | |Subtensor{i} [id R]
427427
| | | |Shape [id S]
428428
| | | | |Subtensor{:stop} [id T]
429-
| | | | |<TensorType(int64, (?,))> [id U]
429+
| | | | |<Vector(int64, shape=(?,))> [id U]
430430
| | | | |ScalarConstant{2} [id V]
431431
| | | |ScalarConstant{0} [id W]
432432
| | |Subtensor{:stop} [id T]
@@ -441,11 +441,11 @@ def fn(a_m2, a_m1, b_m2, b_m1):
441441
442442
for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
443443
>Add [id BB] (inner_out_mit_sot-0)
444-
> |*1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
445-
> |*0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
444+
> |*1-<Scalar(int64, shape=())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
445+
> |*0-<Scalar(int64, shape=())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
446446
>Add [id BE] (inner_out_mit_sot-1)
447-
> |*3-<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
448-
> |*2-<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
447+
> |*3-<Scalar(int64, shape=())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
448+
> |*2-<Scalar(int64, shape=())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
449449
450450
for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
451451
>Add [id BB] (inner_out_mit_sot-0)
@@ -561,19 +561,19 @@ def test_debugprint_mitmot():
561561
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
562562
>Add [id CM] (inner_out_mit_mot-0-0)
563563
> |Mul [id CN]
564-
> | |*2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
565-
> | |*5-<TensorType(float64, (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
566-
> |*3-<TensorType(float64, (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
564+
> | |*2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
565+
> | |*5-<Vector(float64, shape=(?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
566+
> |*3-<Vector(float64, shape=(?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
567567
>Add [id CR] (inner_out_sit_sot-0)
568568
> |Mul [id CS]
569-
> | |*2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
570-
> | |*0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
571-
> |*4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
569+
> | |*2-<Vector(float64, shape=(?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
570+
> | |*0-<Vector(float64, shape=(?,))> [id CT] -> [id Z] (inner_in_seqs-0)
571+
> |*4-<Vector(float64, shape=(?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
572572
573573
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
574574
>Mul [id CV] (inner_out_sit_sot-0)
575-
> |*0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
576-
> |*1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
575+
> |*0-<Vector(float64, shape=(?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
576+
> |*1-<Vector(float64, shape=(?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
577577

578578
for truth, out in zip(expected_output.split("\n"), lines):
579579
assert truth.strip() == out.strip()
@@ -603,25 +603,25 @@ def no_shared_fn(n, x_tm1, M):
603603

604604
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
605605
|TensorConstant{20000} [id B] (n_steps)
606-
|TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
606+
|TensorConstant{[ 0 ... 998 19999]} [id C] (outer_in_seqs-0)
607607
|SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
608608
| |AllocEmpty{dtype='int64'} [id E] 0
609609
| | |TensorConstant{20000} [id B]
610610
| |TensorConstant{(1,) of 0} [id F]
611611
| |ScalarConstant{1} [id G]
612-
|<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
612+
|<Tensor3(float64, shape=(20000, 2, 2))> [id H] (outer_in_non_seqs-0)
613613
614614
Inner graphs:
615615
616616
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
617617
>Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
618618
> |TensorConstant{0} [id J]
619619
> |Subtensor{i, j, k} [id K]
620-
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
620+
> | |*2-<Tensor3(float64, shape=(20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
621621
> | |ScalarFromTensor [id M]
622-
> | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
622+
> | | |*0-<Scalar(int64, shape=())> [id N] -> [id C] (inner_in_seqs-0)
623623
> | |ScalarFromTensor [id O]
624-
> | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
624+
> | | |*1-<Scalar(int64, shape=())> [id P] -> [id D] (inner_in_sit_sot-0)
625625
> | |ScalarConstant{0} [id Q]
626626
> |TensorConstant{1} [id R]
627627

tests/tensor/test_type.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_fixed_shape_basic():
252252
assert t1.shape == (2, 3)
253253
assert t1.broadcastable == (False, False)
254254

255-
assert str(t1) == "TensorType(float64, (2, 3))"
255+
assert str(t1) == "Matrix(float64, shape=(2, 3))"
256256

257257
t1 = TensorType("float64", shape=(1,))
258258
assert t1.shape == (1,)

tests/test_printing.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,7 @@ def test_debugprint():
281281
| | |B
282282
| |TensorConstant{1.0}
283283
| |B
284-
| |<TensorType(float64, (?,))>
284+
| |<Vector(float64, shape=(?,))>
285285
| |TensorConstant{0.0}
286286
|D
287287
|A
@@ -304,7 +304,7 @@ def test_debugprint():
304304

305305
def test_debugprint_id_type():
306306
a_at = dvector()
307-
b_at = dmatrix()
307+
b_at = matrix(dtype="float64", shape=(50, None))
308308

309309
d_at = b_at.dot(a_at)
310310
e_at = d_at + a_at
@@ -315,9 +315,9 @@ def test_debugprint_id_type():
315315

316316
exp_res = f"""Add [id {e_at.auto_name}]
317317
|dot [id {d_at.auto_name}]
318-
| |<TensorType(float64, (?, ?))> [id {b_at.auto_name}]
319-
| |<TensorType(float64, (?,))> [id {a_at.auto_name}]
320-
|<TensorType(float64, (?,))> [id {a_at.auto_name}]
318+
| |<Matrix(float64, shape=(50, ?))> [id {b_at.auto_name}]
319+
| |<Vector(float64, shape=(?,))> [id {a_at.auto_name}]
320+
|<Vector(float64, shape=(?,))> [id {a_at.auto_name}]
321321
"""
322322

323323
assert [l.strip() for l in s.split("\n")] == [
@@ -328,7 +328,7 @@ def test_debugprint_id_type():
328328
def test_pprint():
329329
x = dvector()
330330
y = x[1]
331-
assert pp(y) == "<TensorType(float64, (?,))>[1]"
331+
assert pp(y) == "<Vector(float64, shape=(?,))>[1]"
332332

333333

334334
def test_debugprint_inner_graph():

0 commit comments

Comments
 (0)