Skip to content

Commit 06dbb28

Browse files
committed
[mlir][vector] Remove usage of shapecast to remove unit dim
Instead of using shape_cast op in the pattern removing leading unit dimensions we use extract/broadcast ops. This is part of the effort to restrict ShapeCastOp fuirther in the future and only allow them to convert to or from 1D vector. This also adds extra canonicalization to fill the gaps in simplifying broadcast/extract ops. Differential Revision: https://reviews.llvm.org/D114205
1 parent ffdace4 commit 06dbb28

File tree

5 files changed

+161
-192
lines changed

5 files changed

+161
-192
lines changed

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 76 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,9 +1125,6 @@ static Value foldExtractFromBroadcast(ExtractOp extractOp) {
11251125
b.getI64ArrayAttr(extractPos));
11261126
return extractOp.getResult();
11271127
}
1128-
// TODO: In case the rank of the broadcast source is greater than the rank of
1129-
// the extract result this can be combined into a new broadcast op. This needs
1130-
// to be added a canonicalization pattern if needed.
11311128
return Value();
11321129
}
11331130

@@ -1208,12 +1205,63 @@ OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
12081205

12091206
namespace {
12101207

1208+
// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1209+
class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1210+
public:
1211+
using OpRewritePattern<ExtractOp>::OpRewritePattern;
1212+
1213+
LogicalResult matchAndRewrite(ExtractOp extractOp,
1214+
PatternRewriter &rewriter) const override {
1215+
Operation *defOp = extractOp.vector().getDefiningOp();
1216+
if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1217+
return failure();
1218+
Value source = defOp->getOperand(0);
1219+
if (extractOp.getType() == source.getType())
1220+
return failure();
1221+
auto getRank = [](Type type) {
1222+
return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1223+
};
1224+
unsigned broadcasrSrcRank = getRank(source.getType());
1225+
unsigned extractResultRank = getRank(extractOp.getType());
1226+
// We only consider the case where the rank of the source is smaller than
1227+
// the rank of the extract dst. The other cases are handled in the folding
1228+
// patterns.
1229+
if (extractResultRank <= broadcasrSrcRank)
1230+
return failure();
1231+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1232+
extractOp, extractOp.getType(), source);
1233+
return success();
1234+
}
1235+
};
1236+
1237+
// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1238+
class ExtractOpConstantFolder final : public OpRewritePattern<ExtractOp> {
1239+
public:
1240+
using OpRewritePattern<ExtractOp>::OpRewritePattern;
1241+
1242+
LogicalResult matchAndRewrite(ExtractOp extractOp,
1243+
PatternRewriter &rewriter) const override {
1244+
// Return if 'extractStridedSliceOp' operand is not defined by a
1245+
// ConstantOp.
1246+
auto constantOp = extractOp.vector().getDefiningOp<arith::ConstantOp>();
1247+
if (!constantOp)
1248+
return failure();
1249+
auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
1250+
if (!dense)
1251+
return failure();
1252+
Attribute newAttr = dense.getSplatValue<Attribute>();
1253+
if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1254+
newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1255+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1256+
return success();
1257+
}
1258+
};
1259+
12111260
} // namespace
12121261

12131262
void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
12141263
MLIRContext *context) {
1215-
// ExtractToShapeCast is not a default canonicalization, it is opt-in by
1216-
// calling `populateCastAwayVectorLeadingOneDimPatterns`
1264+
results.add<ExtractOpConstantFolder, ExtractOpFromBroadcast>(context);
12171265
}
12181266

12191267
static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
@@ -1555,10 +1603,31 @@ static LogicalResult verify(InsertOp op) {
15551603
return success();
15561604
}
15571605

1606+
namespace {
1607+
1608+
// If insertOp is only inserting unit dimensions it can be transformed to a
1609+
// broadcast.
1610+
class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
1611+
public:
1612+
using OpRewritePattern<InsertOp>::OpRewritePattern;
1613+
1614+
LogicalResult matchAndRewrite(InsertOp insertOp,
1615+
PatternRewriter &rewriter) const override {
1616+
auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
1617+
if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
1618+
srcVecType.getNumElements())
1619+
return failure();
1620+
rewriter.replaceOpWithNewOp<BroadcastOp>(
1621+
insertOp, insertOp.getDestVectorType(), insertOp.source());
1622+
return success();
1623+
}
1624+
};
1625+
1626+
} // namespace
1627+
15581628
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
15591629
MLIRContext *context) {
1560-
// InsertToShapeCast is not a default canonicalization, it is opt-in by
1561-
// calling `populateCastAwayVectorLeadingOneDimPatterns`
1630+
results.add<InsertToBroadcast, BroadcastFolder>(context);
15621631
}
15631632

15641633
// Eliminates insert operations that produce values identical to their source

mlir/lib/Dialect/Vector/VectorTransforms.cpp

Lines changed: 29 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -2943,6 +2943,11 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
29432943
return VectorType::get(newShape, oldType.getElementType());
29442944
}
29452945

2946+
/// Return a smallVector of size `rank` containing all zeros.
2947+
static SmallVector<int64_t> splatZero(int64_t rank) {
2948+
return SmallVector<int64_t>(rank, 0);
2949+
}
2950+
29462951
// Casts away leading one dimensions in vector.extract_strided_slice's vector
29472952
// input by inserting vector.shape_cast.
29482953
struct CastAwayExtractStridedSliceLeadingOneDim
@@ -2969,8 +2974,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
29692974

29702975
Location loc = extractOp.getLoc();
29712976

2972-
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
2973-
loc, newSrcType, extractOp.vector());
2977+
Value newSrcVector = rewriter.create<vector::ExtractOp>(
2978+
loc, extractOp.vector(), splatZero(dropCount));
29742979

29752980
// The offsets/sizes/strides attribute can have a less number of elements
29762981
// than the input vector's rank: it is meant for the leading dimensions.
@@ -2984,7 +2989,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
29842989
auto newExtractOp = rewriter.create<vector::ExtractStridedSliceOp>(
29852990
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
29862991

2987-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(extractOp, oldDstType,
2992+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(extractOp, oldDstType,
29882993
newExtractOp);
29892994

29902995
return success();
@@ -3004,17 +3009,18 @@ struct CastAwayInsertStridedSliceLeadingOneDim
30043009
VectorType oldDstType = insertOp.getDestVectorType();
30053010
VectorType newDstType = trimLeadingOneDims(oldDstType);
30063011

3007-
if (newSrcType.getRank() == oldSrcType.getRank() &&
3008-
newDstType.getRank() == oldDstType.getRank())
3012+
int64_t srcDropCount = oldSrcType.getRank() - newSrcType.getRank();
3013+
int64_t dstDropCount = oldDstType.getRank() - newDstType.getRank();
3014+
if (srcDropCount == 0 && dstDropCount == 0)
30093015
return failure();
30103016

30113017
// Trim leading one dimensions from both operands.
30123018
Location loc = insertOp.getLoc();
30133019

3014-
Value newSrcVector = rewriter.create<vector::ShapeCastOp>(
3015-
loc, newSrcType, insertOp.source());
3016-
Value newDstVector =
3017-
rewriter.create<vector::ShapeCastOp>(loc, newDstType, insertOp.dest());
3020+
Value newSrcVector = rewriter.create<vector::ExtractOp>(
3021+
loc, insertOp.source(), splatZero(srcDropCount));
3022+
Value newDstVector = rewriter.create<vector::ExtractOp>(
3023+
loc, insertOp.dest(), splatZero(dstDropCount));
30183024

30193025
auto newOffsets = rewriter.getArrayAttr(
30203026
insertOp.offsets().getValue().take_back(newDstType.getRank()));
@@ -3024,7 +3030,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
30243030
auto newInsertOp = rewriter.create<vector::InsertStridedSliceOp>(
30253031
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
30263032

3027-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(insertOp, oldDstType,
3033+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(insertOp, oldDstType,
30283034
newInsertOp);
30293035

30303036
return success();
@@ -3068,7 +3074,7 @@ struct CastAwayTransferReadLeadingOneDim
30683074
auto newRead = rewriter.create<vector::TransferReadOp>(
30693075
read.getLoc(), newType, read.source(), read.indices(), newMap,
30703076
read.padding(), inBounds);
3071-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(read, oldType, newRead);
3077+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(read, oldType, newRead);
30723078

30733079
return success();
30743080
}
@@ -3092,9 +3098,9 @@ struct CastAwayTransferWriteLeadingOneDim
30923098

30933099
VectorType oldType = write.getVectorType();
30943100
VectorType newType = trimLeadingOneDims(oldType);
3095-
30963101
if (newType == oldType)
30973102
return failure();
3103+
int64_t dropDim = oldType.getRank() - newType.getRank();
30983104

30993105
AffineMap oldMap = write.permutation_map();
31003106
ArrayRef<AffineExpr> newResults =
@@ -3108,44 +3114,15 @@ struct CastAwayTransferWriteLeadingOneDim
31083114
inBounds = rewriter.getArrayAttr(
31093115
write.in_boundsAttr().getValue().take_back(newType.getRank()));
31103116

3111-
auto newVector = rewriter.create<vector::ShapeCastOp>(
3112-
write.getLoc(), newType, write.vector());
3117+
auto newVector = rewriter.create<vector::ExtractOp>(
3118+
write.getLoc(), write.vector(), splatZero(dropDim));
31133119
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
31143120
write, newVector, write.source(), write.indices(), newMap, inBounds);
31153121

31163122
return success();
31173123
}
31183124
};
31193125

3120-
template <typename BroadCastType>
3121-
struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern<BroadCastType> {
3122-
using OpRewritePattern<BroadCastType>::OpRewritePattern;
3123-
3124-
LogicalResult matchAndRewrite(BroadCastType broadcastOp,
3125-
PatternRewriter &rewriter) const override {
3126-
VectorType dstType =
3127-
broadcastOp.getResult().getType().template dyn_cast<VectorType>();
3128-
if (!dstType)
3129-
return failure();
3130-
VectorType newDstType = trimLeadingOneDims(dstType);
3131-
if (newDstType == dstType)
3132-
return failure();
3133-
Location loc = broadcastOp.getLoc();
3134-
Value source = broadcastOp->getOperand(0);
3135-
VectorType srcVecType = source.getType().template dyn_cast<VectorType>();
3136-
if (srcVecType)
3137-
srcVecType = trimLeadingOneDims(srcVecType);
3138-
if (srcVecType && srcVecType != source.getType()) {
3139-
source = rewriter.create<vector::ShapeCastOp>(loc, srcVecType, source);
3140-
}
3141-
Value newBroadcastOp =
3142-
rewriter.create<BroadCastType>(loc, newDstType, source);
3143-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(broadcastOp, dstType,
3144-
newBroadcastOp);
3145-
return success();
3146-
}
3147-
};
3148-
31493126
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
31503127
public:
31513128
CastAwayElementwiseLeadingOneDim(MLIRContext *context)
@@ -3161,14 +3138,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
31613138
VectorType newVecType = trimLeadingOneDims(vecType);
31623139
if (newVecType == vecType)
31633140
return failure();
3164-
3141+
int64_t dropDim = vecType.getRank() - newVecType.getRank();
31653142
SmallVector<Value, 4> newOperands;
31663143
for (Value operand : op->getOperands()) {
31673144
if (auto opVecType = operand.getType().dyn_cast<VectorType>()) {
3168-
auto newType =
3169-
VectorType::get(newVecType.getShape(), opVecType.getElementType());
3170-
newOperands.push_back(rewriter.create<vector::ShapeCastOp>(
3171-
op->getLoc(), newType, operand));
3145+
newOperands.push_back(rewriter.create<vector::ExtractOp>(
3146+
op->getLoc(), operand, splatZero(dropDim)));
31723147
} else {
31733148
newOperands.push_back(operand);
31743149
}
@@ -3178,69 +3153,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
31783153
state.addOperands(newOperands);
31793154
state.addTypes(newVecType);
31803155
Operation *newOp = rewriter.createOperation(state);
3181-
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, vecType,
3156+
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecType,
31823157
newOp->getResult(0));
31833158
return success();
31843159
}
31853160
};
31863161

3187-
// If extractOp is only removing unit dimensions it can be transformed to a
3188-
// shapecast.
3189-
class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
3190-
public:
3191-
using OpRewritePattern<ExtractOp>::OpRewritePattern;
3192-
3193-
LogicalResult matchAndRewrite(ExtractOp extractOp,
3194-
PatternRewriter &rewriter) const override {
3195-
auto dstVecType = extractOp.getResult().getType().dyn_cast<VectorType>();
3196-
if (!dstVecType || extractOp.getVectorType().getNumElements() !=
3197-
dstVecType.getNumElements())
3198-
return failure();
3199-
rewriter.replaceOpWithNewOp<ShapeCastOp>(extractOp, dstVecType,
3200-
extractOp.vector());
3201-
return success();
3202-
}
3203-
};
3204-
3205-
// If insertOp is only inserting unit dimensions it can be transformed to a
3206-
// shapecast.
3207-
class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
3208-
public:
3209-
using OpRewritePattern<InsertOp>::OpRewritePattern;
3210-
3211-
LogicalResult matchAndRewrite(InsertOp insertOp,
3212-
PatternRewriter &rewriter) const override {
3213-
auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
3214-
if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
3215-
srcVecType.getNumElements())
3216-
return failure();
3217-
rewriter.replaceOpWithNewOp<ShapeCastOp>(
3218-
insertOp, insertOp.getDestVectorType(), insertOp.source());
3219-
return success();
3220-
}
3221-
};
3222-
3223-
// BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
3224-
// the degenerated case where the broadcast only adds dimensions of size 1 it
3225-
// can be replaced by a ShapeCastOp. This canonicalization checks if the total
3226-
// number of elements is the same before and after the broadcast to detect if
3227-
// the only change in the vector type are new dimensions of size 1.
3228-
class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
3229-
public:
3230-
using OpRewritePattern<BroadcastOp>::OpRewritePattern;
3231-
3232-
LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
3233-
PatternRewriter &rewriter) const override {
3234-
auto srcVecType = broadcastOp.getSourceType().dyn_cast<VectorType>();
3235-
if (!srcVecType || broadcastOp.getVectorType().getNumElements() !=
3236-
srcVecType.getNumElements())
3237-
return failure();
3238-
rewriter.replaceOpWithNewOp<ShapeCastOp>(
3239-
broadcastOp, broadcastOp.getVectorType(), broadcastOp.source());
3240-
return success();
3241-
}
3242-
};
3243-
32443162
// Returns the values in `arrayAttr` as an integer vector.
32453163
static SmallVector<int64_t, 4> getIntValueVector(ArrayAttr arrayAttr) {
32463164
return llvm::to_vector<4>(
@@ -3722,13 +3640,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(
37223640

37233641
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns(
37243642
RewritePatternSet &patterns) {
3725-
patterns.add<
3726-
BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
3727-
CastAwayInsertStridedSliceLeadingOneDim,
3728-
CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
3729-
CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
3730-
CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
3731-
ExtractToShapeCast, InsertToShapeCast>(patterns.getContext());
3643+
patterns.add<CastAwayExtractStridedSliceLeadingOneDim,
3644+
CastAwayInsertStridedSliceLeadingOneDim,
3645+
CastAwayTransferReadLeadingOneDim,
3646+
CastAwayTransferWriteLeadingOneDim,
3647+
CastAwayElementwiseLeadingOneDim>(patterns.getContext());
37323648
populateShapeCastFoldingPatterns(patterns);
37333649
}
37343650

0 commit comments

Comments
 (0)