Skip to content

Improve Op string representation and debug_print formatting #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 8, 2023

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 24, 2023

Please if you can: try with some examples of your own and give feedback Specially if you have suggestions for Ops that have awful names that could be rendered better. I only tackled some I could think from memory.

Closes #313
Closes #196

Example from the issue:

import pytensor
import pytensor.tensor as pt

x = pt.scalar("x")
y = pt.vector("y")
z = x + pt.exp(y)

pytensor.dprint(z)

Before:

Elemwise{add,no_inplace} [id A]
 |InplaceDimShuffle{x} [id B]
 | |x [id C]
 |Elemwise{exp,no_inplace} [id D]
   |y [id E]

After:

Add [id A]
 ├─ ExpandDims{axis=0} [id B]
 │  └─ x [id C]
 └─ Exp [id D]
    └─ y [id E]

Before with print_type=True

Elemwise{add,no_inplace} [id A] <TensorType(float64, (?,))>
 |InplaceDimShuffle{x} [id B] <TensorType(float64, (1,))>
 | |x [id C] <TensorType(float64, ())>
 |Elemwise{exp,no_inplace} [id D] <TensorType(float64, (?,))>
   |y [id E] <TensorType(float64, (?,))>

After with print_type=True

Add [id A] <Vector(float64, shape=(?,))>
 ├─ ExpandDims{axis=0} [id B] <Vector(float64, shape=(1,))>
 │  └─ x [id C] <Scalar(float64, shape=())>
 └─ Exp [id D] <Vector(float64, shape=(?,))>
    └─ y [id E] <Vector(float64, shape=(?,))>

A more crazy one:

import pytensor
import pytensor.tensor as pt

x = pt.vector()
y = pt.matrix("y")
z = x + pt.exp(y)
out1 = pt.squeeze(z[None, None, None, None][:1, 0, None, 0, 1, 2])
out2 = z.T
out3 = z.dimshuffle((1, "x", 0)).sum((0, 1, 2))
outs = [out1, out2, out3]

pytensor.dprint(outs)

Before:

InplaceDimShuffle{2} [id A]
 |Subtensor{:int64:, int64, ::, int64, int64, int64} [id B]
   |InplaceDimShuffle{0,1,x,2,3,4,5} [id C]
   | |InplaceDimShuffle{x,x,x,x,0,1} [id D]
   |   |Elemwise{add,no_inplace} [id E]
   |     |InplaceDimShuffle{x,0} [id F]
   |     | |<TensorType(float64, (?,))> [id G]
   |     |Elemwise{exp,no_inplace} [id H]
   |       |y [id I]
   |ScalarConstant{1} [id J]
   |ScalarConstant{0} [id K]
   |ScalarConstant{0} [id L]
   |ScalarConstant{1} [id M]
   |ScalarConstant{2} [id N]
InplaceDimShuffle{1,0} [id O]
 |Elemwise{add,no_inplace} [id E]
Sum{axis=[0, 1, 2], acc_dtype=float64} [id P]
 |InplaceDimShuffle{1,x,0} [id Q]
   |Elemwise{add,no_inplace} [id E]

After:

DropDims{axes=[0, 1]} [id A]
 └─ Subtensor{:stop, i, :, j, k, ii} [id B]
    ├─ ExpandDims{axis=2} [id C]
    │  └─ ExpandDims{axes=[0, 1, 2, 3]} [id D]
    │     └─ Add [id E]
    │        ├─ ExpandDims{axis=0} [id F]
    │        │  └─ <Vector(float64, shape=(?,))> [id G]
    │        └─ Exp [id H]
    │           └─ y [id I]
    ├─ 1 [id J]
    ├─ 0 [id K]
    ├─ 0 [id L]
    ├─ 1 [id M]
    └─ 2 [id N]
Transpose{axes=[1, 0]} [id O]
 └─ Add [id E]
    └─ ···
Sum{axes=[0, 1, 2]} [id P]
 └─ DimShuffle{order=[1,x,0]} [id Q]
    └─ Add [id E]
       └─ ···

Something with inner graphs

import pytensor
import pytensor.tensor as pt

x = pt.scalar("x")
y = pt.vector("y")
z = x + pt.exp(y)
w = z + 2

fn = pytensor.function([x, y], [z, w])
pytensor.dprint(fn)

Before:

Elemwise{Composite}.0 [id A] 1
 |y [id B]
 |InplaceDimShuffle{x} [id C] 0
 | |x [id D]
 |TensorConstant{(1,) of 2.0} [id E]
Elemwise{Composite}.1 [id A] 1

Inner graphs:

Elemwise{Composite}.0 [id A]
 >add [id F]
 > |<float64> [id G]
 > |exp [id H]
 >   |<float64> [id I]
 >add [id J]
 > |<float64> [id K]
 > |<float64> [id G]
 > |exp [id H]
Elemwise{Composite}.1 [id A]
 >add [id F]
 >add [id J]

After:

Composite{...}.0 [id A] 1
 ├─ y [id B]
 ├─ ExpandDims{axis=0} [id C] 0
 │  └─ x [id D]
 └─ [2.] [id E]
Composite{...}.1 [id A] 1
 └─ ···
Inner graphs:
Composite{...} [id A]
 ← add [id F] 'o0'
    ├─ i1 [id G]
    └─ exp [id H] 't0'
       └─ i0 [id I]
 ← add [id J] 'o1'
    ├─ i2 [id K]
    ├─ i1 [id G]
    └─ exp [id H] 't0'
       └─ ···

@ricardoV94 ricardoV94 force-pushed the better_dprint branch 4 times, most recently from 2441661 to 3ddc622 Compare May 25, 2023 14:13
@ricardoV94
Copy link
Member Author

ricardoV94 commented May 25, 2023

Apparently someone decided to test rewrites were working from the dprint -.- ...

@ricardoV94 ricardoV94 force-pushed the better_dprint branch 5 times, most recently from 947a10b to 538b507 Compare May 26, 2023 16:26
@ricardoV94 ricardoV94 added graph objects enhancement New feature or request labels May 26, 2023
@ricardoV94 ricardoV94 marked this pull request as ready for review May 26, 2023 16:42
@ricardoV94
Copy link
Member Author

Tests are passing, so ready for review

@twiecki
Copy link
Member

twiecki commented May 27, 2023

Does it need to stay ScalarConstant{1} or can it just be 1?

@ricardoV94
Copy link
Member Author

ricardoV94 commented May 30, 2023

Does it need to stay ScalarConstant{1} or can it just be 1?

Updated Reverted as tests started to fail, investigating

@ricardoV94 ricardoV94 force-pushed the better_dprint branch 2 times, most recently from 0b610ea to e3e0424 Compare May 30, 2023 11:37
@ricardoV94 ricardoV94 marked this pull request as draft May 30, 2023 11:37
@codecov-commenter
Copy link

Codecov Report

Merging #319 (96cf4ac) into main (236a3df) will increase coverage by 0.05%.
The diff coverage is 99.20%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #319      +/-   ##
==========================================
+ Coverage   80.33%   80.38%   +0.05%     
==========================================
  Files         156      156              
  Lines       45403    45412       +9     
  Branches    11108    11110       +2     
==========================================
+ Hits        36475    36506      +31     
+ Misses       6712     6699      -13     
+ Partials     2216     2207       -9     
Impacted Files Coverage Δ
pytensor/tensor/rewriting/math.py 86.05% <ø> (ø)
pytensor/tensor/var.py 87.42% <ø> (-0.27%) ⬇️
pytensor/scan/op.py 84.71% <75.00%> (+0.11%) ⬆️
pytensor/graph/basic.py 89.19% <100.00%> (+0.35%) ⬆️
pytensor/graph/op.py 87.06% <100.00%> (+0.26%) ⬆️
pytensor/printing.py 51.38% <100.00%> (+2.44%) ⬆️
pytensor/scalar/basic.py 80.03% <100.00%> (+0.12%) ⬆️
pytensor/tensor/elemwise.py 88.02% <100.00%> (ø)
pytensor/tensor/math.py 90.49% <100.00%> (-0.05%) ⬇️
pytensor/tensor/subtensor.py 89.60% <100.00%> (-0.05%) ⬇️
... and 1 more

... and 1 file with indirect coverage changes

@ricardoV94 ricardoV94 marked this pull request as ready for review June 2, 2023 18:00
@ricardoV94
Copy link
Member Author

Ready for review again

@ricardoV94 ricardoV94 requested review from michaelosthege, aseyboldt and ferrine and removed request for aseyboldt June 6, 2023 13:37
Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much nicer than before!

@ricardoV94 ricardoV94 merged commit ec6a315 into pymc-devs:main Jun 8, 2023
@ricardoV94 ricardoV94 deleted the better_dprint branch June 21, 2023 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Make debug_print more readable Don't repeat inner graphs with multiple outputs in debugprint
4 participants