@@ -242,6 +242,101 @@ func.func @dim_of_alloca_with_dynamic_size(%arg0: memref<*xf32>) -> index {
242
242
243
243
// -----
244
244
245
+ // Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
246
+ // CHECK-LABEL: func @dim_of_memref_reshape(
247
+ // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
248
+ // CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>
249
+ // CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
250
+ // CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
251
+ // CHECK-NEXT: memref.store
252
+ // CHECK-NOT: memref.dim
253
+ // CHECK: return %[[DIM]] : index
254
+ func.func @dim_of_memref_reshape (%arg0: memref <*xf32 >, %arg1: memref <?xindex >)
255
+ -> index {
256
+ %c3 = arith.constant 3 : index
257
+ %0 = memref.reshape %arg0 (%arg1 )
258
+ : (memref <*xf32 >, memref <?xindex >) -> memref <*xf32 >
259
+ // Update the shape to test that he load ends up in the right place.
260
+ memref.store %c3 , %arg1 [%c3 ] : memref <?xindex >
261
+ %1 = memref.dim %0 , %c3 : memref <*xf32 >
262
+ return %1 : index
263
+ }
264
+
265
+ // -----
266
+
267
+ // Test case: Folding of memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
268
+ // CHECK-LABEL: func @dim_of_memref_reshape_i32(
269
+ // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
270
+ // CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xi32>
271
+ // CHECK-NEXT: %[[IDX:.*]] = arith.constant 3
272
+ // CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
273
+ // CHECK-NEXT: %[[CAST:.*]] = arith.index_cast %[[DIM]]
274
+ // CHECK-NOT: memref.dim
275
+ // CHECK: return %[[CAST]] : index
276
+ func.func @dim_of_memref_reshape_i32 (%arg0: memref <*xf32 >, %arg1: memref <?xi32 >)
277
+ -> index {
278
+ %c3 = arith.constant 3 : index
279
+ %0 = memref.reshape %arg0 (%arg1 )
280
+ : (memref <*xf32 >, memref <?xi32 >) -> memref <*xf32 >
281
+ %1 = memref.dim %0 , %c3 : memref <*xf32 >
282
+ return %1 : index
283
+ }
284
+
285
+ // -----
286
+
287
+ // Test case: memref.dim(memref.reshape %v %shp, %idx) -> memref.load %shp[%idx]
288
+ // CHECK-LABEL: func @dim_of_memref_reshape_block_arg_index(
289
+ // CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
290
+ // CHECK-SAME: %[[SHP:[0-9a-z]+]]: memref<?xindex>,
291
+ // CHECK-SAME: %[[IDX:[0-9a-z]+]]: index
292
+ // CHECK-NEXT: %[[DIM:.*]] = memref.load %[[SHP]][%[[IDX]]]
293
+ // CHECK-NOT: memref.dim
294
+ // CHECK: return %[[DIM]] : index
295
+ func.func @dim_of_memref_reshape_block_arg_index (%arg0: memref <*xf32 >, %arg1: memref <?xindex >, %arg2: index ) -> index {
296
+ %reshape = memref.reshape %arg0 (%arg1 ) : (memref <*xf32 >, memref <?xindex >) -> memref <*xf32 >
297
+ %dim = memref.dim %reshape , %arg2 : memref <*xf32 >
298
+ return %dim : index
299
+ }
300
+
301
+ // -----
302
+
303
+ // Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
304
+ // CHECK-LABEL: func @dim_of_memref_reshape_for(
305
+ // CHECK: memref.reshape
306
+ // CHECK: memref.dim
307
+ // CHECK-NOT: memref.load
308
+ func.func @dim_of_memref_reshape_for ( %arg0: memref <*xf32 >, %arg1: memref <?xindex >) -> index {
309
+ %c0 = arith.constant 0 : index
310
+ %c1 = arith.constant 1 : index
311
+ %c4 = arith.constant 4 : index
312
+
313
+ %0 = memref.reshape %arg0 (%arg1 ) : (memref <*xf32 >, memref <?xindex >) -> memref <*xf32 >
314
+
315
+ %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args (%arg3 = %c1 ) -> (index ) {
316
+ %2 = memref.dim %0 , %arg2 : memref <*xf32 >
317
+ %3 = arith.muli %arg3 , %2 : index
318
+ scf.yield %3 : index
319
+ }
320
+ return %1 : index
321
+ }
322
+
323
+ // -----
324
+
325
+ // Test case: memref.dim(memref.reshape %v %shp, %idx) is not folded into memref.load %shp[%idx]
326
+ // CHECK-LABEL: func @dim_of_memref_reshape_undominated(
327
+ // CHECK: memref.reshape
328
+ // CHECK: memref.dim
329
+ // CHECK-NOT: memref.load
330
+ func.func @dim_of_memref_reshape_undominated (%arg0: memref <*xf32 >, %arg1: memref <?xindex >, %arg2: index ) -> index {
331
+ %c4 = arith.constant 4 : index
332
+ %reshape = memref.reshape %arg0 (%arg1 ) : (memref <*xf32 >, memref <?xindex >) -> memref <*xf32 >
333
+ %0 = arith.muli %arg2 , %c4 : index
334
+ %dim = memref.dim %reshape , %0 : memref <*xf32 >
335
+ return %dim : index
336
+ }
337
+
338
+ // -----
339
+
245
340
// CHECK-LABEL: func @alloc_const_fold
246
341
func.func @alloc_const_fold () -> memref <?xf32 > {
247
342
// CHECK-NEXT: memref.alloc() : memref<4xf32>
0 commit comments