File tree 3 files changed +16
-12
lines changed
3 files changed +16
-12
lines changed Original file line number Diff line number Diff line change @@ -50,4 +50,5 @@ dependencies:
50
50
- typing_extensions
51
51
# optional
52
52
- cython
53
-
53
+ - graphviz
54
+ - pydot
Original file line number Diff line number Diff line change 9
9
from pytensor import compile
10
10
from pytensor .compile .function import function
11
11
from pytensor .configdefaults import config
12
- from pytensor .d3viz . formatting import pydot_imported , pydot_imported_msg
12
+ from pytensor .printing import pydot_imported , pydot_imported_msg
13
13
from tests .d3viz import models
14
14
15
15
Original file line number Diff line number Diff line change 2
2
import pytest
3
3
4
4
from pytensor import config , function
5
- from pytensor .d3viz .formatting import PyDotFormatter , pydot_imported , pydot_imported_msg
5
+ from pytensor .d3viz .formatting import PyDotFormatter
6
+ from pytensor .printing import pydot_imported , pydot_imported_msg
6
7
7
8
8
9
if not pydot_imported :
@@ -21,21 +22,23 @@ def node_counts(self, graph):
21
22
nc = dict (zip (a , b ))
22
23
return nc
23
24
24
- def test_mlp (self ):
25
+ @pytest .mark .parametrize ("mode" , ["FAST_RUN" , "FAST_COMPILE" ])
26
+ def test_mlp (self , mode ):
25
27
m = models .Mlp ()
26
- f = function (m .inputs , m .outputs )
28
+ f = function (m .inputs , m .outputs , mode = mode )
27
29
pdf = PyDotFormatter ()
28
30
graph = pdf (f )
29
- expected = 11
30
- if config .mode == "FAST_COMPILE" :
31
- expected = 12
31
+ if mode == "FAST_RUN" :
32
+ expected = 13
33
+ elif mode == "FAST_COMPILE" :
34
+ expected = 14
32
35
assert len (graph .get_nodes ()) == expected
33
36
nc = self .node_counts (graph )
34
37
35
- if config . mode == "FAST_COMPILE " :
36
- assert nc ["apply" ] == 6
37
- else :
38
- assert nc ["apply" ] == 5
38
+ if mode == "FAST_RUN " :
39
+ assert nc ["apply" ] == 7
40
+ elif mode == "FAST_COMPILE" :
41
+ assert nc ["apply" ] == 8
39
42
assert nc ["output" ] == 1
40
43
41
44
def test_ofg (self ):
You can’t perform that action at this time.
0 commit comments