Skip to content

Commit d41c2cf

Browse files
committed
Lower precision in TestFusion parametrization
This was not an issue in my local machine, but failed on the Github CI. It could be due to compiler optimizations. Case 69 used to look like this: ```python Elemwise{Composite{(i0 * tan(i0) * tan(i0) * i1)}} [id C] |x [id A] |x [id A] ``` And now looks like this ```python Elemwise{Composite{(i0 * tan(i0) * tan(i0) * i0)}} [id C] |x [id A] [None] ```
1 parent 941d4cf commit d41c2cf

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tests/tensor/rewriting/test_elemwise.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ def my_init(dtype="float64", num=0):
882882
1,
883883
fxv * np.tan(fxv) * np.tan(fxv) * fxv,
884884
"float32",
885+
1e-5,
885886
),
886887
(
887888
mul(ftanx, ftanx, fx + fy),
@@ -890,6 +891,7 @@ def my_init(dtype="float64", num=0):
890891
1,
891892
np.tan(fxv) * np.tan(fxv) * (fxv + fyv),
892893
"float32",
894+
1e-5,
893895
), # 70
894896
# Cases with different broadcast pattern. They should not
895897
# be merged as this would duplicate computation
@@ -978,7 +980,11 @@ def my_init(dtype="float64", num=0):
978980
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
979981
"""Verify that `Elemwise` fusion works."""
980982

981-
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
983+
if len(case) == 6:
984+
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
985+
atol = None
986+
else:
987+
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype, atol = case
982988

983989
if isinstance(out_dtype, dict):
984990
out_dtype = out_dtype[config.cast_policy]
@@ -1005,9 +1011,10 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
10051011
f(*val_inputs)
10061012
out = [o.get_value() for o in out]
10071013

1008-
atol = 1e-8
1009-
if any(o == "float32" for o in out_dtype):
1010-
atol = 1e-6
1014+
if atol is None:
1015+
atol = 1e-8
1016+
if any(o == "float32" for o in out_dtype):
1017+
atol = 1e-6
10111018

10121019
for o, a in zip(out, answer):
10131020
np.testing.assert_allclose(o, a * nb_repeat, atol=atol)

0 commit comments

Comments
 (0)