Skip to content

Commit e6e6d69

Browse files
Break MaxandArgmax Op to seperate TensorMax Op and Argmax Op (#731)
* Break MaxandArgmax to TensorMax and Argmax seperately * XFAIL pytensor tests for uint64 data type * Deprecate and raise AttributeError for MaxAndArgmax
1 parent dbe0e09 commit e6e6d69

File tree

10 files changed

+254
-371
lines changed

10 files changed

+254
-371
lines changed

pytensor/ifelse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,8 @@ def cond_make_inplace(fgraph, node):
477477
Reshape,
478478
Unbroadcast,
479479
pt.math.Dot,
480-
pt.math.MaxAndArgmax,
480+
pt.math.Max,
481+
pt.math.Argmax,
481482
pt.subtensor.Subtensor,
482483
pt.subtensor.IncSubtensor,
483484
pt.basic.Alloc,

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.blas import BatchedDot
5-
from pytensor.tensor.math import Dot, MaxAndArgmax
5+
from pytensor.tensor.math import Argmax, Dot, Max
66
from pytensor.tensor.nlinalg import (
77
SVD,
88
Det,
@@ -104,18 +104,28 @@ def batched_dot(a, b):
104104
return batched_dot
105105

106106

107-
@jax_funcify.register(MaxAndArgmax)
108-
def jax_funcify_MaxAndArgmax(op, **kwargs):
107+
@jax_funcify.register(Max)
108+
def jax_funcify_Max(op, **kwargs):
109109
axis = op.axis
110110

111-
def maxandargmax(x, axis=axis):
111+
def max(x):
112+
max_res = jnp.max(x, axis)
113+
114+
return max_res
115+
116+
return max
117+
118+
119+
@jax_funcify.register(Argmax)
120+
def jax_funcify_Argmax(op, **kwargs):
121+
axis = op.axis
122+
123+
def argmax(x):
112124
if axis is None:
113125
axes = tuple(range(x.ndim))
114126
else:
115127
axes = tuple(int(ax) for ax in axis)
116128

117-
max_res = jnp.max(x, axis)
118-
119129
# NumPy does not support multiple axes for argmax; this is a
120130
# work-around
121131
keep_axes = jnp.array(
@@ -138,6 +148,6 @@ def maxandargmax(x, axis=axis):
138148

139149
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
140150

141-
return max_res, max_idx_res
151+
return max_idx_res
142152

143-
return maxandargmax
153+
return argmax

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
)
4545
from pytensor.scalar.basic import add as add_as
4646
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
47-
from pytensor.tensor.math import MaxAndArgmax, MulWithoutZeros, Sum
47+
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
4848
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
4949
from pytensor.tensor.type import scalar
5050

@@ -827,8 +827,8 @@ def log_softmax_py_fn(x):
827827
return log_softmax
828828

829829

830-
@numba_funcify.register(MaxAndArgmax)
831-
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
830+
@numba_funcify.register(Argmax)
831+
def numba_funcify_Argmax(op, node, **kwargs):
832832
axis = op.axis
833833
x_at = node.inputs[0]
834834
x_dtype = x_at.type.numpy_dtype
@@ -838,8 +838,8 @@ def numba_funcify_MaxAndArgmax(op, node, **kwargs):
838838
if x_ndim == 0:
839839

840840
@numba_basic.numba_njit(inline="always")
841-
def maxandargmax(x):
842-
return x, 0
841+
def argmax(x):
842+
return 0
843843

844844
else:
845845
axes = tuple(int(ax) for ax in axis)
@@ -848,20 +848,6 @@ def maxandargmax(x):
848848
# work-around
849849
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
850850

851-
reduce_max_py_fn = create_multiaxis_reducer(
852-
scalar_maximum,
853-
-np.inf,
854-
axes,
855-
x_ndim,
856-
x_dtype,
857-
return_scalar=False,
858-
)
859-
reduce_max = jit_compile_reducer(
860-
Apply(node.op, node.inputs, [node.outputs[0].clone()]),
861-
reduce_max_py_fn,
862-
reduce_to_scalar=False,
863-
)
864-
865851
reduced_x_ndim = x_ndim - len(axes) + 1
866852
argmax_axis = create_axis_apply_fn(
867853
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
@@ -872,9 +858,7 @@ def maxandargmax(x):
872858
sl2 = slice(len(keep_axes), None)
873859

874860
@numba_basic.numba_njit
875-
def maxandargmax(x):
876-
max_res = reduce_max(x)
877-
861+
def argmax(x):
878862
# Not-reduced axes in front
879863
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
880864
kept_shape = transposed_x.shape[sl1]
@@ -890,6 +874,6 @@ def maxandargmax(x):
890874

891875
max_idx_res = argmax_axis(reshaped_x)
892876

893-
return max_res, max_idx_res
877+
return max_idx_res
894878

895-
return maxandargmax
879+
return argmax

0 commit comments

Comments
 (0)