Skip to content

Commit e18c236

Browse files
Commented code resolved
1 parent 1a8333f commit e18c236

File tree

2 files changed

+0
-78
lines changed

2 files changed

+0
-78
lines changed

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,33 +34,12 @@
3434
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3535
from pytensor.tensor.basic import Alloc, alloc, constant
3636
from pytensor.tensor.elemwise import DimShuffle
37-
38-
# from pytensor.tensor.math import Argmax, Max, MaxAndArgmax, Min, neg
3937
from pytensor.tensor.math import Min, TensorMax, neg
4038
from pytensor.tensor.rewriting.basic import register_uncanonicalize
4139
from pytensor.tensor.shape import Reshape, reshape
4240
from pytensor.tensor.subtensor import Subtensor
4341

4442

45-
# @register_uncanonicalize
46-
# @node_rewriter([MaxAndArgmax])
47-
# def local_max_and_argmax(fgraph, node):
48-
# """
49-
# If we don't use the argmax, change it to a max only.
50-
# """
51-
# if isinstance(node.op, MaxAndArgmax):
52-
# axis = node.op.axis
53-
# if len(fgraph.clients[node.outputs[1]]) == 0:
54-
# new = Max(axis)(node.inputs[0])
55-
# copy_stack_trace(node.outputs[0], new)
56-
# return [new, None]
57-
58-
# if len(fgraph.clients[node.outputs[0]]) == 0:
59-
# new = Argmax(axis)(node.inputs[0])
60-
# copy_stack_trace(node.outputs[0], new)
61-
# return [None, new]
62-
63-
6443
@register_uncanonicalize
6544
@node_rewriter([neg])
6645
def local_max_to_min(fgraph, node):

tests/tensor/rewriting/test_uncanonicalize.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -22,69 +22,12 @@
2222
from tests.link.test_link import make_function
2323

2424

25-
# class TestMaxAndArgmax:
26-
# def test_optimization(self):
27-
# # If we use only the max output, we should replace this op with
28-
# # a faster one.
29-
# mode = pytensor.compile.mode.get_default_mode().including(
30-
# "canonicalize", "fast_run"
31-
# )
32-
33-
# for axis in [0, 1, -1]:
34-
# n = matrix()
35-
36-
# f = function([n], max_and_argmax(n, axis)[0], mode=mode)
37-
# topo = f.maker.fgraph.toposort()
38-
# assert len(topo) == 1
39-
# assert isinstance(topo[0].op, CAReduce)
40-
41-
# f = function([n], max_and_argmax(n, axis), mode=mode)
42-
# topo = f.maker.fgraph.toposort()
43-
# assert len(topo) == 1
44-
# assert isinstance(topo[0].op, MaxAndArgmax)
45-
46-
4725
class TestMinMax:
4826
def setup_method(self):
4927
self.mode = pytensor.compile.mode.get_default_mode().including(
5028
"canonicalize", "fast_run"
5129
)
5230

53-
# def test_optimization_max(self):
54-
# data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
55-
# n = matrix()
56-
57-
# for axis in [0, 1, -1]:
58-
# f = function([n], pt_max(n, axis), mode=self.mode)
59-
# topo = f.maker.fgraph.toposort()
60-
# assert len(topo) == 1
61-
# # assert isinstance(topo[0].op, CAReduce)
62-
# f(data)
63-
64-
# f = function([n], pt_max(-n, axis), mode=self.mode)
65-
# topo = f.maker.fgraph.toposort()
66-
# import pytensor
67-
# pytensor.dprint(topo)
68-
# assert len(topo) == 2
69-
# assert isinstance(topo[0].op, Elemwise)
70-
# assert isinstance(topo[0].op.scalar_op, ps.Neg)
71-
# assert isinstance(topo[1].op, CAReduce)
72-
# f(data)
73-
74-
# f = function([n], -pt_max(n, axis), mode=self.mode)
75-
# topo = f.maker.fgraph.toposort()
76-
# assert len(topo) == 2
77-
# assert isinstance(topo[0].op, CAReduce)
78-
# assert isinstance(topo[1].op, Elemwise)
79-
# assert isinstance(topo[1].op.scalar_op, ps.Neg)
80-
# f(data)
81-
82-
# f = function([n], -pt_max(-n, axis), mode=self.mode)
83-
# topo = f.maker.fgraph.toposort()
84-
# assert len(topo) == 1
85-
# assert isinstance(topo[0].op, CAReduce) # min
86-
# f(data)
87-
8831
def test_optimization_min(self):
8932
data = np.asarray(np.random.random((2, 3)), dtype=config.floatX)
9033
n = matrix()

0 commit comments

Comments
 (0)