Skip to content

Commit 013b66a

Browse files
Buffer Changes
1 parent e18c236 commit 013b66a

File tree

6 files changed

+132
-17
lines changed

6 files changed

+132
-17
lines changed

pytensor/graph/op.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def __call__(
291291
292292
"""
293293
node = self.make_node(*inputs, **kwargs)
294+
294295
if name is not None:
295296
if len(node.outputs) == 1:
296297
node.outputs[0].name = name

pytensor/ifelse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ def cond_make_inplace(fgraph, node):
477477
Reshape,
478478
Unbroadcast,
479479
pt.math.Dot,
480-
pt.math.TensorMax,
480+
pt.math.Max,
481481
pt.math.Argmax,
482482
pt.subtensor.Subtensor,
483483
pt.subtensor.IncSubtensor,

pytensor/tensor/math.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,8 @@ def max_and_argmax(a, axis=None, keepdims=False):
808808
axis = check_and_normalize_axes(a, axis)
809809
if len(axis) == 0:
810810
axis = list(range(a.type.ndim))
811-
out = TensorMax(axis)(a)
811+
# out = TensorMax(axis)(a)
812+
out = Max(axis)(a)
812813
argout = Argmax(axis)(a)
813814
# out, argout = MaxAndArgmax(axis)(a)
814815

@@ -864,6 +865,105 @@ def clone(self, **kwargs):
864865
axis = kwargs.get("axis", self.axis)
865866
return type(self)(axis=axis)
866867

868+
def grad(self, inp, grads):
869+
# The strict sense mathematical gradient of the maximum function is
870+
# not calculated here for it is not defined at every point where some
871+
# coordinates are identical. However, since the latter set has null
872+
# Lebesgue measure, the result may be interpreted as weak gradient.
873+
874+
# @note: This function should work correctly for L{vector}s.
875+
# (x, y), (gz, gw)
876+
# gz*dz/dx + gw*dw/dx, gz*dz/dy + gw*dw/dy
877+
# gMax * dMax/dx + gArgMax * dArgMax/dx,
878+
# gMax * dMax/daxis + gArgMax * dArgMax/daxis
879+
# g_max has one less dimension than x, so you need to complete
880+
# g_max to x's shape when axis=0 the broadcasting mechanism
881+
# does it automatically
882+
x = inp[0]
883+
axis = as_tensor_variable(self.axis)
884+
(g_max,) = grads
885+
886+
g_max_disconnected = isinstance(g_max.type, DisconnectedType)
887+
888+
# if the op is totally disconnected, so are its inputs
889+
if g_max_disconnected:
890+
return [DisconnectedType()()]
891+
892+
if NoneConst.equals(axis):
893+
axis_ = list(range(x.ndim))
894+
else:
895+
axis_ = axis
896+
xmax = max(x, axis_)
897+
898+
# Raise the g_max and xmax to the same number of dim as the input.
899+
pattern = []
900+
out_dim = 0
901+
if NoneConst.equals(axis):
902+
# We are taking the max/argmax over all dimensions.
903+
axis = None
904+
for i in range(x.ndim):
905+
if axis is None or i in axis.data:
906+
pattern.append("x")
907+
else:
908+
pattern.append(out_dim)
909+
out_dim += 1
910+
g_max_pad = DimShuffle(g_max.broadcastable, pattern)(g_max)
911+
xmax_pad = DimShuffle(xmax.broadcastable, pattern)(xmax)
912+
913+
# Set the grad to the correct position.
914+
g_x = eq(xmax_pad, x) * g_max_pad
915+
return (g_x,)
916+
917+
def c_code(self, node, name, inp, out, sub):
918+
if len(self.axis) != 1 and len(self.axis) != node.inputs[0].ndim:
919+
raise NotImplementedError(
920+
"NumPy C-API can compute max only for 1 axis or for all axes."
921+
)
922+
x = inp[0]
923+
axis = sub["params"]
924+
# max, argmax = out
925+
(max,) = out
926+
fail = sub["fail"]
927+
ret = """
928+
#if PY_MAJOR_VERSION >= 3
929+
#ifndef PyInt_AS_LONG
930+
#define PyInt_AS_LONG PyLong_AS_LONG
931+
#endif
932+
#endif
933+
934+
int axis;
935+
936+
if (PyTuple_GET_SIZE(%(axis)s) == PyArray_NDIM(%(x)s)) {
937+
axis = NPY_MAXDIMS;
938+
} else if(PyTuple_GET_SIZE(%(axis)s) == 1) {
939+
PyObject* axis_object = PyTuple_GET_ITEM(%(axis)s, 0);
940+
axis = (int)PyInt_AS_LONG(axis_object);
941+
if (axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)) {
942+
PyErr_SetString(PyExc_ValueError,
943+
"TensorMax: bad axis argument");
944+
%(fail)s
945+
}
946+
} else {
947+
PyErr_SetString(PyExc_NotImplementedError,
948+
"TensorMax: NumPy C-API can compute max only for 1 axis or for all axes.");
949+
%(fail)s
950+
}
951+
952+
Py_CLEAR(%(max)s);
953+
954+
%(max)s = (PyArrayObject*)PyArray_Max(%(x)s, axis, NULL);
955+
if (%(max)s == NULL) {
956+
%(fail)s;
957+
}
958+
if (!PyArray_CheckExact(%(max)s)) {
959+
%(max)s = (PyArrayObject*)PyArray_FromAny((PyObject*)%(max)s, NULL, 0, 0, NPY_ARRAY_ENSUREARRAY, NULL);
960+
if(%(max)s == NULL){
961+
%(fail)s;
962+
}
963+
}
964+
"""
965+
return ret % locals()
966+
867967

868968
class Min(NonZeroDimsCAReduce):
869969
nfunc_spec = ("min", 1, 1)
@@ -956,6 +1056,12 @@ def min(x, axis=None, keepdims=False):
9561056
elif str_x_type in uint_dtypes:
9571057
itype = np.iinfo(x.dtype)
9581058
max_val = np.array(itype.max, dtype=itype.dtype)
1059+
# print('a')
1060+
# for c in (max_val - x):
1061+
# print(c.eval())
1062+
# print()
1063+
# print(max(max_val - x, axis=axis, keepdims=keepdims).eval())
1064+
# print()
9591065
return max_val - max(max_val - x, axis=axis, keepdims=keepdims)
9601066
elif str_x_type == "bool":
9611067
return ~max(~x, axis=axis, keepdims=keepdims)

pytensor/tensor/rewriting/uncanonicalize.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
3232
"""
3333

34+
from pytensor import scalar as ps
3435
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
3536
from pytensor.tensor.basic import Alloc, alloc, constant
36-
from pytensor.tensor.elemwise import DimShuffle
37-
from pytensor.tensor.math import Min, TensorMax, neg
37+
from pytensor.tensor.elemwise import CAReduce, DimShuffle
38+
from pytensor.tensor.math import Min, neg
3839
from pytensor.tensor.rewriting.basic import register_uncanonicalize
3940
from pytensor.tensor.shape import Reshape, reshape
4041
from pytensor.tensor.subtensor import Subtensor
@@ -54,13 +55,14 @@ def local_max_to_min(fgraph, node):
5455
the interface put only MaxAndArgmax into the graph.
5556
5657
"""
57-
# pytensor.dprint(node)
58-
# print()
59-
# print(node.op == neg)
6058
if node.op == neg and node.inputs[0].owner:
6159
max = node.inputs[0]
6260
# print(max.owner.op.scalar_op)
63-
if max.owner and isinstance(max.owner.op, TensorMax):
61+
if (
62+
max.owner
63+
and isinstance(max.owner.op, CAReduce)
64+
and max.owner.op.scalar_op == ps.scalar_maximum
65+
):
6466
neg_node = max.owner.inputs[0]
6567
if neg_node.owner and neg_node.owner.op == neg:
6668
new = Min(max.owner.op.axis)(neg_node.owner.inputs[0])

tests/tensor/rewriting/test_uncanonicalize.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from pytensor.graph.rewriting.basic import out2in
1010
from pytensor.link.basic import PerformLinker
1111
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
12-
from pytensor.tensor.math import TensorMax
1312
from pytensor.tensor.math import min as pt_min
1413
from pytensor.tensor.rewriting.uncanonicalize import (
1514
local_alloc_dimshuffle,
@@ -43,7 +42,7 @@ def test_optimization_min(self):
4342
f = function([n], pt_min(-n, axis), mode=self.mode)
4443
topo = f.maker.fgraph.toposort()
4544
assert len(topo) == 2
46-
assert isinstance(topo[0].op, TensorMax) # max
45+
assert isinstance(topo[0].op, CAReduce) # max
4746
assert isinstance(topo[1].op, Elemwise)
4847
assert isinstance(topo[1].op.scalar_op, ps.Neg)
4948
f(data)
@@ -53,13 +52,13 @@ def test_optimization_min(self):
5352
assert len(topo) == 2
5453
assert isinstance(topo[0].op, Elemwise)
5554
assert isinstance(topo[0].op.scalar_op, ps.Neg)
56-
assert isinstance(topo[1].op, TensorMax) # max
55+
assert isinstance(topo[1].op, CAReduce) # max
5756
f(data)
5857

5958
f = function([n], -pt_min(-n, axis), mode=self.mode)
6059
topo = f.maker.fgraph.toposort()
6160
assert len(topo) == 1
62-
assert isinstance(topo[0].op, TensorMax) # max
61+
assert isinstance(topo[0].op, CAReduce) # max
6362
f(data)
6463

6564

tests/tensor/test_max_argmax.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
)
1717
from pytensor.tensor.math import (
1818
Argmax,
19-
TensorMax,
19+
Max,
2020
argmax,
2121
argmin,
2222
max,
@@ -43,7 +43,7 @@
4343

4444
class TestMaxAndArgmax:
4545
def setup_method(self):
46-
TensorMax.debug = 0
46+
Max.debug = 0
4747

4848
def test_basic(self):
4949
# dbt: for some reason, Argmax does not work when I pass: n = as_tensor_variable(5.0)
@@ -317,7 +317,8 @@ def test_vectorize(self, core_axis, batch_axis):
317317
# Test MaxAndArgmax
318318
max_x, argmax_x = max_and_argmax(x, axis=core_axis)
319319
node = max_x.owner
320-
assert isinstance(node.op, TensorMax)
320+
# assert isinstance(node.op, TensorMax)
321+
assert isinstance(node.op, Max)
321322

322323
# dbt: how to make Argmax user facing?
323324
# new_node = vectorize_node(node, batch_x)
@@ -339,7 +340,7 @@ def test_vectorize(self, core_axis, batch_axis):
339340

340341
class TestArgminArgmax:
341342
def setup_method(self):
342-
TensorMax.debug = 0
343+
Argmax.debug = 0
343344

344345
def test_scalar(self):
345346
for fct in [argmin, argmax]:
@@ -501,7 +502,7 @@ def test_bool(self):
501502

502503
class TestMinMax:
503504
def setup_method(self):
504-
TensorMax.debug = 0
505+
Max.debug = 0
505506

506507
def test_scalar(self):
507508
for fct in [max, min]:
@@ -675,6 +676,12 @@ def test_uint(self):
675676
n = as_tensor_variable(data)
676677
assert min(n).dtype == dtype
677678
i = eval_outputs(min(n))
679+
# pytensor.dprint(n)
680+
for x in n:
681+
print(x.eval())
682+
print(i)
683+
print(itype.min)
684+
print()
678685
assert i == itype.min
679686
assert max(n).dtype == dtype
680687
i = eval_outputs(max(n))

0 commit comments

Comments
 (0)