Skip to content

Commit abc44ab

Browse files
Switch to numpy assert_allclose
1 parent 035139e commit abc44ab

File tree

1 file changed

+17
-17
lines changed

1 file changed

+17
-17
lines changed

tests/tensor/rewriting/test_math.py

+17-17
Original file line numberDiff line numberDiff line change
@@ -2511,23 +2511,23 @@ def test_local_sum_prod_all_to_none(self):
25112511
# test sum
25122512
f = function([a], a.sum(), mode=self.mode)
25132513
assert len(f.maker.fgraph.apply_nodes) == 1
2514-
utt.assert_allclose(f(input), input.sum())
2514+
np.testing.assert_allclose(f(input), input.sum())
25152515
# test prod
25162516
f = function([a], a.prod(), mode=self.mode)
25172517
assert len(f.maker.fgraph.apply_nodes) == 1
2518-
utt.assert_allclose(f(input), input.prod())
2518+
np.testing.assert_allclose(f(input), input.prod())
25192519
# test sum
25202520
f = function([a], a.sum([0, 1, 2]), mode=self.mode)
25212521
assert len(f.maker.fgraph.apply_nodes) == 1
2522-
utt.assert_allclose(f(input), input.sum())
2522+
np.testing.assert_allclose(f(input), input.sum())
25232523
# test prod
25242524
f = function([a], a.prod([0, 1, 2]), mode=self.mode)
25252525
assert len(f.maker.fgraph.apply_nodes) == 1
2526-
utt.assert_allclose(f(input), input.prod())
2526+
np.testing.assert_allclose(f(input), input.prod())
25272527

25282528
f = function([a], a.sum(0).sum(0).sum(0), mode=self.mode)
25292529
assert len(f.maker.fgraph.apply_nodes) == 1
2530-
utt.assert_allclose(f(input), input.sum())
2530+
np.testing.assert_allclose(f(input), input.sum())
25312531

25322532
def test_local_sum_sum_prod_prod(self):
25332533
a = tensor3()
@@ -2582,54 +2582,54 @@ def my_sum_prod(data, d, dd):
25822582
for d, dd in dims:
25832583
expected = my_sum(input, d, dd)
25842584
f = function([a], a.sum(d).sum(dd), mode=self.mode)
2585-
utt.assert_allclose(f(input), expected)
2585+
np.testing.assert_allclose(f(input), expected)
25862586
assert len(f.maker.fgraph.apply_nodes) == 1
25872587
for d, dd in dims[:6]:
25882588
f = function([a], a.sum(d).sum(dd).sum(0), mode=self.mode)
2589-
utt.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
2589+
np.testing.assert_allclose(f(input), input.sum(d).sum(dd).sum(0))
25902590
assert len(f.maker.fgraph.apply_nodes) == 1
25912591
for d in [0, 1, 2]:
25922592
f = function([a], a.sum(d).sum(None), mode=self.mode)
2593-
utt.assert_allclose(f(input), input.sum(d).sum())
2593+
np.testing.assert_allclose(f(input), input.sum(d).sum())
25942594
assert len(f.maker.fgraph.apply_nodes) == 1
25952595
f = function([a], a.sum(None).sum(), mode=self.mode)
2596-
utt.assert_allclose(f(input), input.sum())
2596+
np.testing.assert_allclose(f(input), input.sum())
25972597
assert len(f.maker.fgraph.apply_nodes) == 1
25982598

25992599
# test prod
26002600
for d, dd in dims:
26012601
expected = my_prod(input, d, dd)
26022602
f = function([a], a.prod(d).prod(dd), mode=self.mode)
2603-
utt.assert_allclose(f(input), expected)
2603+
np.testing.assert_allclose(f(input), expected)
26042604
assert len(f.maker.fgraph.apply_nodes) == 1
26052605
for d, dd in dims[:6]:
26062606
f = function([a], a.prod(d).prod(dd).prod(0), mode=self.mode)
2607-
utt.assert_allclose(f(input), input.prod(d).prod(dd).prod(0))
2607+
np.testing.assert_allclose(f(input), input.prod(d).prod(dd).prod(0))
26082608
assert len(f.maker.fgraph.apply_nodes) == 1
26092609
for d in [0, 1, 2]:
26102610
f = function([a], a.prod(d).prod(None), mode=self.mode)
2611-
utt.assert_allclose(f(input), input.prod(d).prod())
2611+
np.testing.assert_allclose(f(input), input.prod(d).prod())
26122612
assert len(f.maker.fgraph.apply_nodes) == 1
26132613
f = function([a], a.prod(None).prod(), mode=self.mode)
2614-
utt.assert_allclose(f(input), input.prod())
2614+
np.testing.assert_allclose(f(input), input.prod())
26152615
assert len(f.maker.fgraph.apply_nodes) == 1
26162616

26172617
# Test that sum prod didn't get rewritten.
26182618
for d, dd in dims:
26192619
expected = my_sum_prod(input, d, dd)
26202620
f = function([a], a.sum(d).prod(dd), mode=self.mode)
2621-
utt.assert_allclose(f(input), expected)
2621+
np.testing.assert_allclose(f(input), expected)
26222622
assert len(f.maker.fgraph.apply_nodes) == 2
26232623
for d, dd in dims[:6]:
26242624
f = function([a], a.sum(d).prod(dd).prod(0), mode=self.mode)
2625-
utt.assert_allclose(f(input), input.sum(d).prod(dd).prod(0))
2625+
np.testing.assert_allclose(f(input), input.sum(d).prod(dd).prod(0))
26262626
assert len(f.maker.fgraph.apply_nodes) == 2
26272627
for d in [0, 1, 2]:
26282628
f = function([a], a.sum(d).prod(None), mode=self.mode)
2629-
utt.assert_allclose(f(input), input.sum(d).prod())
2629+
np.testing.assert_allclose(f(input), input.sum(d).prod())
26302630
assert len(f.maker.fgraph.apply_nodes) == 2
26312631
f = function([a], a.sum(None).prod(), mode=self.mode)
2632-
utt.assert_allclose(f(input), input.sum())
2632+
np.testing.assert_allclose(f(input), input.sum())
26332633
assert len(f.maker.fgraph.apply_nodes) == 1
26342634

26352635
def test_local_sum_sum_int8(self):

0 commit comments

Comments
 (0)