|
22 | 22 | from tests.link.test_link import make_function
|
23 | 23 |
|
24 | 24 |
|
25 |
| -# class TestMaxAndArgmax: |
26 |
| -# def test_optimization(self): |
27 |
| -# # If we use only the max output, we should replace this op with |
28 |
| -# # a faster one. |
29 |
| -# mode = pytensor.compile.mode.get_default_mode().including( |
30 |
| -# "canonicalize", "fast_run" |
31 |
| -# ) |
32 |
| - |
33 |
| -# for axis in [0, 1, -1]: |
34 |
| -# n = matrix() |
35 |
| - |
36 |
| -# f = function([n], max_and_argmax(n, axis)[0], mode=mode) |
37 |
| -# topo = f.maker.fgraph.toposort() |
38 |
| -# assert len(topo) == 1 |
39 |
| -# assert isinstance(topo[0].op, CAReduce) |
40 |
| - |
41 |
| -# f = function([n], max_and_argmax(n, axis), mode=mode) |
42 |
| -# topo = f.maker.fgraph.toposort() |
43 |
| -# assert len(topo) == 1 |
44 |
| -# assert isinstance(topo[0].op, MaxAndArgmax) |
45 |
| - |
46 |
| - |
47 | 25 | class TestMinMax:
|
48 | 26 | def setup_method(self):
|
49 | 27 | self.mode = pytensor.compile.mode.get_default_mode().including(
|
50 | 28 | "canonicalize", "fast_run"
|
51 | 29 | )
|
52 | 30 |
|
53 |
| - # def test_optimization_max(self): |
54 |
| - # data = np.asarray(np.random.random((2, 3)), dtype=config.floatX) |
55 |
| - # n = matrix() |
56 |
| - |
57 |
| - # for axis in [0, 1, -1]: |
58 |
| - # f = function([n], pt_max(n, axis), mode=self.mode) |
59 |
| - # topo = f.maker.fgraph.toposort() |
60 |
| - # assert len(topo) == 1 |
61 |
| - # # assert isinstance(topo[0].op, CAReduce) |
62 |
| - # f(data) |
63 |
| - |
64 |
| - # f = function([n], pt_max(-n, axis), mode=self.mode) |
65 |
| - # topo = f.maker.fgraph.toposort() |
66 |
| - # import pytensor |
67 |
| - # pytensor.dprint(topo) |
68 |
| - # assert len(topo) == 2 |
69 |
| - # assert isinstance(topo[0].op, Elemwise) |
70 |
| - # assert isinstance(topo[0].op.scalar_op, ps.Neg) |
71 |
| - # assert isinstance(topo[1].op, CAReduce) |
72 |
| - # f(data) |
73 |
| - |
74 |
| - # f = function([n], -pt_max(n, axis), mode=self.mode) |
75 |
| - # topo = f.maker.fgraph.toposort() |
76 |
| - # assert len(topo) == 2 |
77 |
| - # assert isinstance(topo[0].op, CAReduce) |
78 |
| - # assert isinstance(topo[1].op, Elemwise) |
79 |
| - # assert isinstance(topo[1].op.scalar_op, ps.Neg) |
80 |
| - # f(data) |
81 |
| - |
82 |
| - # f = function([n], -pt_max(-n, axis), mode=self.mode) |
83 |
| - # topo = f.maker.fgraph.toposort() |
84 |
| - # assert len(topo) == 1 |
85 |
| - # assert isinstance(topo[0].op, CAReduce) # min |
86 |
| - # f(data) |
87 |
| - |
88 | 31 | def test_optimization_min(self):
|
89 | 32 | data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
|
90 | 33 | n = matrix()
|
|
0 commit comments