@@ -2943,6 +2943,11 @@ static VectorType trimLeadingOneDims(VectorType oldType) {
2943
2943
return VectorType::get (newShape, oldType.getElementType ());
2944
2944
}
2945
2945
2946
+ // / Return a smallVector of size `rank` containing all zeros.
2947
+ static SmallVector<int64_t > splatZero (int64_t rank) {
2948
+ return SmallVector<int64_t >(rank, 0 );
2949
+ }
2950
+
2946
2951
// Casts away leading one dimensions in vector.extract_strided_slice's vector
2947
2952
// input by inserting vector.shape_cast.
2948
2953
struct CastAwayExtractStridedSliceLeadingOneDim
@@ -2969,8 +2974,8 @@ struct CastAwayExtractStridedSliceLeadingOneDim
2969
2974
2970
2975
Location loc = extractOp.getLoc ();
2971
2976
2972
- Value newSrcVector = rewriter.create <vector::ShapeCastOp >(
2973
- loc, newSrcType, extractOp.vector ());
2977
+ Value newSrcVector = rewriter.create <vector::ExtractOp >(
2978
+ loc, extractOp.vector (), splatZero (dropCount ));
2974
2979
2975
2980
// The offsets/sizes/strides attribute can have a less number of elements
2976
2981
// than the input vector's rank: it is meant for the leading dimensions.
@@ -2984,7 +2989,7 @@ struct CastAwayExtractStridedSliceLeadingOneDim
2984
2989
auto newExtractOp = rewriter.create <vector::ExtractStridedSliceOp>(
2985
2990
loc, newDstType, newSrcVector, newOffsets, newSizes, newStrides);
2986
2991
2987
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(extractOp, oldDstType,
2992
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(extractOp, oldDstType,
2988
2993
newExtractOp);
2989
2994
2990
2995
return success ();
@@ -3004,17 +3009,18 @@ struct CastAwayInsertStridedSliceLeadingOneDim
3004
3009
VectorType oldDstType = insertOp.getDestVectorType ();
3005
3010
VectorType newDstType = trimLeadingOneDims (oldDstType);
3006
3011
3007
- if (newSrcType.getRank () == oldSrcType.getRank () &&
3008
- newDstType.getRank () == oldDstType.getRank ())
3012
+ int64_t srcDropCount = oldSrcType.getRank () - newSrcType.getRank ();
3013
+ int64_t dstDropCount = oldDstType.getRank () - newDstType.getRank ();
3014
+ if (srcDropCount == 0 && dstDropCount == 0 )
3009
3015
return failure ();
3010
3016
3011
3017
// Trim leading one dimensions from both operands.
3012
3018
Location loc = insertOp.getLoc ();
3013
3019
3014
- Value newSrcVector = rewriter.create <vector::ShapeCastOp >(
3015
- loc, newSrcType, insertOp.source ());
3016
- Value newDstVector =
3017
- rewriter. create <vector::ShapeCastOp>( loc, newDstType, insertOp.dest ());
3020
+ Value newSrcVector = rewriter.create <vector::ExtractOp >(
3021
+ loc, insertOp.source (), splatZero (srcDropCount ));
3022
+ Value newDstVector = rewriter. create <vector::ExtractOp>(
3023
+ loc, insertOp.dest (), splatZero (dstDropCount ));
3018
3024
3019
3025
auto newOffsets = rewriter.getArrayAttr (
3020
3026
insertOp.offsets ().getValue ().take_back (newDstType.getRank ()));
@@ -3024,7 +3030,7 @@ struct CastAwayInsertStridedSliceLeadingOneDim
3024
3030
auto newInsertOp = rewriter.create <vector::InsertStridedSliceOp>(
3025
3031
loc, newDstType, newSrcVector, newDstVector, newOffsets, newStrides);
3026
3032
3027
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(insertOp, oldDstType,
3033
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(insertOp, oldDstType,
3028
3034
newInsertOp);
3029
3035
3030
3036
return success ();
@@ -3068,7 +3074,7 @@ struct CastAwayTransferReadLeadingOneDim
3068
3074
auto newRead = rewriter.create <vector::TransferReadOp>(
3069
3075
read .getLoc (), newType, read .source (), read .indices (), newMap,
3070
3076
read .padding (), inBounds);
3071
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(read , oldType, newRead);
3077
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(read , oldType, newRead);
3072
3078
3073
3079
return success ();
3074
3080
}
@@ -3092,9 +3098,9 @@ struct CastAwayTransferWriteLeadingOneDim
3092
3098
3093
3099
VectorType oldType = write .getVectorType ();
3094
3100
VectorType newType = trimLeadingOneDims (oldType);
3095
-
3096
3101
if (newType == oldType)
3097
3102
return failure ();
3103
+ int64_t dropDim = oldType.getRank () - newType.getRank ();
3098
3104
3099
3105
AffineMap oldMap = write .permutation_map ();
3100
3106
ArrayRef<AffineExpr> newResults =
@@ -3108,44 +3114,15 @@ struct CastAwayTransferWriteLeadingOneDim
3108
3114
inBounds = rewriter.getArrayAttr (
3109
3115
write .in_boundsAttr ().getValue ().take_back (newType.getRank ()));
3110
3116
3111
- auto newVector = rewriter.create <vector::ShapeCastOp >(
3112
- write .getLoc (), newType, write .vector ());
3117
+ auto newVector = rewriter.create <vector::ExtractOp >(
3118
+ write .getLoc (), write .vector (), splatZero (dropDim ));
3113
3119
rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
3114
3120
write , newVector, write .source (), write .indices (), newMap, inBounds);
3115
3121
3116
3122
return success ();
3117
3123
}
3118
3124
};
3119
3125
3120
- template <typename BroadCastType>
3121
- struct CastAwayBroadcastLeadingOneDim : public OpRewritePattern <BroadCastType> {
3122
- using OpRewritePattern<BroadCastType>::OpRewritePattern;
3123
-
3124
- LogicalResult matchAndRewrite (BroadCastType broadcastOp,
3125
- PatternRewriter &rewriter) const override {
3126
- VectorType dstType =
3127
- broadcastOp.getResult ().getType ().template dyn_cast <VectorType>();
3128
- if (!dstType)
3129
- return failure ();
3130
- VectorType newDstType = trimLeadingOneDims (dstType);
3131
- if (newDstType == dstType)
3132
- return failure ();
3133
- Location loc = broadcastOp.getLoc ();
3134
- Value source = broadcastOp->getOperand (0 );
3135
- VectorType srcVecType = source.getType ().template dyn_cast <VectorType>();
3136
- if (srcVecType)
3137
- srcVecType = trimLeadingOneDims (srcVecType);
3138
- if (srcVecType && srcVecType != source.getType ()) {
3139
- source = rewriter.create <vector::ShapeCastOp>(loc, srcVecType, source);
3140
- }
3141
- Value newBroadcastOp =
3142
- rewriter.create <BroadCastType>(loc, newDstType, source);
3143
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp>(broadcastOp, dstType,
3144
- newBroadcastOp);
3145
- return success ();
3146
- }
3147
- };
3148
-
3149
3126
class CastAwayElementwiseLeadingOneDim : public RewritePattern {
3150
3127
public:
3151
3128
CastAwayElementwiseLeadingOneDim (MLIRContext *context)
@@ -3161,14 +3138,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
3161
3138
VectorType newVecType = trimLeadingOneDims (vecType);
3162
3139
if (newVecType == vecType)
3163
3140
return failure ();
3164
-
3141
+ int64_t dropDim = vecType. getRank () - newVecType. getRank ();
3165
3142
SmallVector<Value, 4 > newOperands;
3166
3143
for (Value operand : op->getOperands ()) {
3167
3144
if (auto opVecType = operand.getType ().dyn_cast <VectorType>()) {
3168
- auto newType =
3169
- VectorType::get (newVecType.getShape (), opVecType.getElementType ());
3170
- newOperands.push_back (rewriter.create <vector::ShapeCastOp>(
3171
- op->getLoc (), newType, operand));
3145
+ newOperands.push_back (rewriter.create <vector::ExtractOp>(
3146
+ op->getLoc (), operand, splatZero (dropDim)));
3172
3147
} else {
3173
3148
newOperands.push_back (operand);
3174
3149
}
@@ -3178,69 +3153,12 @@ class CastAwayElementwiseLeadingOneDim : public RewritePattern {
3178
3153
state.addOperands (newOperands);
3179
3154
state.addTypes (newVecType);
3180
3155
Operation *newOp = rewriter.createOperation (state);
3181
- rewriter.replaceOpWithNewOp <vector::ShapeCastOp >(op, vecType,
3156
+ rewriter.replaceOpWithNewOp <vector::BroadcastOp >(op, vecType,
3182
3157
newOp->getResult (0 ));
3183
3158
return success ();
3184
3159
}
3185
3160
};
3186
3161
3187
- // If extractOp is only removing unit dimensions it can be transformed to a
3188
- // shapecast.
3189
- class ExtractToShapeCast final : public OpRewritePattern<ExtractOp> {
3190
- public:
3191
- using OpRewritePattern<ExtractOp>::OpRewritePattern;
3192
-
3193
- LogicalResult matchAndRewrite (ExtractOp extractOp,
3194
- PatternRewriter &rewriter) const override {
3195
- auto dstVecType = extractOp.getResult ().getType ().dyn_cast <VectorType>();
3196
- if (!dstVecType || extractOp.getVectorType ().getNumElements () !=
3197
- dstVecType.getNumElements ())
3198
- return failure ();
3199
- rewriter.replaceOpWithNewOp <ShapeCastOp>(extractOp, dstVecType,
3200
- extractOp.vector ());
3201
- return success ();
3202
- }
3203
- };
3204
-
3205
- // If insertOp is only inserting unit dimensions it can be transformed to a
3206
- // shapecast.
3207
- class InsertToShapeCast final : public OpRewritePattern<InsertOp> {
3208
- public:
3209
- using OpRewritePattern<InsertOp>::OpRewritePattern;
3210
-
3211
- LogicalResult matchAndRewrite (InsertOp insertOp,
3212
- PatternRewriter &rewriter) const override {
3213
- auto srcVecType = insertOp.getSourceType ().dyn_cast <VectorType>();
3214
- if (!srcVecType || insertOp.getDestVectorType ().getNumElements () !=
3215
- srcVecType.getNumElements ())
3216
- return failure ();
3217
- rewriter.replaceOpWithNewOp <ShapeCastOp>(
3218
- insertOp, insertOp.getDestVectorType (), insertOp.source ());
3219
- return success ();
3220
- }
3221
- };
3222
-
3223
- // BroadcastOp can only add dimensions or broadcast a dimension from 1 to N. In
3224
- // the degenerated case where the broadcast only adds dimensions of size 1 it
3225
- // can be replaced by a ShapeCastOp. This canonicalization checks if the total
3226
- // number of elements is the same before and after the broadcast to detect if
3227
- // the only change in the vector type are new dimensions of size 1.
3228
- class BroadcastToShapeCast final : public OpRewritePattern<BroadcastOp> {
3229
- public:
3230
- using OpRewritePattern<BroadcastOp>::OpRewritePattern;
3231
-
3232
- LogicalResult matchAndRewrite (BroadcastOp broadcastOp,
3233
- PatternRewriter &rewriter) const override {
3234
- auto srcVecType = broadcastOp.getSourceType ().dyn_cast <VectorType>();
3235
- if (!srcVecType || broadcastOp.getVectorType ().getNumElements () !=
3236
- srcVecType.getNumElements ())
3237
- return failure ();
3238
- rewriter.replaceOpWithNewOp <ShapeCastOp>(
3239
- broadcastOp, broadcastOp.getVectorType (), broadcastOp.source ());
3240
- return success ();
3241
- }
3242
- };
3243
-
3244
3162
// Returns the values in `arrayAttr` as an integer vector.
3245
3163
static SmallVector<int64_t , 4 > getIntValueVector (ArrayAttr arrayAttr) {
3246
3164
return llvm::to_vector<4 >(
@@ -3722,13 +3640,11 @@ void mlir::vector::populateShapeCastFoldingPatterns(
3722
3640
3723
3641
void mlir::vector::populateCastAwayVectorLeadingOneDimPatterns (
3724
3642
RewritePatternSet &patterns) {
3725
- patterns.add <
3726
- BroadcastToShapeCast, CastAwayExtractStridedSliceLeadingOneDim,
3727
- CastAwayInsertStridedSliceLeadingOneDim,
3728
- CastAwayTransferReadLeadingOneDim, CastAwayTransferWriteLeadingOneDim,
3729
- CastAwayBroadcastLeadingOneDim<vector::BroadcastOp>,
3730
- CastAwayBroadcastLeadingOneDim<SplatOp>, CastAwayElementwiseLeadingOneDim,
3731
- ExtractToShapeCast, InsertToShapeCast>(patterns.getContext ());
3643
+ patterns.add <CastAwayExtractStridedSliceLeadingOneDim,
3644
+ CastAwayInsertStridedSliceLeadingOneDim,
3645
+ CastAwayTransferReadLeadingOneDim,
3646
+ CastAwayTransferWriteLeadingOneDim,
3647
+ CastAwayElementwiseLeadingOneDim>(patterns.getContext ());
3732
3648
populateShapeCastFoldingPatterns (patterns);
3733
3649
}
3734
3650
0 commit comments