@@ -41,3 +41,61 @@ def test_pytorch_CumOp(axis, dtype):
41
41
out = pt .cumprod (a , axis = axis )
42
42
fgraph = FunctionGraph ([a ], [out ])
43
43
compare_pytorch_and_py (fgraph , [test_value ])
44
+
45
+
46
+ @pytest .mark .parametrize (
47
+ "axis, repeats" ,
48
+ [
49
+ (0 , (1 , 2 , 3 )),
50
+ (1 , (3 , 3 )),
51
+ pytest .param (
52
+ None ,
53
+ 3 ,
54
+ marks = pytest .mark .xfail (reason = "Reshape not implemented" ),
55
+ ),
56
+ ],
57
+ )
58
+ def test_pytorch_Repeat (axis , repeats ):
59
+ a = pt .matrix ("a" , dtype = "float64" )
60
+
61
+ test_value = np .arange (6 , dtype = "float64" ).reshape ((3 , 2 ))
62
+
63
+ out = pt .repeat (a , repeats , axis = axis )
64
+ fgraph = FunctionGraph ([a ], [out ])
65
+ compare_pytorch_and_py (fgraph , [test_value ])
66
+
67
+
68
+ @pytest .mark .parametrize ("axis" , [None , 0 , 1 ])
69
+ def test_pytorch_Unique_axis (axis ):
70
+ a = pt .matrix ("a" , dtype = "float64" )
71
+
72
+ test_value = np .array (
73
+ [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
74
+ )
75
+
76
+ out = pt .unique (a , axis = axis )
77
+ fgraph = FunctionGraph ([a ], [out ])
78
+ compare_pytorch_and_py (fgraph , [test_value ])
79
+
80
+
81
+ @pytest .mark .parametrize ("return_inverse" , [False , True ])
82
+ @pytest .mark .parametrize ("return_counts" , [False , True ])
83
+ @pytest .mark .parametrize (
84
+ "return_index" ,
85
+ (False , pytest .param (True , marks = pytest .mark .xfail (raises = NotImplementedError ))),
86
+ )
87
+ def test_pytorch_Unique_params (return_index , return_inverse , return_counts ):
88
+ a = pt .matrix ("a" , dtype = "float64" )
89
+ test_value = np .array (
90
+ [[1.0 , 1.0 , 2.0 ], [1.0 , 1.0 , 2.0 ], [3.0 , 3.0 , 0.0 ]], dtype = "float64"
91
+ )
92
+
93
+ out = pt .unique (
94
+ a ,
95
+ return_index = return_index ,
96
+ return_inverse = return_inverse ,
97
+ return_counts = return_counts ,
98
+ axis = 0 ,
99
+ )
100
+ fgraph = FunctionGraph ([a ], [out [0 ] if isinstance (out , list ) else out ])
101
+ compare_pytorch_and_py (fgraph , [test_value ])
0 commit comments