Skip to content

Commit 8cbd984

Browse files
Add basic optimization canonicalizations to "fast_compile" mode by default
1 parent 842d3bc commit 8cbd984

File tree

7 files changed

+15
-65
lines changed

7 files changed

+15
-65
lines changed

aesara/tensor/basic_opt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,9 @@ def register(inner_lopt):
507507
return register
508508
else:
509509
name = kwargs.pop("name", None) or lopt.__name__
510-
compile.optdb["canonicalize"].register(name, lopt, "fast_run", *tags, **kwargs)
510+
compile.optdb["canonicalize"].register(
511+
name, lopt, "fast_run", "fast_compile", *tags, **kwargs
512+
)
511513
return lopt
512514

513515

tests/d3viz/test_formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def test_ofg(self):
4848
assert len(sub_graphs) == 2
4949
ofg1, ofg2 = sub_graphs
5050
if config.mode == "FAST_COMPILE":
51-
assert len(ofg1.get_nodes()) == 9
51+
assert len(ofg1.get_nodes()) == 8
5252
else:
5353
assert len(ofg1.get_nodes()) == 5
5454
assert len(ofg1.get_nodes()) == len(ofg2.get_nodes())

tests/sparse/test_basic.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,18 +1563,8 @@ def f_b(x, y):
15631563

15641564
assert np.all(f_a(vx, vy) == f_b(vx, vy))
15651565
topo = f_a.maker.fgraph.toposort()
1566-
if aesara.config.mode != "FAST_COMPILE":
1567-
nb = 0
1568-
else:
1569-
nb = 1
1570-
assert (
1571-
sum(
1572-
[
1573-
isinstance(node.op, (Dot, Usmm, UsmmCscDense))
1574-
for node in topo
1575-
]
1576-
)
1577-
== nb
1566+
assert not any(
1567+
isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo
15781568
)
15791569

15801570
def test_int32_dtype(self):
@@ -1822,13 +1812,8 @@ def f_b(z, a, x, y):
18221812
)
18231813
assert all(f_shape(a_data, x_data, y_data) == f_b_out.shape)
18241814
topo = f_shape.maker.fgraph.toposort()
1825-
if aesara.config.mode != "FAST_COMPILE":
1826-
nb = 0
1827-
else:
1828-
nb = 1
1829-
assert (
1830-
sum([isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo])
1831-
== nb
1815+
assert not any(
1816+
isinstance(node.op, (Dot, Usmm, UsmmCscDense)) for node in topo
18321817
)
18331818

18341819

tests/tensor/test_basic.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,11 +1028,6 @@ def setup_method(self):
10281028
self.split_op_class = Split
10291029
self.make_vector_op = MakeVector()
10301030
self.floatX = config.floatX
1031-
self.hide_error = config.mode not in [
1032-
"DebugMode",
1033-
"DEBUG_MODE",
1034-
"FAST_COMPILE",
1035-
]
10361031
self.shared = shared
10371032

10381033
def eval_outputs_and_check_join(self, outputs):
@@ -1712,17 +1707,6 @@ def get_mat(s1, s2):
17121707
for node in topo:
17131708
assert not isinstance(node.op, type(self.join_op))
17141709

1715-
with config.change_flags(compute_test_value="off"):
1716-
# Test hide error
1717-
x1.set_value(get_mat(3, 4))
1718-
x2.set_value(get_mat(3, 4))
1719-
x3.set_value(get_mat(2, 5))
1720-
if not self.hide_error:
1721-
with pytest.raises(ValueError):
1722-
f()
1723-
else:
1724-
f()
1725-
17261710
def test_rebroadcast(self):
17271711
# Regression test for a crash that used to happen when rebroadcasting.
17281712
x = TensorType(self.floatX, [False, False, True])()

tests/tensor/test_sharedvar.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -439,14 +439,6 @@ def test_specify_shape(self):
439439
with pytest.raises(AssertionError):
440440
specify_shape_fct()
441441

442-
# No assertion will be raised as the Op is removed from the graph
443-
# when their is optimization
444-
if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
445-
shape_constant_fct()
446-
else:
447-
with pytest.raises(AssertionError):
448-
shape_constant_fct()
449-
450442
def test_specify_shape_partial(self):
451443
dtype = self.dtype
452444
if dtype is None:
@@ -502,13 +494,6 @@ def test_specify_shape_partial(self):
502494
with pytest.raises(AssertionError):
503495
specify_shape_fct()
504496

505-
# No assertion will be raised as the Op is removed from the graph
506-
if aesara.config.mode not in ["FAST_COMPILE", "DebugMode", "DEBUG_MODE"]:
507-
shape_constant_fct()
508-
else:
509-
with pytest.raises(AssertionError):
510-
shape_constant_fct()
511-
512497
def test_specify_shape_inplace(self):
513498
# test that specify_shape don't break inserting inplace op
514499

tests/tensor/test_subtensor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,8 @@ def fct2(t):
10771077
else:
10781078
ops = subtensor_ops
10791079
if idx is idxs[0]:
1080-
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=2)
1080+
# TODO FIXME: This is a very poorly specified test.
1081+
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=0)
10811082
f()
10821083

10831084
def test_wrong_exception_regression(self):
@@ -1129,7 +1130,7 @@ def test_shape_list(self):
11291130
data = np.asarray(data, dtype=self.dtype)
11301131
n = self.shared(data)
11311132
t = n[idx]
1132-
f = self.function([], t.shape, op=subtensor_ops, N=0, N_fast=1)
1133+
f = self.function([], t.shape, op=subtensor_ops, N=0, N_fast=0)
11331134
val = f()
11341135
assert np.allclose(val, data[idx].shape)
11351136

tests/tensor/test_subtensor_opt.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,18 +1174,14 @@ def test_const6(self):
11741174
data = self.rng.uniform(size=(8, 8, 8)).astype(config.floatX)
11751175
x = tensor3("x")
11761176

1177-
nops = 1
1178-
if config.mode == "FAST_COMPILE":
1179-
nops = 2
1180-
11811177
# test 1)
11821178
y = x[3:6, 2:6, 1:7][1]
11831179
fun = function([x], y)
11841180
val = fun(data)
11851181
assert np.all(val == data[3:6, 2:6, 1:7][1])
11861182
assert (
11871183
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
1188-
== nops
1184+
== 1
11891185
)
11901186

11911187
# test 2)
@@ -1195,7 +1191,7 @@ def test_const6(self):
11951191
assert np.all(val == data[2, 3][1])
11961192
assert (
11971193
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
1198-
== nops
1194+
== 1
11991195
)
12001196

12011197
# test 3)
@@ -1205,7 +1201,7 @@ def test_const6(self):
12051201
assert np.all(val == data[3:6, 2, 1:7][1])
12061202
assert (
12071203
len([n for n in fun.maker.fgraph.toposort() if isinstance(n.op, Subtensor)])
1208-
== nops
1204+
== 1
12091205
)
12101206

12111207
def test_scalar6(self):
@@ -1590,10 +1586,7 @@ def test_incsubtensor_x_zeros(self):
15901586

15911587
assert len(inc_nodes) == 1
15921588
node_is_set_instead_of_inc = inc_nodes[0].op.set_instead_of_inc
1593-
mode = config.mode
1594-
assert (mode != "FAST_COMPILE" and node_is_set_instead_of_inc) or (
1595-
mode == "FAST_COMPILE" and not node_is_set_instead_of_inc
1596-
)
1589+
assert node_is_set_instead_of_inc
15971590
test_X = np.random.random((4, 4)).astype(config.floatX)
15981591
utt.assert_allclose(f(test_X), test_X)
15991592

0 commit comments

Comments
 (0)