@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1229
1229
VectorType extractSrcType = extractOp.getSourceVectorType ();
1230
1230
Location loc = extractOp.getLoc ();
1231
1231
1232
- // "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233
- assert (extractSrcType.getRank () > 0 &&
1234
- " vector.extract does not support rank 0 sources" );
1235
-
1236
- // "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237
- // canonicalized to %v.
1238
- if (extractOp.getNumIndices () == 0 )
1232
+ // For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1233
+ if (extractSrcType.getRank () <= 1 ) {
1239
1234
return failure ();
1240
-
1241
- // Rewrite vector.extract with 1d source to vector.extractelement.
1242
- if (extractSrcType.getRank () == 1 ) {
1243
- if (extractOp.hasDynamicPosition ())
1244
- // TODO: Dinamic position not supported yet.
1245
- return failure ();
1246
-
1247
- assert (extractOp.getNumIndices () == 1 && " expected 1 index" );
1248
- int64_t pos = extractOp.getStaticPosition ()[0 ];
1249
- rewriter.setInsertionPoint (extractOp);
1250
- rewriter.replaceOpWithNewOp <vector::ExtractElementOp>(
1251
- extractOp, extractOp.getVector (),
1252
- rewriter.create <arith::ConstantIndexOp>(loc, pos));
1253
- return success ();
1254
1235
}
1255
1236
1256
1237
// All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
1313
1294
}
1314
1295
};
1315
1296
1316
- // / Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317
- // / need to be distributed and can just be propagated outside of the region .
1318
- struct WarpOpExtractElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1319
- WarpOpExtractElement (MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320
- PatternBenefit b = 1 )
1297
+ // / Pattern to move out vector.extract with a scalar result.
1298
+ // / Only supports 1-D and 0-D sources for now .
1299
+ struct WarpOpExtractScalar : public OpRewritePattern <WarpExecuteOnLane0Op> {
1300
+ WarpOpExtractScalar (MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1301
+ PatternBenefit b = 1 )
1321
1302
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
1322
1303
warpShuffleFromIdxFn (std::move(fn)) {}
1323
1304
LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1324
1305
PatternRewriter &rewriter) const override {
1325
1306
OpOperand *operand =
1326
- getWarpResult (warpOp, llvm::IsaPred<vector::ExtractElementOp >);
1307
+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractOp >);
1327
1308
if (!operand)
1328
1309
return failure ();
1329
1310
unsigned int operandNumber = operand->getOperandNumber ();
1330
- auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp >();
1311
+ auto extractOp = operand->get ().getDefiningOp <vector::ExtractOp >();
1331
1312
VectorType extractSrcType = extractOp.getSourceVectorType ();
1313
+ // Only supports 1-D or 0-D sources for now.
1314
+ if (extractSrcType.getRank () > 1 ) {
1315
+ return rewriter.notifyMatchFailure (
1316
+ extractOp, " only 0-D or 1-D source supported for now" );
1317
+ }
1332
1318
// TODO: Supported shuffle types should be parameterizable, similar to
1333
1319
// `WarpShuffleFromIdxFn`.
1334
1320
if (!extractSrcType.getElementType ().isF32 () &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1340
1326
VectorType distributedVecType;
1341
1327
if (!is0dOrVec1Extract) {
1342
1328
assert (extractSrcType.getRank () == 1 &&
1343
- " expected that extractelement src rank is 0 or 1" );
1329
+ " expected that extract src rank is 0 or 1" );
1344
1330
if (extractSrcType.getShape ()[0 ] % warpOp.getWarpSize () != 0 )
1345
1331
return failure ();
1346
1332
int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1352
1338
// Yield source vector and position (if present) from warp op.
1353
1339
SmallVector<Value> additionalResults{extractOp.getVector ()};
1354
1340
SmallVector<Type> additionalResultTypes{distributedVecType};
1355
- if (static_cast <bool >(extractOp.getPosition ())) {
1356
- additionalResults.push_back (extractOp.getPosition ());
1357
- additionalResultTypes.push_back (extractOp.getPosition ().getType ());
1358
- }
1341
+ additionalResults.append (
1342
+ SmallVector<Value>(extractOp.getDynamicPosition ()));
1343
+ additionalResultTypes.append (
1344
+ SmallVector<Type>(extractOp.getDynamicPosition ().getTypes ()));
1345
+
1359
1346
Location loc = extractOp.getLoc ();
1360
1347
SmallVector<size_t > newRetIndices;
1361
1348
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1368,39 +1355,33 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1368
1355
// All lanes extract the scalar.
1369
1356
if (is0dOrVec1Extract) {
1370
1357
Value newExtract;
1371
- if (extractSrcType.getRank () == 1 ) {
1372
- newExtract = rewriter.create <vector::ExtractElementOp>(
1373
- loc, distributedVec,
1374
- rewriter.create <arith::ConstantIndexOp>(loc, 0 ));
1375
-
1376
- } else {
1377
- newExtract =
1378
- rewriter.create <vector::ExtractElementOp>(loc, distributedVec);
1379
- }
1358
+ SmallVector<int64_t > indices (extractSrcType.getRank (), 0 );
1359
+ newExtract =
1360
+ rewriter.create <vector::ExtractOp>(loc, distributedVec, indices);
1380
1361
rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
1381
1362
newExtract);
1382
1363
return success ();
1383
1364
}
1384
1365
1366
+ int64_t staticPos = extractOp.getStaticPosition ()[0 ];
1367
+ OpFoldResult pos = ShapedType::isDynamic (staticPos)
1368
+ ? (newWarpOp->getResult (newRetIndices[1 ]))
1369
+ : OpFoldResult (rewriter.getIndexAttr (staticPos));
1385
1370
// 1d extract: Distribute the source vector. One lane extracts and shuffles
1386
1371
// the value to all other lanes.
1387
1372
int64_t elementsPerLane = distributedVecType.getShape ()[0 ];
1388
1373
AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
1389
1374
// tid of extracting thread: pos / elementsPerLane
1390
- Value broadcastFromTid = rewriter.create <affine::AffineApplyOp>(
1391
- loc, sym0.ceilDiv (elementsPerLane),
1392
- newWarpOp->getResult (newRetIndices[1 ]));
1375
+ Value broadcastFromTid = affine::makeComposedAffineApply (
1376
+ rewriter, loc, sym0.ceilDiv (elementsPerLane), pos);
1393
1377
// Extract at position: pos % elementsPerLane
1394
- Value pos =
1378
+ Value newPos =
1395
1379
elementsPerLane == 1
1396
1380
? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1397
- : rewriter
1398
- .create <affine::AffineApplyOp>(
1399
- loc, sym0 % elementsPerLane,
1400
- newWarpOp->getResult (newRetIndices[1 ]))
1401
- .getResult ();
1381
+ : affine::makeComposedAffineApply (rewriter, loc,
1382
+ sym0 % elementsPerLane, pos);
1402
1383
Value extracted =
1403
- rewriter.create <vector::ExtractElementOp >(loc, distributedVec, pos );
1384
+ rewriter.create <vector::ExtractOp >(loc, distributedVec, newPos );
1404
1385
1405
1386
// Shuffle the extracted value to all lanes.
1406
1387
Value shuffled = warpShuffleFromIdxFn (
@@ -1413,31 +1394,59 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1413
1394
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
1414
1395
};
1415
1396
1416
- struct WarpOpInsertElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1397
+ // / Pattern to convert vector.extractelement to vector.extract.
1398
+ struct WarpOpExtractElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1399
+ WarpOpExtractElement (MLIRContext *ctx, PatternBenefit b = 1 )
1400
+ : OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
1401
+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1402
+ PatternRewriter &rewriter) const override {
1403
+ OpOperand *operand =
1404
+ getWarpResult (warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1405
+ if (!operand)
1406
+ return failure ();
1407
+ auto extractOp = operand->get ().getDefiningOp <vector::ExtractElementOp>();
1408
+ SmallVector<OpFoldResult> indices;
1409
+ if (auto pos = extractOp.getPosition ()) {
1410
+ indices.push_back (pos);
1411
+ }
1412
+ rewriter.setInsertionPoint (extractOp);
1413
+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
1414
+ extractOp, extractOp.getVector (), indices);
1415
+ return success ();
1416
+ }
1417
+ };
1418
+
1419
+ // / Pattern to move out vector.insert with a scalar input.
1420
+ // / Only supports 1-D and 0-D destinations for now.
1421
+ struct WarpOpInsertScalar : public OpRewritePattern <WarpExecuteOnLane0Op> {
1417
1422
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1418
1423
1419
1424
LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1420
1425
PatternRewriter &rewriter) const override {
1421
- OpOperand *operand =
1422
- getWarpResult (warpOp, llvm::IsaPred<vector::InsertElementOp>);
1426
+ OpOperand *operand = getWarpResult (warpOp, llvm::IsaPred<vector::InsertOp>);
1423
1427
if (!operand)
1424
1428
return failure ();
1425
1429
unsigned int operandNumber = operand->getOperandNumber ();
1426
- auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp >();
1430
+ auto insertOp = operand->get ().getDefiningOp <vector::InsertOp >();
1427
1431
VectorType vecType = insertOp.getDestVectorType ();
1428
1432
VectorType distrType =
1429
1433
cast<VectorType>(warpOp.getResult (operandNumber).getType ());
1430
- bool hasPos = static_cast <bool >(insertOp.getPosition ());
1434
+
1435
+ // Only supports 1-D or 0-D destinations for now.
1436
+ if (vecType.getRank () > 1 ) {
1437
+ return rewriter.notifyMatchFailure (
1438
+ insertOp, " only 0-D or 1-D source supported for now" );
1439
+ }
1431
1440
1432
1441
// Yield destination vector, source scalar and position from warp op.
1433
1442
SmallVector<Value> additionalResults{insertOp.getDest (),
1434
1443
insertOp.getSource ()};
1435
1444
SmallVector<Type> additionalResultTypes{distrType,
1436
1445
insertOp.getSource ().getType ()};
1437
- if (hasPos) {
1438
- additionalResults. push_back (insertOp. getPosition ());
1439
- additionalResultTypes. push_back (insertOp.getPosition ().getType ( ));
1440
- }
1446
+ additionalResults. append (SmallVector<Value>(insertOp. getDynamicPosition ()));
1447
+ additionalResultTypes. append (
1448
+ SmallVector<Type> (insertOp.getDynamicPosition ().getTypes () ));
1449
+
1441
1450
Location loc = insertOp.getLoc ();
1442
1451
SmallVector<size_t > newRetIndices;
1443
1452
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns (
@@ -1446,13 +1455,26 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1446
1455
rewriter.setInsertionPointAfter (newWarpOp);
1447
1456
Value distributedVec = newWarpOp->getResult (newRetIndices[0 ]);
1448
1457
Value newSource = newWarpOp->getResult (newRetIndices[1 ]);
1449
- Value newPos = hasPos ? newWarpOp->getResult (newRetIndices[2 ]) : Value ();
1450
1458
rewriter.setInsertionPointAfter (newWarpOp);
1451
1459
1460
+ OpFoldResult pos;
1461
+ if (vecType.getRank () != 0 ) {
1462
+ int64_t staticPos = insertOp.getStaticPosition ()[0 ];
1463
+ pos = ShapedType::isDynamic (staticPos)
1464
+ ? (newWarpOp->getResult (newRetIndices[2 ]))
1465
+ : OpFoldResult (rewriter.getIndexAttr (staticPos));
1466
+ }
1467
+
1468
+ // This condition is always true for 0-d vectors.
1452
1469
if (vecType == distrType) {
1453
- // Broadcast: Simply move the vector.inserelement op out.
1454
- Value newInsert = rewriter.create <vector::InsertElementOp>(
1455
- loc, newSource, distributedVec, newPos);
1470
+ Value newInsert;
1471
+ SmallVector<OpFoldResult> indices;
1472
+ if (pos) {
1473
+ indices.push_back (pos);
1474
+ }
1475
+ newInsert = rewriter.create <vector::InsertOp>(loc, newSource,
1476
+ distributedVec, indices);
1477
+ // Broadcast: Simply move the vector.insert op out.
1456
1478
rewriter.replaceAllUsesWith (newWarpOp->getResult (operandNumber),
1457
1479
newInsert);
1458
1480
return success ();
@@ -1462,16 +1484,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1462
1484
int64_t elementsPerLane = distrType.getShape ()[0 ];
1463
1485
AffineExpr sym0 = getAffineSymbolExpr (0 , rewriter.getContext ());
1464
1486
// tid of extracting thread: pos / elementsPerLane
1465
- Value insertingLane = rewriter. create < affine::AffineApplyOp> (
1466
- loc, sym0.ceilDiv (elementsPerLane), newPos );
1487
+ Value insertingLane = affine::makeComposedAffineApply (
1488
+ rewriter, loc, sym0.ceilDiv (elementsPerLane), pos );
1467
1489
// Insert position: pos % elementsPerLane
1468
- Value pos =
1469
- elementsPerLane == 1
1470
- ? rewriter.create <arith::ConstantIndexOp>(loc, 0 ).getResult ()
1471
- : rewriter
1472
- .create <affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473
- newPos)
1474
- .getResult ();
1490
+ OpFoldResult newPos = affine::makeComposedFoldedAffineApply (
1491
+ rewriter, loc, sym0 % elementsPerLane, pos);
1475
1492
Value isInsertingLane = rewriter.create <arith::CmpIOp>(
1476
1493
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid (), insertingLane);
1477
1494
Value newResult =
@@ -1480,8 +1497,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1480
1497
loc, isInsertingLane,
1481
1498
/* thenBuilder=*/
1482
1499
[&](OpBuilder &builder, Location loc) {
1483
- Value newInsert = builder.create <vector::InsertElementOp >(
1484
- loc, newSource, distributedVec, pos );
1500
+ Value newInsert = builder.create <vector::InsertOp >(
1501
+ loc, newSource, distributedVec, newPos );
1485
1502
builder.create <scf::YieldOp>(loc, newInsert);
1486
1503
},
1487
1504
/* elseBuilder=*/
@@ -1506,25 +1523,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1506
1523
auto insertOp = operand->get ().getDefiningOp <vector::InsertOp>();
1507
1524
Location loc = insertOp.getLoc ();
1508
1525
1509
- // "vector.insert %v, %v[] : ..." can be canonicalized to %v .
1510
- if (insertOp.getNumIndices () == 0 )
1526
+ // For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern .
1527
+ if (insertOp.getDestVectorType (). getRank () <= 1 ) {
1511
1528
return failure ();
1512
-
1513
- // Rewrite vector.insert with 1d dest to vector.insertelement.
1514
- if (insertOp.getDestVectorType ().getRank () == 1 ) {
1515
- if (insertOp.hasDynamicPosition ())
1516
- // TODO: Dinamic position not supported yet.
1517
- return failure ();
1518
-
1519
- assert (insertOp.getNumIndices () == 1 && " expected 1 index" );
1520
- int64_t pos = insertOp.getStaticPosition ()[0 ];
1521
- rewriter.setInsertionPoint (insertOp);
1522
- rewriter.replaceOpWithNewOp <vector::InsertElementOp>(
1523
- insertOp, insertOp.getSource (), insertOp.getDest (),
1524
- rewriter.create <arith::ConstantIndexOp>(loc, pos));
1525
- return success ();
1526
1529
}
1527
1530
1531
+ // All following cases are 2d or higher dimensional source vectors.
1532
+
1528
1533
if (warpOp.getResult (operandNumber).getType () == operand->get ().getType ()) {
1529
1534
// There is no distribution, this is a broadcast. Simply move the insert
1530
1535
// out of the warp op.
@@ -1620,9 +1625,30 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
1620
1625
}
1621
1626
};
1622
1627
1628
+ struct WarpOpInsertElement : public OpRewritePattern <WarpExecuteOnLane0Op> {
1629
+ using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1630
+
1631
+ LogicalResult matchAndRewrite (WarpExecuteOnLane0Op warpOp,
1632
+ PatternRewriter &rewriter) const override {
1633
+ OpOperand *operand =
1634
+ getWarpResult (warpOp, llvm::IsaPred<vector::InsertElementOp>);
1635
+ if (!operand)
1636
+ return failure ();
1637
+ auto insertOp = operand->get ().getDefiningOp <vector::InsertElementOp>();
1638
+ SmallVector<OpFoldResult> indices;
1639
+ if (auto pos = insertOp.getPosition ()) {
1640
+ indices.push_back (pos);
1641
+ }
1642
+ rewriter.setInsertionPoint (insertOp);
1643
+ rewriter.replaceOpWithNewOp <vector::InsertOp>(
1644
+ insertOp, insertOp.getSource (), insertOp.getDest (), indices);
1645
+ return success ();
1646
+ }
1647
+ };
1648
+
1623
1649
// / Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624
- // / the scf.ForOp is the last operation in the region so that it doesn't change
1625
- // / the order of execution. This creates a new scf.for region after the
1650
+ // / the scf.ForOp is the last operation in the region so that it doesn't
1651
+ // / change the order of execution. This creates a new scf.for region after the
1626
1652
// / WarpExecuteOnLane0Op. The new scf.for region will contain a new
1627
1653
// / WarpExecuteOnLane0Op region. Example:
1628
1654
// / ```
@@ -1668,8 +1694,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1668
1694
if (!forOp)
1669
1695
return failure ();
1670
1696
// Collect Values that come from the warp op but are outside the forOp.
1671
- // Those Value needs to be returned by the original warpOp and passed to the
1672
- // new op.
1697
+ // Those Value needs to be returned by the original warpOp and passed to
1698
+ // the new op.
1673
1699
llvm::SmallSetVector<Value, 32 > escapingValues;
1674
1700
SmallVector<Type> inputTypes;
1675
1701
SmallVector<Type> distTypes;
@@ -1715,8 +1741,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1715
1741
OpBuilder::InsertionGuard g (rewriter);
1716
1742
rewriter.setInsertionPointAfter (newWarpOp);
1717
1743
1718
- // Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719
- // inside.
1744
+ // Create a new for op outside the region with a WarpExecuteOnLane0Op
1745
+ // region inside.
1720
1746
auto newForOp = rewriter.create <scf::ForOp>(
1721
1747
forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
1722
1748
forOp.getStep (), newOperands);
@@ -1778,8 +1804,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
1778
1804
};
1779
1805
1780
1806
// / A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781
- // / The vector is reduced in parallel. Currently limited to vector size matching
1782
- // / the warpOp size. E.g.:
1807
+ // / The vector is reduced in parallel. Currently limited to vector size
1808
+ // / matching the warpOp size. E.g.:
1783
1809
// / ```
1784
1810
// / %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
1785
1811
// / %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1906,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
1880
1906
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
1881
1907
PatternBenefit readBenefit) {
1882
1908
patterns.add <WarpOpTransferRead>(patterns.getContext (), readBenefit);
1883
- patterns
1884
- . add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast ,
1885
- WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant ,
1886
- WarpOpInsertElement , WarpOpInsert, WarpOpCreateMask>(
1887
- patterns.getContext (), benefit);
1888
- patterns.add <WarpOpExtractElement >(patterns.getContext (),
1889
- warpShuffleFromIdxFn, benefit);
1909
+ patterns. add <WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1910
+ WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand ,
1911
+ WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement ,
1912
+ WarpOpInsertScalar , WarpOpInsert, WarpOpCreateMask>(
1913
+ patterns.getContext (), benefit);
1914
+ patterns.add <WarpOpExtractScalar >(patterns.getContext (), warpShuffleFromIdxFn ,
1915
+ benefit);
1890
1916
patterns.add <WarpOpScfForOp>(patterns.getContext (), distributionMapFn,
1891
1917
benefit);
1892
1918
}
0 commit comments