Skip to content

Commit c73f5cc

Browse files
committed
Lower precision in TestFusion parametrization
This was not an issue in my local machine, but failed on the Github CI. It could be due to compiler optimizations. Case 69 used to look like this: ```python Elemwise{Composite{(i0 * tan(i0) * tan(i0) * i1)}} [id C] |x [id A] |x [id A] ``` And now looks like this ```python Elemwise{Composite{(i0 * tan(i0) * tan(i0) * i0)}} [id C] |x [id A] [None] ```
1 parent 716c50e commit c73f5cc

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

tests/tensor/rewriting/test_elemwise.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,7 @@ def my_init(dtype="float64", num=0):
882882
1,
883883
fxv * np.tan(fxv) * np.tan(fxv) * fxv,
884884
"float32",
885+
1e-5,
885886
),
886887
(
887888
mul(ftanx, ftanx, fx + fy),
@@ -890,6 +891,7 @@ def my_init(dtype="float64", num=0):
890891
1,
891892
np.tan(fxv) * np.tan(fxv) * (fxv + fyv),
892893
"float32",
894+
1e-5,
893895
), # 70
894896
# Cases with different broadcast pattern. They should not
895897
# be merged as this would duplicate computation
@@ -973,7 +975,11 @@ def my_init(dtype="float64", num=0):
973975
def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
974976
"""Verify that `Elemwise` fusion works."""
975977

976-
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
978+
if len(case) == 6:
979+
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype = case
980+
atol = None
981+
else:
982+
g, sym_inputs, val_inputs, nb_elemwise, answer, out_dtype, atol = case
977983

978984
if isinstance(out_dtype, dict):
979985
out_dtype = out_dtype[config.cast_policy]
@@ -1000,9 +1006,10 @@ def test_elemwise_fusion(self, case, nb_repeat=1, assert_len_topo=True):
10001006
f(*val_inputs)
10011007
out = [o.get_value() for o in out]
10021008

1003-
atol = 1e-8
1004-
if any(o == "float32" for o in out_dtype):
1005-
atol = 1e-6
1009+
if atol is None:
1010+
atol = 1e-8
1011+
if any(o == "float32" for o in out_dtype):
1012+
atol = 1e-6
10061013

10071014
for o, a in zip(out, answer):
10081015
np.testing.assert_allclose(o, a * nb_repeat, atol=atol)

0 commit comments

Comments
 (0)