Skip to content

Commit 2f925d7

Browse files
authored
[mlir][Vector] Move insert/extractelement distribution patterns to insert/extract (llvm#116425)
This is a NFC-ish change that moves vector.extractelement/vector.insertelement vector distribution patterns to vector.insert/vector.extract. Before: 0-d/1-d vector.extract -> vector.extractelement -> distributed vector.extractelement 2-d+ vector.extract -> distributed vector.extract After: scalar input vector.extract -> distributed vector.extract vector.extractelement -> distributed vector.extract 2d+ vector.extract -> distributed vector.extract The same changes are done for insertelement/insert. The change allows us to remove reliance on vector.extractelement/vector.insertelement, which are soon to be depreciated: https://discourse.llvm.org/t/rfc-psa-remove-vector-extractelement-and-vector-insertelement-ops-in-favor-of-vector-extract-and-vector-insert-ops/71116/8 No extra tests are included because this patch doesn't introduce / remove any functionality. It only changes the chain of lowerings. This change can be completly NFC if we make the distributed operation vector.extractelement/vector.insertelement, but that is slightly weird, because you are going from extractelement -> extract -> extractelement.
1 parent a6385a3 commit 2f925d7

File tree

2 files changed

+145
-121
lines changed

2 files changed

+145
-121
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 133 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,28 +1229,9 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
12291229
VectorType extractSrcType = extractOp.getSourceVectorType();
12301230
Location loc = extractOp.getLoc();
12311231

1232-
// "vector.extract %v[] : vector<f32> from vector<f32>" is an invalid op.
1233-
assert(extractSrcType.getRank() > 0 &&
1234-
"vector.extract does not support rank 0 sources");
1235-
1236-
// "vector.extract %v[] : vector<...xf32> from vector<...xf32>" can be
1237-
// canonicalized to %v.
1238-
if (extractOp.getNumIndices() == 0)
1232+
// For 1-d or 0-d source cases, we rely on WarpOpExtractScalar pattern.
1233+
if (extractSrcType.getRank() <= 1) {
12391234
return failure();
1240-
1241-
// Rewrite vector.extract with 1d source to vector.extractelement.
1242-
if (extractSrcType.getRank() == 1) {
1243-
if (extractOp.hasDynamicPosition())
1244-
// TODO: Dinamic position not supported yet.
1245-
return failure();
1246-
1247-
assert(extractOp.getNumIndices() == 1 && "expected 1 index");
1248-
int64_t pos = extractOp.getStaticPosition()[0];
1249-
rewriter.setInsertionPoint(extractOp);
1250-
rewriter.replaceOpWithNewOp<vector::ExtractElementOp>(
1251-
extractOp, extractOp.getVector(),
1252-
rewriter.create<arith::ConstantIndexOp>(loc, pos));
1253-
return success();
12541235
}
12551236

12561237
// All following cases are 2d or higher dimensional source vectors.
@@ -1313,22 +1294,27 @@ struct WarpOpExtract : public OpRewritePattern<WarpExecuteOnLane0Op> {
13131294
}
13141295
};
13151296

1316-
/// Pattern to move out vector.extractelement of 0-D tensors. Those don't
1317-
/// need to be distributed and can just be propagated outside of the region.
1318-
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1319-
WarpOpExtractElement(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1320-
PatternBenefit b = 1)
1297+
/// Pattern to move out vector.extract with a scalar result.
1298+
/// Only supports 1-D and 0-D sources for now.
1299+
struct WarpOpExtractScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
1300+
WarpOpExtractScalar(MLIRContext *ctx, WarpShuffleFromIdxFn fn,
1301+
PatternBenefit b = 1)
13211302
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b),
13221303
warpShuffleFromIdxFn(std::move(fn)) {}
13231304
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
13241305
PatternRewriter &rewriter) const override {
13251306
OpOperand *operand =
1326-
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1307+
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractOp>);
13271308
if (!operand)
13281309
return failure();
13291310
unsigned int operandNumber = operand->getOperandNumber();
1330-
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1311+
auto extractOp = operand->get().getDefiningOp<vector::ExtractOp>();
13311312
VectorType extractSrcType = extractOp.getSourceVectorType();
1313+
// Only supports 1-D or 0-D sources for now.
1314+
if (extractSrcType.getRank() > 1) {
1315+
return rewriter.notifyMatchFailure(
1316+
extractOp, "only 0-D or 1-D source supported for now");
1317+
}
13321318
// TODO: Supported shuffle types should be parameterizable, similar to
13331319
// `WarpShuffleFromIdxFn`.
13341320
if (!extractSrcType.getElementType().isF32() &&
@@ -1340,7 +1326,7 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13401326
VectorType distributedVecType;
13411327
if (!is0dOrVec1Extract) {
13421328
assert(extractSrcType.getRank() == 1 &&
1343-
"expected that extractelement src rank is 0 or 1");
1329+
"expected that extract src rank is 0 or 1");
13441330
if (extractSrcType.getShape()[0] % warpOp.getWarpSize() != 0)
13451331
return failure();
13461332
int64_t elementsPerLane =
@@ -1352,10 +1338,11 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13521338
// Yield source vector and position (if present) from warp op.
13531339
SmallVector<Value> additionalResults{extractOp.getVector()};
13541340
SmallVector<Type> additionalResultTypes{distributedVecType};
1355-
if (static_cast<bool>(extractOp.getPosition())) {
1356-
additionalResults.push_back(extractOp.getPosition());
1357-
additionalResultTypes.push_back(extractOp.getPosition().getType());
1358-
}
1341+
additionalResults.append(
1342+
SmallVector<Value>(extractOp.getDynamicPosition()));
1343+
additionalResultTypes.append(
1344+
SmallVector<Type>(extractOp.getDynamicPosition().getTypes()));
1345+
13591346
Location loc = extractOp.getLoc();
13601347
SmallVector<size_t> newRetIndices;
13611348
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1368,39 +1355,33 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
13681355
// All lanes extract the scalar.
13691356
if (is0dOrVec1Extract) {
13701357
Value newExtract;
1371-
if (extractSrcType.getRank() == 1) {
1372-
newExtract = rewriter.create<vector::ExtractElementOp>(
1373-
loc, distributedVec,
1374-
rewriter.create<arith::ConstantIndexOp>(loc, 0));
1375-
1376-
} else {
1377-
newExtract =
1378-
rewriter.create<vector::ExtractElementOp>(loc, distributedVec);
1379-
}
1358+
SmallVector<int64_t> indices(extractSrcType.getRank(), 0);
1359+
newExtract =
1360+
rewriter.create<vector::ExtractOp>(loc, distributedVec, indices);
13801361
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
13811362
newExtract);
13821363
return success();
13831364
}
13841365

1366+
int64_t staticPos = extractOp.getStaticPosition()[0];
1367+
OpFoldResult pos = ShapedType::isDynamic(staticPos)
1368+
? (newWarpOp->getResult(newRetIndices[1]))
1369+
: OpFoldResult(rewriter.getIndexAttr(staticPos));
13851370
// 1d extract: Distribute the source vector. One lane extracts and shuffles
13861371
// the value to all other lanes.
13871372
int64_t elementsPerLane = distributedVecType.getShape()[0];
13881373
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
13891374
// tid of extracting thread: pos / elementsPerLane
1390-
Value broadcastFromTid = rewriter.create<affine::AffineApplyOp>(
1391-
loc, sym0.ceilDiv(elementsPerLane),
1392-
newWarpOp->getResult(newRetIndices[1]));
1375+
Value broadcastFromTid = affine::makeComposedAffineApply(
1376+
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
13931377
// Extract at position: pos % elementsPerLane
1394-
Value pos =
1378+
Value newPos =
13951379
elementsPerLane == 1
13961380
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1397-
: rewriter
1398-
.create<affine::AffineApplyOp>(
1399-
loc, sym0 % elementsPerLane,
1400-
newWarpOp->getResult(newRetIndices[1]))
1401-
.getResult();
1381+
: affine::makeComposedAffineApply(rewriter, loc,
1382+
sym0 % elementsPerLane, pos);
14021383
Value extracted =
1403-
rewriter.create<vector::ExtractElementOp>(loc, distributedVec, pos);
1384+
rewriter.create<vector::ExtractOp>(loc, distributedVec, newPos);
14041385

14051386
// Shuffle the extracted value to all lanes.
14061387
Value shuffled = warpShuffleFromIdxFn(
@@ -1413,31 +1394,59 @@ struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14131394
WarpShuffleFromIdxFn warpShuffleFromIdxFn;
14141395
};
14151396

1416-
struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1397+
/// Pattern to convert vector.extractelement to vector.extract.
1398+
struct WarpOpExtractElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1399+
WarpOpExtractElement(MLIRContext *ctx, PatternBenefit b = 1)
1400+
: OpRewritePattern<WarpExecuteOnLane0Op>(ctx, b) {}
1401+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1402+
PatternRewriter &rewriter) const override {
1403+
OpOperand *operand =
1404+
getWarpResult(warpOp, llvm::IsaPred<vector::ExtractElementOp>);
1405+
if (!operand)
1406+
return failure();
1407+
auto extractOp = operand->get().getDefiningOp<vector::ExtractElementOp>();
1408+
SmallVector<OpFoldResult> indices;
1409+
if (auto pos = extractOp.getPosition()) {
1410+
indices.push_back(pos);
1411+
}
1412+
rewriter.setInsertionPoint(extractOp);
1413+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
1414+
extractOp, extractOp.getVector(), indices);
1415+
return success();
1416+
}
1417+
};
1418+
1419+
/// Pattern to move out vector.insert with a scalar input.
1420+
/// Only supports 1-D and 0-D destinations for now.
1421+
struct WarpOpInsertScalar : public OpRewritePattern<WarpExecuteOnLane0Op> {
14171422
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
14181423

14191424
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
14201425
PatternRewriter &rewriter) const override {
1421-
OpOperand *operand =
1422-
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1426+
OpOperand *operand = getWarpResult(warpOp, llvm::IsaPred<vector::InsertOp>);
14231427
if (!operand)
14241428
return failure();
14251429
unsigned int operandNumber = operand->getOperandNumber();
1426-
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1430+
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
14271431
VectorType vecType = insertOp.getDestVectorType();
14281432
VectorType distrType =
14291433
cast<VectorType>(warpOp.getResult(operandNumber).getType());
1430-
bool hasPos = static_cast<bool>(insertOp.getPosition());
1434+
1435+
// Only supports 1-D or 0-D destinations for now.
1436+
if (vecType.getRank() > 1) {
1437+
return rewriter.notifyMatchFailure(
1438+
insertOp, "only 0-D or 1-D source supported for now");
1439+
}
14311440

14321441
// Yield destination vector, source scalar and position from warp op.
14331442
SmallVector<Value> additionalResults{insertOp.getDest(),
14341443
insertOp.getSource()};
14351444
SmallVector<Type> additionalResultTypes{distrType,
14361445
insertOp.getSource().getType()};
1437-
if (hasPos) {
1438-
additionalResults.push_back(insertOp.getPosition());
1439-
additionalResultTypes.push_back(insertOp.getPosition().getType());
1440-
}
1446+
additionalResults.append(SmallVector<Value>(insertOp.getDynamicPosition()));
1447+
additionalResultTypes.append(
1448+
SmallVector<Type>(insertOp.getDynamicPosition().getTypes()));
1449+
14411450
Location loc = insertOp.getLoc();
14421451
SmallVector<size_t> newRetIndices;
14431452
WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndAppendReturns(
@@ -1446,13 +1455,26 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14461455
rewriter.setInsertionPointAfter(newWarpOp);
14471456
Value distributedVec = newWarpOp->getResult(newRetIndices[0]);
14481457
Value newSource = newWarpOp->getResult(newRetIndices[1]);
1449-
Value newPos = hasPos ? newWarpOp->getResult(newRetIndices[2]) : Value();
14501458
rewriter.setInsertionPointAfter(newWarpOp);
14511459

1460+
OpFoldResult pos;
1461+
if (vecType.getRank() != 0) {
1462+
int64_t staticPos = insertOp.getStaticPosition()[0];
1463+
pos = ShapedType::isDynamic(staticPos)
1464+
? (newWarpOp->getResult(newRetIndices[2]))
1465+
: OpFoldResult(rewriter.getIndexAttr(staticPos));
1466+
}
1467+
1468+
// This condition is always true for 0-d vectors.
14521469
if (vecType == distrType) {
1453-
// Broadcast: Simply move the vector.inserelement op out.
1454-
Value newInsert = rewriter.create<vector::InsertElementOp>(
1455-
loc, newSource, distributedVec, newPos);
1470+
Value newInsert;
1471+
SmallVector<OpFoldResult> indices;
1472+
if (pos) {
1473+
indices.push_back(pos);
1474+
}
1475+
newInsert = rewriter.create<vector::InsertOp>(loc, newSource,
1476+
distributedVec, indices);
1477+
// Broadcast: Simply move the vector.insert op out.
14561478
rewriter.replaceAllUsesWith(newWarpOp->getResult(operandNumber),
14571479
newInsert);
14581480
return success();
@@ -1462,16 +1484,11 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14621484
int64_t elementsPerLane = distrType.getShape()[0];
14631485
AffineExpr sym0 = getAffineSymbolExpr(0, rewriter.getContext());
14641486
// tid of extracting thread: pos / elementsPerLane
1465-
Value insertingLane = rewriter.create<affine::AffineApplyOp>(
1466-
loc, sym0.ceilDiv(elementsPerLane), newPos);
1487+
Value insertingLane = affine::makeComposedAffineApply(
1488+
rewriter, loc, sym0.ceilDiv(elementsPerLane), pos);
14671489
// Insert position: pos % elementsPerLane
1468-
Value pos =
1469-
elementsPerLane == 1
1470-
? rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult()
1471-
: rewriter
1472-
.create<affine::AffineApplyOp>(loc, sym0 % elementsPerLane,
1473-
newPos)
1474-
.getResult();
1490+
OpFoldResult newPos = affine::makeComposedFoldedAffineApply(
1491+
rewriter, loc, sym0 % elementsPerLane, pos);
14751492
Value isInsertingLane = rewriter.create<arith::CmpIOp>(
14761493
loc, arith::CmpIPredicate::eq, newWarpOp.getLaneid(), insertingLane);
14771494
Value newResult =
@@ -1480,8 +1497,8 @@ struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
14801497
loc, isInsertingLane,
14811498
/*thenBuilder=*/
14821499
[&](OpBuilder &builder, Location loc) {
1483-
Value newInsert = builder.create<vector::InsertElementOp>(
1484-
loc, newSource, distributedVec, pos);
1500+
Value newInsert = builder.create<vector::InsertOp>(
1501+
loc, newSource, distributedVec, newPos);
14851502
builder.create<scf::YieldOp>(loc, newInsert);
14861503
},
14871504
/*elseBuilder=*/
@@ -1506,25 +1523,13 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
15061523
auto insertOp = operand->get().getDefiningOp<vector::InsertOp>();
15071524
Location loc = insertOp.getLoc();
15081525

1509-
// "vector.insert %v, %v[] : ..." can be canonicalized to %v.
1510-
if (insertOp.getNumIndices() == 0)
1526+
// For 1-d or 0-d destination cases, we rely on WarpOpInsertScalar pattern.
1527+
if (insertOp.getDestVectorType().getRank() <= 1) {
15111528
return failure();
1512-
1513-
// Rewrite vector.insert with 1d dest to vector.insertelement.
1514-
if (insertOp.getDestVectorType().getRank() == 1) {
1515-
if (insertOp.hasDynamicPosition())
1516-
// TODO: Dinamic position not supported yet.
1517-
return failure();
1518-
1519-
assert(insertOp.getNumIndices() == 1 && "expected 1 index");
1520-
int64_t pos = insertOp.getStaticPosition()[0];
1521-
rewriter.setInsertionPoint(insertOp);
1522-
rewriter.replaceOpWithNewOp<vector::InsertElementOp>(
1523-
insertOp, insertOp.getSource(), insertOp.getDest(),
1524-
rewriter.create<arith::ConstantIndexOp>(loc, pos));
1525-
return success();
15261529
}
15271530

1531+
// All following cases are 2d or higher dimensional source vectors.
1532+
15281533
if (warpOp.getResult(operandNumber).getType() == operand->get().getType()) {
15291534
// There is no distribution, this is a broadcast. Simply move the insert
15301535
// out of the warp op.
@@ -1620,9 +1625,30 @@ struct WarpOpInsert : public OpRewritePattern<WarpExecuteOnLane0Op> {
16201625
}
16211626
};
16221627

1628+
struct WarpOpInsertElement : public OpRewritePattern<WarpExecuteOnLane0Op> {
1629+
using OpRewritePattern<WarpExecuteOnLane0Op>::OpRewritePattern;
1630+
1631+
LogicalResult matchAndRewrite(WarpExecuteOnLane0Op warpOp,
1632+
PatternRewriter &rewriter) const override {
1633+
OpOperand *operand =
1634+
getWarpResult(warpOp, llvm::IsaPred<vector::InsertElementOp>);
1635+
if (!operand)
1636+
return failure();
1637+
auto insertOp = operand->get().getDefiningOp<vector::InsertElementOp>();
1638+
SmallVector<OpFoldResult> indices;
1639+
if (auto pos = insertOp.getPosition()) {
1640+
indices.push_back(pos);
1641+
}
1642+
rewriter.setInsertionPoint(insertOp);
1643+
rewriter.replaceOpWithNewOp<vector::InsertOp>(
1644+
insertOp, insertOp.getSource(), insertOp.getDest(), indices);
1645+
return success();
1646+
}
1647+
};
1648+
16231649
/// Sink scf.for region out of WarpExecuteOnLane0Op. This can be done only if
1624-
/// the scf.ForOp is the last operation in the region so that it doesn't change
1625-
/// the order of execution. This creates a new scf.for region after the
1650+
/// the scf.ForOp is the last operation in the region so that it doesn't
1651+
/// change the order of execution. This creates a new scf.for region after the
16261652
/// WarpExecuteOnLane0Op. The new scf.for region will contain a new
16271653
/// WarpExecuteOnLane0Op region. Example:
16281654
/// ```
@@ -1668,8 +1694,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
16681694
if (!forOp)
16691695
return failure();
16701696
// Collect Values that come from the warp op but are outside the forOp.
1671-
// Those Value needs to be returned by the original warpOp and passed to the
1672-
// new op.
1697+
// Those Value needs to be returned by the original warpOp and passed to
1698+
// the new op.
16731699
llvm::SmallSetVector<Value, 32> escapingValues;
16741700
SmallVector<Type> inputTypes;
16751701
SmallVector<Type> distTypes;
@@ -1715,8 +1741,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17151741
OpBuilder::InsertionGuard g(rewriter);
17161742
rewriter.setInsertionPointAfter(newWarpOp);
17171743

1718-
// Create a new for op outside the region with a WarpExecuteOnLane0Op region
1719-
// inside.
1744+
// Create a new for op outside the region with a WarpExecuteOnLane0Op
1745+
// region inside.
17201746
auto newForOp = rewriter.create<scf::ForOp>(
17211747
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
17221748
forOp.getStep(), newOperands);
@@ -1778,8 +1804,8 @@ struct WarpOpScfForOp : public OpRewritePattern<WarpExecuteOnLane0Op> {
17781804
};
17791805

17801806
/// A pattern that extracts vector.reduction ops from a WarpExecuteOnLane0Op.
1781-
/// The vector is reduced in parallel. Currently limited to vector size matching
1782-
/// the warpOp size. E.g.:
1807+
/// The vector is reduced in parallel. Currently limited to vector size
1808+
/// matching the warpOp size. E.g.:
17831809
/// ```
17841810
/// %r = vector_ext.warp_execute_on_lane_0(%laneid)[32] -> (f32) {
17851811
/// %0 = "some_def"() : () -> (vector<32xf32>)
@@ -1880,13 +1906,13 @@ void mlir::vector::populatePropagateWarpVectorDistributionPatterns(
18801906
const WarpShuffleFromIdxFn &warpShuffleFromIdxFn, PatternBenefit benefit,
18811907
PatternBenefit readBenefit) {
18821908
patterns.add<WarpOpTransferRead>(patterns.getContext(), readBenefit);
1883-
patterns
1884-
.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1885-
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand, WarpOpConstant,
1886-
WarpOpInsertElement, WarpOpInsert, WarpOpCreateMask>(
1887-
patterns.getContext(), benefit);
1888-
patterns.add<WarpOpExtractElement>(patterns.getContext(),
1889-
warpShuffleFromIdxFn, benefit);
1909+
patterns.add<WarpOpElementwise, WarpOpDeadResult, WarpOpBroadcast,
1910+
WarpOpShapeCast, WarpOpExtract, WarpOpForwardOperand,
1911+
WarpOpConstant, WarpOpExtractElement, WarpOpInsertElement,
1912+
WarpOpInsertScalar, WarpOpInsert, WarpOpCreateMask>(
1913+
patterns.getContext(), benefit);
1914+
patterns.add<WarpOpExtractScalar>(patterns.getContext(), warpShuffleFromIdxFn,
1915+
benefit);
18901916
patterns.add<WarpOpScfForOp>(patterns.getContext(), distributionMapFn,
18911917
benefit);
18921918
}

0 commit comments

Comments
 (0)