Skip to content

Commit 53cad9b

Browse files
covertgricardoV94
authored andcommitted
Propagate static shape in MaxAndArgmax
1 parent e37497f commit 53cad9b

File tree

2 files changed

+75
-52
lines changed

2 files changed

+75
-52
lines changed

pytensor/tensor/math.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,10 @@ def get_params(self, node):
142142
def make_node(self, x):
143143
x = as_tensor_variable(x)
144144

145-
# We keep the original broadcastable flags for dimensions on which
146-
# we do not perform the max / argmax.
145+
# Keep the original shapes for axes on which we do not perform the max/argmax.
147146
all_axes = set(self.axis)
148147
inputs = [x]
149-
out_shape = tuple(
150-
1 if s == 1 else None
151-
for i, s in enumerate(x.type.shape)
152-
if i not in all_axes
153-
)
148+
out_shape = tuple(s for i, s in enumerate(x.type.shape) if i not in all_axes)
154149
outputs = [
155150
tensor(dtype=x.type.dtype, shape=out_shape, name="max"),
156151
tensor(dtype="int64", shape=out_shape, name="argmax"),
@@ -1521,7 +1516,6 @@ def perform(self, node, inp, out):
15211516
output[0] = np.asarray(np.mean(input, dtype="float64", axis=axis))
15221517

15231518
def c_code(self, node, name, inames, onames, sub):
1524-
15251519
ret = super().c_code(node, name, inames, onames, sub)
15261520

15271521
if self.axis is not None:
@@ -1940,7 +1934,6 @@ def perform(self, node, inp, out):
19401934
z[0] = np.asarray(np.dot(x, y))
19411935

19421936
def grad(self, inp, grads):
1943-
19441937
x, y = inp
19451938
(gz,) = grads
19461939
xdim, ydim, gdim = x.type.ndim, y.type.ndim, gz.type.ndim
@@ -2631,7 +2624,6 @@ def L_op(self, inp, out, grads):
26312624
# this handles inputs with zeros, but only certain input shapes
26322625
return [grad_case_without_zeros]
26332626
else:
2634-
26352627
where_zeros = eq(prod_in, 0.0)
26362628
sum_where_zeros = sum(where_zeros, axis=self.axis)
26372629
groups_with_single_zero = eq(sum_where_zeros, 1).dimshuffle(new_dims)
@@ -2924,7 +2916,6 @@ def _get_output_shape(cls, x1, x2, shapes, validate=False):
29242916
)
29252917
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
29262918
else:
2927-
29282919
if validate:
29292920
from pytensor.tensor.random.basic import broadcast_shapes
29302921

tests/tensor/test_math.py

+73-41
Original file line numberDiff line numberDiff line change
@@ -771,10 +771,9 @@ def test_basic_1(self):
771771
v = eval_outputs(max_and_argmax(n)[0].shape)
772772
assert len(v) == 0
773773

774-
def test_basic_2(self):
775-
data = random(2, 3)
776-
n = as_tensor_variable(data)
777-
for (axis, np_axis) in [
774+
@pytest.mark.parametrize(
775+
"axis,np_axis",
776+
[
778777
(-1, -1),
779778
(0, 0),
780779
(1, 1),
@@ -783,19 +782,28 @@ def test_basic_2(self):
783782
([1, 0], None),
784783
(NoneConst.clone(), None),
785784
(constant(0), 0),
786-
]:
787-
v, i = eval_outputs(max_and_argmax(n, axis))
788-
assert i.dtype == "int64"
789-
assert np.all(v == np.max(data, np_axis))
790-
assert np.all(i == np.argmax(data, np_axis))
791-
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
792-
assert tuple(v_shape) == np.max(data, np_axis).shape
793-
794-
def test_basic_2_float16(self):
795-
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
796-
data = (random(20, 30).astype("float16") - 0.5) * 20
797-
n = shared(data)
798-
for (axis, np_axis) in [
785+
],
786+
)
787+
def test_basic_2(self, axis, np_axis):
788+
data = random(2, 3)
789+
n = as_tensor_variable(data)
790+
# Test shape propagates (static & eval)
791+
vt, it = max_and_argmax(n, axis)
792+
np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis)
793+
assert vt.type.shape == np_max.shape
794+
assert it.type.shape == np_argm.shape
795+
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
796+
assert tuple(v_shape) == vt.type.shape
797+
assert tuple(i_shape) == it.type.shape
798+
# Test values
799+
v, i = eval_outputs([vt, it])
800+
assert i.dtype == "int64"
801+
assert np.all(v == np_max)
802+
assert np.all(i == np_argm)
803+
804+
@pytest.mark.parametrize(
805+
"axis,np_axis",
806+
[
799807
(-1, -1),
800808
(0, 0),
801809
(1, 1),
@@ -804,13 +812,25 @@ def test_basic_2_float16(self):
804812
([1, 0], None),
805813
(NoneConst.clone(), None),
806814
(constant(0), 0),
807-
]:
808-
v, i = eval_outputs(max_and_argmax(n, axis), (MaxAndArgmax,))
809-
assert i.dtype == "int64"
810-
assert np.all(v == np.max(data, np_axis))
811-
assert np.all(i == np.argmax(data, np_axis))
812-
v_shape = eval_outputs(max_and_argmax(n, axis)[0].shape)
813-
assert tuple(v_shape) == np.max(data, np_axis).shape
815+
],
816+
)
817+
def test_basic_2_float16(self, axis, np_axis):
818+
# Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
819+
data = (random(20, 30).astype("float16") - 0.5) * 20
820+
n = as_tensor_variable(data)
821+
# Test shape propagates (static & eval)
822+
vt, it = max_and_argmax(n, axis)
823+
np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis)
824+
assert vt.type.shape == np_max.shape
825+
assert it.type.shape == np_argm.shape
826+
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
827+
assert tuple(v_shape) == vt.type.shape
828+
assert tuple(i_shape) == it.type.shape
829+
# Test values
830+
v, i = eval_outputs([vt, it])
831+
assert i.dtype == "int64"
832+
assert np.all(v == np_max)
833+
assert np.all(i == np_argm)
814834

815835
def test_basic_2_invalid(self):
816836
n = as_tensor_variable(random(2, 3))
@@ -840,23 +860,33 @@ def test_basic_2_valid_neg(self):
840860
v = eval_outputs(max_and_argmax(n, -2)[0].shape)
841861
assert v == (3)
842862

843-
def test_basic_3(self):
844-
data = random(2, 3, 4)
845-
n = as_tensor_variable(data)
846-
for (axis, np_axis) in [
863+
@pytest.mark.parametrize(
864+
"axis,np_axis",
865+
[
847866
(-1, -1),
848867
(0, 0),
849868
(1, 1),
850869
(None, None),
851870
([0, 1, 2], None),
852871
([1, 2, 0], None),
853-
]:
854-
v, i = eval_outputs(max_and_argmax(n, axis))
855-
assert i.dtype == "int64"
856-
assert np.all(v == np.max(data, np_axis))
857-
assert np.all(i == np.argmax(data, np_axis))
858-
v = eval_outputs(max_and_argmax(n, axis)[0].shape)
859-
assert tuple(v) == np.max(data, np_axis).shape
872+
],
873+
)
874+
def test_basic_3(self, axis, np_axis):
875+
data = random(2, 3, 4)
876+
n = as_tensor_variable(data)
877+
# Test shape propagates (static & eval)
878+
vt, it = max_and_argmax(n, axis)
879+
np_max, np_argm = np.max(data, np_axis), np.argmax(data, np_axis)
880+
assert vt.type.shape == np_max.shape
881+
assert it.type.shape == np_argm.shape
882+
v_shape, i_shape = eval_outputs([vt.shape, it.shape])
883+
assert tuple(v_shape) == vt.type.shape
884+
assert tuple(i_shape) == it.type.shape
885+
# Test values
886+
v, i = eval_outputs([vt, it])
887+
assert i.dtype == "int64"
888+
assert np.all(v == np_max)
889+
assert np.all(i == np_argm)
860890

861891
def test_arg_grad(self):
862892
# The test checks that the gradient of argmax(x).sum() is 0
@@ -948,17 +978,19 @@ def test_preserve_broadcastable(self):
948978
# Ensure the original broadcastable flags are preserved by Max/Argmax.
949979
x = matrix().dimshuffle("x", 0, "x", 1, "x")
950980
y = x.max(axis=1)
981+
assert y.type.shape == (1, 1, None, 1)
951982
assert y.type.broadcastable == (True, True, False, True)
952983

953984
def test_multiple_axes(self):
954985
data = np.arange(24).reshape(3, 2, 4)
955986
x = as_tensor_variable(data)
956-
v, i = eval_outputs(max_and_argmax(x, [1, -1]))
987+
vt, it = max_and_argmax(x, [1, -1])
988+
assert vt.type.shape == it.type.shape == (3,)
989+
v, i = eval_outputs([vt, it])
957990
assert np.all(v == np.array([7, 15, 23]))
958991
assert np.all(i == np.array([7, 7, 7]))
959-
960-
v = eval_outputs(max_and_argmax(x, [1, -1])[0].shape)
961-
assert tuple(v) == np.max(data, (1, -1)).shape
992+
v = eval_outputs(vt.shape)
993+
assert tuple(v) == vt.type.shape
962994

963995
def test_zero_shape(self):
964996
x = matrix()
@@ -972,8 +1004,8 @@ def test_zero_shape(self):
9721004
def test_numpy_input(self):
9731005
ar = np.array([1, 2, 3])
9741006
max_at, argmax_at = max_and_argmax(ar, axis=None)
975-
assert max_at.eval(), 3
976-
assert argmax_at.eval(), 2
1007+
assert max_at.eval() == 3
1008+
assert argmax_at.eval() == 2
9771009

9781010

9791011
class TestArgminArgmax:

0 commit comments

Comments
 (0)