Skip to content

Commit f4de2fd

Browse files
michaelosthegeferrine
authored andcommitted
Run pydot/graphviz tests in CI
Closes #151
1 parent f92e109 commit f4de2fd

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

environment.yml

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,5 @@ dependencies:
5050
- typing_extensions
5151
# optional
5252
- cython
53-
53+
- graphviz
54+
- pydot

tests/d3viz/test_d3viz.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor import compile
1010
from pytensor.compile.function import function
1111
from pytensor.configdefaults import config
12-
from pytensor.d3viz.formatting import pydot_imported, pydot_imported_msg
12+
from pytensor.printing import pydot_imported, pydot_imported_msg
1313
from tests.d3viz import models
1414

1515

tests/d3viz/test_formatting.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
import pytest
33

44
from pytensor import config, function
5-
from pytensor.d3viz.formatting import PyDotFormatter, pydot_imported, pydot_imported_msg
5+
from pytensor.d3viz.formatting import PyDotFormatter
6+
from pytensor.printing import pydot_imported, pydot_imported_msg
67

78

89
if not pydot_imported:
@@ -21,21 +22,23 @@ def node_counts(self, graph):
2122
nc = dict(zip(a, b))
2223
return nc
2324

24-
def test_mlp(self):
25+
@pytest.mark.parametrize("mode", ["FAST_RUN", "FAST_COMPILE"])
26+
def test_mlp(self, mode):
2527
m = models.Mlp()
26-
f = function(m.inputs, m.outputs)
28+
f = function(m.inputs, m.outputs, mode=mode)
2729
pdf = PyDotFormatter()
2830
graph = pdf(f)
29-
expected = 11
30-
if config.mode == "FAST_COMPILE":
31-
expected = 12
31+
if mode == "FAST_RUN":
32+
expected = 13
33+
elif mode == "FAST_COMPILE":
34+
expected = 14
3235
assert len(graph.get_nodes()) == expected
3336
nc = self.node_counts(graph)
3437

35-
if config.mode == "FAST_COMPILE":
36-
assert nc["apply"] == 6
37-
else:
38-
assert nc["apply"] == 5
38+
if mode == "FAST_RUN":
39+
assert nc["apply"] == 7
40+
elif mode == "FAST_COMPILE":
41+
assert nc["apply"] == 8
3942
assert nc["output"] == 1
4043

4144
def test_ofg(self):

0 commit comments

Comments
 (0)