@@ -1369,3 +1369,128 @@ func.func @extract_strided_metadata_of_get_global_with_offset()
1369
1369
return %base , %offset , %sizes#0 , %sizes#1 , %strides#0 , %strides#1 :
1370
1370
memref <i32 >, index , index , index , index , index
1371
1371
}
1372
+
1373
+ // -----
1374
+
1375
+ // Check that we simplify extract_strided_metadata of cast
1376
+ // when the source of the cast is compatible with what
1377
+ // `extract_strided_metadata`s accept.
1378
+ //
1379
+ // When we apply the transformation the resulting offset, sizes and strides
1380
+ // should come straight from the inputs of the cast.
1381
+ // Additionally the folder on extract_strided_metadata should propagate the
1382
+ // static information.
1383
+ //
1384
+ // CHECK-LABEL: func @extract_strided_metadata_of_cast
1385
+ // CHECK-SAME: %[[ARG:.*]]: memref<3x?xi32, strided<[4, ?], offset: ?>>)
1386
+ //
1387
+ // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index
1388
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1389
+ // CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1390
+ //
1391
+ // CHECK: return %[[BASE]], %[[DYN_OFFSET]], %[[C3]], %[[DYN_SIZES]]#1, %[[C4]], %[[DYN_STRIDES]]#1
1392
+ func.func @extract_strided_metadata_of_cast (
1393
+ %arg : memref <3 x?xi32 , strided <[4 , ?], offset :?>>)
1394
+ -> (memref <i32 >, index ,
1395
+ index , index ,
1396
+ index , index ) {
1397
+
1398
+ %cast =
1399
+ memref.cast %arg :
1400
+ memref <3 x?xi32 , strided <[4 , ?], offset : ?>> to
1401
+ memref <?x?xi32 , strided <[?, ?], offset : ?>>
1402
+
1403
+ %base , %base_offset , %sizes:2 , %strides:2 =
1404
+ memref.extract_strided_metadata %cast:memref <?x ?xi32 , strided <[?, ?], offset : ?>>
1405
+ -> memref <i32 >, index ,
1406
+ index , index ,
1407
+ index , index
1408
+
1409
+ return %base , %base_offset ,
1410
+ %sizes#0 , %sizes#1 ,
1411
+ %strides#0 , %strides#1 :
1412
+ memref <i32 >, index ,
1413
+ index , index ,
1414
+ index , index
1415
+ }
1416
+
1417
+ // -----
1418
+
1419
+ // Check that we simplify extract_strided_metadata of cast
1420
+ // when the source of the cast is compatible with what
1421
+ // `extract_strided_metadata`s accept.
1422
+ //
1423
+ // Same as extract_strided_metadata_of_cast but with constant sizes and strides
1424
+ // in the destination type.
1425
+ //
1426
+ // CHECK-LABEL: func @extract_strided_metadata_of_cast_w_csts
1427
+ // CHECK-SAME: %[[ARG:.*]]: memref<?x?xi32, strided<[?, ?], offset: ?>>)
1428
+ //
1429
+ // CHECK-DAG: %[[C4:.*]] = arith.constant 4 : index
1430
+ // CHECK-DAG: %[[C18:.*]] = arith.constant 18 : index
1431
+ // CHECK-DAG: %[[C25:.*]] = arith.constant 25 : index
1432
+ // CHECK: %[[BASE:.*]], %[[DYN_OFFSET:.*]], %[[DYN_SIZES:.*]]:2, %[[DYN_STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG]]
1433
+ //
1434
+ // CHECK: return %[[BASE]], %[[C25]], %[[C4]], %[[DYN_SIZES]]#1, %[[DYN_STRIDES]]#0, %[[C18]]
1435
+ func.func @extract_strided_metadata_of_cast_w_csts (
1436
+ %arg : memref <?x?xi32 , strided <[?, ?], offset :?>>)
1437
+ -> (memref <i32 >, index ,
1438
+ index , index ,
1439
+ index , index ) {
1440
+
1441
+ %cast =
1442
+ memref.cast %arg :
1443
+ memref <?x?xi32 , strided <[?, ?], offset : ?>> to
1444
+ memref <4 x?xi32 , strided <[?, 18 ], offset : 25 >>
1445
+
1446
+ %base , %base_offset , %sizes:2 , %strides:2 =
1447
+ memref.extract_strided_metadata %cast:memref <4 x?xi32 , strided <[?, 18 ], offset : 25 >>
1448
+ -> memref <i32 >, index ,
1449
+ index , index ,
1450
+ index , index
1451
+
1452
+ return %base , %base_offset ,
1453
+ %sizes#0 , %sizes#1 ,
1454
+ %strides#0 , %strides#1 :
1455
+ memref <i32 >, index ,
1456
+ index , index ,
1457
+ index , index
1458
+ }
1459
+ // -----
1460
+
1461
+ // Check that we don't simplify extract_strided_metadata of
1462
+ // cast when the source of the cast is unranked.
1463
+ // Unranked memrefs cannot feed into extract_strided_metadata operations.
1464
+ // Note: Technically we could still fold the sizes and strides.
1465
+ //
1466
+ // CHECK-LABEL: func @extract_strided_metadata_of_cast_unranked
1467
+ // CHECK-SAME: %[[ARG:.*]]: memref<*xi32>)
1468
+ //
1469
+ // CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] :
1470
+ // CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[CAST]]
1471
+ //
1472
+ // CHECK: return %[[BASE]], %[[OFFSET]], %[[SIZES]]#0, %[[SIZES]]#1, %[[STRIDES]]#0, %[[STRIDES]]#1
1473
+ func.func @extract_strided_metadata_of_cast_unranked (
1474
+ %arg : memref <*xi32 >)
1475
+ -> (memref <i32 >, index ,
1476
+ index , index ,
1477
+ index , index ) {
1478
+
1479
+ %cast =
1480
+ memref.cast %arg :
1481
+ memref <*xi32 > to
1482
+ memref <?x?xi32 , strided <[?, ?], offset : ?>>
1483
+
1484
+ %base , %base_offset , %sizes:2 , %strides:2 =
1485
+ memref.extract_strided_metadata %cast:memref <?x ?xi32 , strided <[?, ?], offset : ?>>
1486
+ -> memref <i32 >, index ,
1487
+ index , index ,
1488
+ index , index
1489
+
1490
+ return %base , %base_offset ,
1491
+ %sizes#0 , %sizes#1 ,
1492
+ %strides#0 , %strides#1 :
1493
+ memref <i32 >, index ,
1494
+ index , index ,
1495
+ index , index
1496
+ }
0 commit comments