@@ -161,6 +161,25 @@ func.func @shape_cast_of_transpose(%arg : vector<1x4x4x1x1xi8>) -> vector<4x4xi8
161
161
162
162
// -----
163
163
164
+ // In this test, the permutation maps the non-unit dimensions (1 and 2) as follows:
165
+ // 1 -> 0
166
+ // 2 -> 4
167
+ // Because 0 < 4, this permutation is order preserving and effectively a shape_cast.
168
+ // (same as the example above, but one of the dims is scalable)
169
+ // CHECK-LABEL: @shape_cast_of_transpose_scalable
170
+ // CHECK-SAME: %[[ARG:.*]]: vector<1x[4]x4x1x1xi8>) -> vector<[4]x4xi8> {
171
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
172
+ // CHECK-SAME: vector<1x[4]x4x1x1xi8> to vector<[4]x4xi8>
173
+ // CHECK: return %[[SHAPE_CAST]] : vector<[4]x4xi8>
174
+ func.func @shape_cast_of_transpose_scalable (%arg : vector <1 x[4 ]x4 x1 x1 xi8 >) -> vector <[4 ]x4 xi8 > {
175
+ %0 = vector.transpose %arg , [1 , 0 , 3 , 4 , 2 ]
176
+ : vector <1 x[4 ]x4 x1 x1 xi8 > to vector <[4 ]x1 x1 x1 x4 xi8 >
177
+ %1 = vector.shape_cast %0 : vector <[4 ]x1 x1 x1 x4 xi8 > to vector <[4 ]x4 xi8 >
178
+ return %1 : vector <[4 ]x4 xi8 >
179
+ }
180
+
181
+ // -----
182
+
164
183
// In this test, the mapping of non-unit dimensions (1 and 2) is as follows:
165
184
// 1 -> 2
166
185
// 2 -> 1
@@ -180,36 +199,10 @@ func.func @negative_shape_cast_of_transpose(%arg : vector<1x4x4x1xi8>) -> vector
180
199
181
200
// -----
182
201
183
- // Currently the conversion shape_cast(transpose) -> shape_cast is disabled for
184
- // scalable vectors because of bad interaction with ConvertIllegalShapeCastOpsToTransposes
185
- // CHECK-LABEL: @negative_shape_cast_of_transpose_scalable
186
- // CHECK: vector.transpose
187
- // CHECK: vector.shape_cast
188
- func.func @negative_shape_cast_of_transpose_scalable (%arg : vector <[4 ]x1 xi8 >) -> vector <[4 ]xi8 > {
189
- %0 = vector.transpose %arg , [1 , 0 ] : vector <[4 ]x1 xi8 > to vector <1 x[4 ]xi8 >
190
- %1 = vector.shape_cast %0 : vector <1 x[4 ]xi8 > to vector <[4 ]xi8 >
191
- return %1 : vector <[4 ]xi8 >
192
- }
193
-
194
- // -----
195
-
196
202
/// +--------------------------------------------------------------------------
197
203
/// Tests of FoldTransposeShapeCast: transpose(shape_cast) -> shape_cast
198
204
/// +--------------------------------------------------------------------------
199
205
200
- // The conversion transpose(shape_cast) -> shape_cast is not disabled for scalable
201
- // vectors.
202
- // CHECK-LABEL: @transpose_of_shape_cast_scalable
203
- // CHECK: vector.shape_cast
204
- // CHECK-SAME: vector<[4]xi8> to vector<[4]x1xi8>
205
- func.func @transpose_of_shape_cast_scalable (%arg : vector <[4 ]xi8 >) -> vector <[4 ]x1 xi8 > {
206
- %0 = vector.shape_cast %arg : vector <[4 ]xi8 > to vector <1 x[4 ]xi8 >
207
- %1 = vector.transpose %0 , [1 , 0 ] : vector <1 x[4 ]xi8 > to vector <[4 ]x1 xi8 >
208
- return %1 : vector <[4 ]x1 xi8 >
209
- }
210
-
211
- // -----
212
-
213
206
// A transpose that is 'order preserving' can be treated like a shape_cast.
214
207
// CHECK-LABEL: @transpose_of_shape_cast
215
208
// CHECK-SAME: %[[ARG:.*]]: vector<2x3x1x1xi8>) -> vector<6x1x1xi8> {
@@ -225,11 +218,26 @@ func.func @transpose_of_shape_cast(%arg : vector<2x3x1x1xi8>) -> vector<6x1x1xi
225
218
226
219
// -----
227
220
228
- // Scalable dimensions should be treated as non-unit dimensions.
229
221
// CHECK-LABEL: @transpose_of_shape_cast_scalable
222
+ // CHECK-SAME: %[[ARG:.*]]: vector<[2]x3x1x1xi8>) -> vector<[6]x1x1xi8> {
223
+ // CHECK: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[ARG]] :
224
+ // CHECK-SAME: vector<[2]x3x1x1xi8> to vector<[6]x1x1xi8>
225
+ // CHECK: return %[[SHAPE_CAST]] : vector<[6]x1x1xi8>
226
+ func.func @transpose_of_shape_cast_scalable (%arg : vector <[2 ]x3 x1 x1 xi8 >) -> vector <[6 ]x1 x1 xi8 > {
227
+ %0 = vector.shape_cast %arg : vector <[2 ]x3 x1 x1 xi8 > to vector <[6 ]x1 x1 xi8 >
228
+ %1 = vector.transpose %0 , [0 , 2 , 1 ]
229
+ : vector <[6 ]x1 x1 xi8 > to vector <[6 ]x1 x1 xi8 >
230
+ return %1 : vector <[6 ]x1 x1 xi8 >
231
+ }
232
+
233
+ // -----
234
+
235
+ // Scalable 1 dimensions (i.e. [1]) should be treated as non-unit dimensions
236
+ // (hence no folding).
237
+ // CHECK-LABEL: @negative_transpose_of_shape_cast_scalable_unit
230
238
// CHECK: vector.shape_cast
231
239
// CHECK: vector.transpose
232
- func.func @transpose_of_shape_cast_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
240
+ func.func @negative_transpose_of_shape_cast_scalable_unit (%arg : vector <[1 ]x4 x1 xi8 >) -> vector <4 x[1 ]xi8 > {
233
241
%0 = vector.shape_cast %arg : vector <[1 ]x4 x1 xi8 > to vector <[1 ]x4 xi8 >
234
242
%1 = vector.transpose %0 , [1 , 0 ] : vector <[1 ]x4 xi8 > to vector <4 x[1 ]xi8 >
235
243
return %1 : vector <4 x[1 ]xi8 >
0 commit comments