5
5
6
6
from typing import Tuple
7
7
8
- import pytest
9
8
import torch
10
9
from executorch .backends .arm .test import common
11
10
16
15
TosaPipelineMI ,
17
16
)
18
17
19
- aten_op = "torch.ops.aten.ge.Tensor"
20
- exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor"
21
-
22
18
input_t = Tuple [torch .Tensor ]
23
19
24
20
25
21
class GreaterEqual (torch .nn .Module ):
22
+ aten_op_tensor = "torch.ops.aten.ge.Tensor"
23
+ aten_op_scalar = "torch.ops.aten.ge.Scalar"
24
+ exir_op = "executorch_exir_dialects_edge__ops_aten_ge_Tensor"
25
+
26
26
def __init__ (self , input , other ):
27
27
super ().__init__ ()
28
28
self .input_ = input
@@ -31,106 +31,151 @@ def __init__(self, input, other):
31
31
def forward (
32
32
self ,
33
33
input_ : torch .Tensor ,
34
- other_ : torch .Tensor ,
34
+ other_ : torch .Tensor | int | float ,
35
35
):
36
36
return input_ >= other_
37
37
38
38
def get_inputs (self ):
39
39
return (self .input_ , self .other_ )
40
40
41
41
42
- op_ge_rank1_ones = GreaterEqual (
42
+ op_ge_tensor_rank1_ones = GreaterEqual (
43
43
torch .ones (5 ),
44
44
torch .ones (5 ),
45
45
)
46
- op_ge_rank2_rand = GreaterEqual (
46
+ op_ge_tensor_rank2_rand = GreaterEqual (
47
47
torch .rand (4 , 5 ),
48
48
torch .rand (1 , 5 ),
49
49
)
50
- op_ge_rank3_randn = GreaterEqual (
50
+ op_ge_tensor_rank3_randn = GreaterEqual (
51
51
torch .randn (10 , 5 , 2 ),
52
52
torch .randn (10 , 5 , 2 ),
53
53
)
54
- op_ge_rank4_randn = GreaterEqual (
54
+ op_ge_tensor_rank4_randn = GreaterEqual (
55
55
torch .randn (3 , 2 , 2 , 2 ),
56
56
torch .randn (3 , 2 , 2 , 2 ),
57
57
)
58
58
59
- test_data_common = {
60
- "ge_rank1_ones" : op_ge_rank1_ones ,
61
- "ge_rank2_rand" : op_ge_rank2_rand ,
62
- "ge_rank3_randn" : op_ge_rank3_randn ,
63
- "ge_rank4_randn" : op_ge_rank4_randn ,
59
+ op_ge_scalar_rank1_ones = GreaterEqual (torch .ones (5 ), 1.0 )
60
+ op_ge_scalar_rank2_rand = GreaterEqual (torch .rand (4 , 5 ), 0.2 )
61
+ op_ge_scalar_rank3_randn = GreaterEqual (torch .randn (10 , 5 , 2 ), - 0.1 )
62
+ op_ge_scalar_rank4_randn = GreaterEqual (torch .randn (3 , 2 , 2 , 2 ), 0.3 )
63
+
64
+ test_data_tensor = {
65
+ "ge_tensor_rank1_ones" : op_ge_tensor_rank1_ones ,
66
+ "ge_tensor_rank2_rand" : op_ge_tensor_rank2_rand ,
67
+ "ge_tensor_rank3_randn" : op_ge_tensor_rank3_randn ,
68
+ "ge_tensor_rank4_randn" : op_ge_tensor_rank4_randn ,
69
+ }
70
+
71
+ test_data_scalar = {
72
+ "ge_scalar_rank1_ones" : op_ge_scalar_rank1_ones ,
73
+ "ge_scalar_rank2_rand" : op_ge_scalar_rank2_rand ,
74
+ "ge_scalar_rank3_randn" : op_ge_scalar_rank3_randn ,
75
+ "ge_scalar_rank4_randn" : op_ge_scalar_rank4_randn ,
64
76
}
65
77
66
78
67
- @common .parametrize ("test_module" , test_data_common )
68
- def test_ge_tosa_MI (test_module ):
79
+ @common .parametrize ("test_module" , test_data_tensor )
80
+ def test_ge_tensor_tosa_MI (test_module ):
81
+ pipeline = TosaPipelineMI [input_t ](
82
+ test_module ,
83
+ test_module .get_inputs (),
84
+ GreaterEqual .aten_op_tensor ,
85
+ GreaterEqual .exir_op ,
86
+ )
87
+ pipeline .run ()
88
+
89
+
90
+ @common .parametrize ("test_module" , test_data_scalar )
91
+ def test_ge_scalar_tosa_MI (test_module ):
69
92
pipeline = TosaPipelineMI [input_t ](
70
- test_module , test_module .get_inputs (), aten_op , exir_op
93
+ test_module ,
94
+ test_module .get_inputs (),
95
+ GreaterEqual .aten_op_scalar ,
96
+ GreaterEqual .exir_op ,
71
97
)
72
98
pipeline .run ()
73
99
74
100
75
- @common .parametrize ("test_module" , test_data_common )
76
- def test_ge_tosa_BI (test_module ):
101
+ @common .parametrize ("test_module" , test_data_tensor )
102
+ def test_ge_tensor_tosa_BI (test_module ):
77
103
pipeline = TosaPipelineBI [input_t ](
78
- test_module , test_module .get_inputs (), aten_op , exir_op
104
+ test_module ,
105
+ test_module .get_inputs (),
106
+ GreaterEqual .aten_op_tensor ,
107
+ GreaterEqual .exir_op ,
79
108
)
80
109
pipeline .run ()
81
110
82
111
83
- @common .parametrize ("test_module" , test_data_common )
84
- def test_ge_u55_BI (test_module ):
85
- # GREATER_EQUAL is not supported on U55.
86
- pipeline = OpNotSupportedPipeline [input_t ](
112
+ @common .parametrize ("test_module" , test_data_scalar )
113
+ def test_ge_scalar_tosa_BI (test_module ):
114
+ pipeline = TosaPipelineBI [input_t ](
87
115
test_module ,
88
116
test_module .get_inputs (),
89
- "TOSA-0.80+BI+u55" ,
90
- { exir_op : 1 } ,
117
+ GreaterEqual . aten_op_tensor ,
118
+ GreaterEqual . exir_op ,
91
119
)
92
120
pipeline .run ()
93
121
94
122
95
- @common .parametrize ("test_module" , test_data_common )
96
- def test_ge_u85_BI (test_module ):
97
- pipeline = EthosU85PipelineBI [input_t ](
123
+ @common .parametrize ("test_module" , test_data_tensor )
124
+ @common .XfailIfNoCorstone300
125
+ def test_ge_tensor_u55_BI (test_module ):
126
+ # GREATER_EQUAL is not supported on U55.
127
+ pipeline = OpNotSupportedPipeline [input_t ](
98
128
test_module ,
99
129
test_module .get_inputs (),
100
- aten_op ,
101
- exir_op ,
102
- run_on_fvp = False ,
103
- use_to_edge_transform_and_lower = True ,
130
+ "TOSA-0.80+BI+u55" ,
131
+ {GreaterEqual .exir_op : 1 },
104
132
)
105
133
pipeline .run ()
106
134
107
135
108
- @common .parametrize ("test_module" , test_data_common )
109
- @pytest . mark . skip ( reason = "The same as test_ge_u55_BI" )
110
- def test_ge_u55_BI_on_fvp (test_module ):
136
+ @common .parametrize ("test_module" , test_data_scalar )
137
+ @common . XfailIfNoCorstone300
138
+ def test_ge_scalar_u55_BI (test_module ):
111
139
# GREATER_EQUAL is not supported on U55.
112
140
pipeline = OpNotSupportedPipeline [input_t ](
113
141
test_module ,
114
142
test_module .get_inputs (),
115
143
"TOSA-0.80+BI+u55" ,
116
- {exir_op : 1 },
144
+ {GreaterEqual .exir_op : 1 },
145
+ n_expected_delegates = 1 ,
146
+ )
147
+ pipeline .run ()
148
+
149
+
150
+ @common .parametrize (
151
+ "test_module" ,
152
+ test_data_tensor ,
153
+ xfails = {"ge_tensor_rank4_randn" : "MLETORCH-847: Boolean eq result unstable on U85" },
154
+ )
155
+ @common .XfailIfNoCorstone320
156
+ def test_ge_tensor_u85_BI (test_module ):
157
+ pipeline = EthosU85PipelineBI [input_t ](
158
+ test_module ,
159
+ test_module .get_inputs (),
160
+ GreaterEqual .aten_op_tensor ,
161
+ GreaterEqual .exir_op ,
162
+ run_on_fvp = True ,
117
163
)
118
164
pipeline .run ()
119
165
120
166
121
167
@common .parametrize (
122
168
"test_module" ,
123
- test_data_common ,
124
- xfails = {"ge_rank4_randn " : "4D fails because boolean Tensors can't be subtracted " },
169
+ test_data_scalar ,
170
+ xfails = {"ge_scalar_rank4_randn " : "MLETORCH-847: Boolean eq result unstable on U85 " },
125
171
)
126
- @common .SkipIfNoCorstone320
127
- def test_ge_u85_BI_on_fvp (test_module ):
172
+ @common .XfailIfNoCorstone320
173
+ def test_ge_scalar_u85_BI (test_module ):
128
174
pipeline = EthosU85PipelineBI [input_t ](
129
175
test_module ,
130
176
test_module .get_inputs (),
131
- aten_op ,
132
- exir_op ,
177
+ GreaterEqual . aten_op_tensor ,
178
+ GreaterEqual . exir_op ,
133
179
run_on_fvp = True ,
134
- use_to_edge_transform_and_lower = True ,
135
180
)
136
181
pipeline .run ()
0 commit comments