@@ -771,10 +771,9 @@ def test_basic_1(self):
771
771
v = eval_outputs (max_and_argmax (n )[0 ].shape )
772
772
assert len (v ) == 0
773
773
774
- def test_basic_2 (self ):
775
- data = random (2 , 3 )
776
- n = as_tensor_variable (data )
777
- for (axis , np_axis ) in [
774
+ @pytest .mark .parametrize (
775
+ "axis,np_axis" ,
776
+ [
778
777
(- 1 , - 1 ),
779
778
(0 , 0 ),
780
779
(1 , 1 ),
@@ -783,19 +782,28 @@ def test_basic_2(self):
783
782
([1 , 0 ], None ),
784
783
(NoneConst .clone (), None ),
785
784
(constant (0 ), 0 ),
786
- ]:
787
- v , i = eval_outputs (max_and_argmax (n , axis ))
788
- assert i .dtype == "int64"
789
- assert np .all (v == np .max (data , np_axis ))
790
- assert np .all (i == np .argmax (data , np_axis ))
791
- v_shape = eval_outputs (max_and_argmax (n , axis )[0 ].shape )
792
- assert tuple (v_shape ) == np .max (data , np_axis ).shape
793
-
794
- def test_basic_2_float16 (self ):
795
- # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
796
- data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
797
- n = shared (data )
798
- for (axis , np_axis ) in [
785
+ ],
786
+ )
787
+ def test_basic_2 (self , axis , np_axis ):
788
+ data = random (2 , 3 )
789
+ n = as_tensor_variable (data )
790
+ # Test shape propagates (static & eval)
791
+ vt , it = max_and_argmax (n , axis )
792
+ np_max , np_argm = np .max (data , np_axis ), np .argmax (data , np_axis )
793
+ assert vt .type .shape == np_max .shape
794
+ assert it .type .shape == np_argm .shape
795
+ v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
796
+ assert tuple (v_shape ) == vt .type .shape
797
+ assert tuple (i_shape ) == it .type .shape
798
+ # Test values
799
+ v , i = eval_outputs ([vt , it ])
800
+ assert i .dtype == "int64"
801
+ assert np .all (v == np_max )
802
+ assert np .all (i == np_argm )
803
+
804
+ @pytest .mark .parametrize (
805
+ "axis,np_axis" ,
806
+ [
799
807
(- 1 , - 1 ),
800
808
(0 , 0 ),
801
809
(1 , 1 ),
@@ -804,13 +812,25 @@ def test_basic_2_float16(self):
804
812
([1 , 0 ], None ),
805
813
(NoneConst .clone (), None ),
806
814
(constant (0 ), 0 ),
807
- ]:
808
- v , i = eval_outputs (max_and_argmax (n , axis ), (MaxAndArgmax ,))
809
- assert i .dtype == "int64"
810
- assert np .all (v == np .max (data , np_axis ))
811
- assert np .all (i == np .argmax (data , np_axis ))
812
- v_shape = eval_outputs (max_and_argmax (n , axis )[0 ].shape )
813
- assert tuple (v_shape ) == np .max (data , np_axis ).shape
815
+ ],
816
+ )
817
+ def test_basic_2_float16 (self , axis , np_axis ):
818
+ # Test negative values and bigger range to make sure numpy don't do the argmax as on uint16
819
+ data = (random (20 , 30 ).astype ("float16" ) - 0.5 ) * 20
820
+ n = as_tensor_variable (data )
821
+ # Test shape propagates (static & eval)
822
+ vt , it = max_and_argmax (n , axis )
823
+ np_max , np_argm = np .max (data , np_axis ), np .argmax (data , np_axis )
824
+ assert vt .type .shape == np_max .shape
825
+ assert it .type .shape == np_argm .shape
826
+ v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
827
+ assert tuple (v_shape ) == vt .type .shape
828
+ assert tuple (i_shape ) == it .type .shape
829
+ # Test values
830
+ v , i = eval_outputs ([vt , it ])
831
+ assert i .dtype == "int64"
832
+ assert np .all (v == np_max )
833
+ assert np .all (i == np_argm )
814
834
815
835
def test_basic_2_invalid (self ):
816
836
n = as_tensor_variable (random (2 , 3 ))
@@ -840,23 +860,33 @@ def test_basic_2_valid_neg(self):
840
860
v = eval_outputs (max_and_argmax (n , - 2 )[0 ].shape )
841
861
assert v == (3 )
842
862
843
- def test_basic_3 (self ):
844
- data = random (2 , 3 , 4 )
845
- n = as_tensor_variable (data )
846
- for (axis , np_axis ) in [
863
+ @pytest .mark .parametrize (
864
+ "axis,np_axis" ,
865
+ [
847
866
(- 1 , - 1 ),
848
867
(0 , 0 ),
849
868
(1 , 1 ),
850
869
(None , None ),
851
870
([0 , 1 , 2 ], None ),
852
871
([1 , 2 , 0 ], None ),
853
- ]:
854
- v , i = eval_outputs (max_and_argmax (n , axis ))
855
- assert i .dtype == "int64"
856
- assert np .all (v == np .max (data , np_axis ))
857
- assert np .all (i == np .argmax (data , np_axis ))
858
- v = eval_outputs (max_and_argmax (n , axis )[0 ].shape )
859
- assert tuple (v ) == np .max (data , np_axis ).shape
872
+ ],
873
+ )
874
+ def test_basic_3 (self , axis , np_axis ):
875
+ data = random (2 , 3 , 4 )
876
+ n = as_tensor_variable (data )
877
+ # Test shape propagates (static & eval)
878
+ vt , it = max_and_argmax (n , axis )
879
+ np_max , np_argm = np .max (data , np_axis ), np .argmax (data , np_axis )
880
+ assert vt .type .shape == np_max .shape
881
+ assert it .type .shape == np_argm .shape
882
+ v_shape , i_shape = eval_outputs ([vt .shape , it .shape ])
883
+ assert tuple (v_shape ) == vt .type .shape
884
+ assert tuple (i_shape ) == it .type .shape
885
+ # Test values
886
+ v , i = eval_outputs ([vt , it ])
887
+ assert i .dtype == "int64"
888
+ assert np .all (v == np_max )
889
+ assert np .all (i == np_argm )
860
890
861
891
def test_arg_grad (self ):
862
892
# The test checks that the gradient of argmax(x).sum() is 0
@@ -948,17 +978,19 @@ def test_preserve_broadcastable(self):
948
978
# Ensure the original broadcastable flags are preserved by Max/Argmax.
949
979
x = matrix ().dimshuffle ("x" , 0 , "x" , 1 , "x" )
950
980
y = x .max (axis = 1 )
981
+ assert y .type .shape == (1 , 1 , None , 1 )
951
982
assert y .type .broadcastable == (True , True , False , True )
952
983
953
984
def test_multiple_axes (self ):
954
985
data = np .arange (24 ).reshape (3 , 2 , 4 )
955
986
x = as_tensor_variable (data )
956
- v , i = eval_outputs (max_and_argmax (x , [1 , - 1 ]))
987
+ vt , it = max_and_argmax (x , [1 , - 1 ])
988
+ assert vt .type .shape == it .type .shape == (3 ,)
989
+ v , i = eval_outputs ([vt , it ])
957
990
assert np .all (v == np .array ([7 , 15 , 23 ]))
958
991
assert np .all (i == np .array ([7 , 7 , 7 ]))
959
-
960
- v = eval_outputs (max_and_argmax (x , [1 , - 1 ])[0 ].shape )
961
- assert tuple (v ) == np .max (data , (1 , - 1 )).shape
992
+ v = eval_outputs (vt .shape )
993
+ assert tuple (v ) == vt .type .shape
962
994
963
995
def test_zero_shape (self ):
964
996
x = matrix ()
@@ -972,8 +1004,8 @@ def test_zero_shape(self):
972
1004
def test_numpy_input (self ):
973
1005
ar = np .array ([1 , 2 , 3 ])
974
1006
max_at , argmax_at = max_and_argmax (ar , axis = None )
975
- assert max_at .eval (), 3
976
- assert argmax_at .eval (), 2
1007
+ assert max_at .eval () == 3
1008
+ assert argmax_at .eval () == 2
977
1009
978
1010
979
1011
class TestArgminArgmax :
0 commit comments