@@ -3045,6 +3045,119 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
3045
3045
return success ();
3046
3046
}
3047
3047
};
3048
+
3049
+ static llvm::SmallVector<int64_t > getIndicesVector (int start, int end) {
3050
+ return llvm::to_vector<2 >(llvm::seq<int64_t >(start, end));
3051
+ }
3052
+
3053
+ LogicalResult matchAndReplaceDepthwiseConv (Operation *operation, Value input,
3054
+ Value kernel, Value iZp, Value kZp ,
3055
+ Value init, Attribute stride,
3056
+ Attribute dilation,
3057
+ PatternRewriter &rewriter) {
3058
+ Location loc = operation->getLoc ();
3059
+ auto linalgOp = dyn_cast<LinalgOp>(operation);
3060
+ // Exit out on the memref version of this operation.
3061
+ if (!linalgOp || !linalgOp.hasTensorSemantics ())
3062
+ return failure ();
3063
+
3064
+ auto result = operation->getResult (0 );
3065
+
3066
+ auto kernelTy = kernel.getType ().dyn_cast <RankedTensorType>();
3067
+ auto initTy = init.getType ().dyn_cast <RankedTensorType>();
3068
+ auto resultTy = result.getType ().template dyn_cast <RankedTensorType>();
3069
+ if (!kernelTy || !initTy || !resultTy)
3070
+ return failure ();
3071
+
3072
+ if (kernelTy.getDimSize (3 ) != 1 )
3073
+ return failure ();
3074
+
3075
+ // Collapse kernel dims.
3076
+ SmallVector<ReassociationIndices, 4 > collapsedKernelDims = {
3077
+ getIndicesVector (0 , 1 ), getIndicesVector (1 , 2 ), getIndicesVector (2 , 4 )};
3078
+ auto newKernelTy = RankedTensorType::get (
3079
+ {kernelTy.getDimSize (0 ), kernelTy.getDimSize (1 ), kernelTy.getDimSize (2 )},
3080
+ kernelTy.getElementType ());
3081
+ auto collapsedKernel = rewriter.create <linalg::TensorCollapseShapeOp>(
3082
+ loc, newKernelTy, kernel, collapsedKernelDims);
3083
+
3084
+ // Collapse init dims.
3085
+ SmallVector<ReassociationIndices, 4 > collapsedInitDims = {
3086
+ getIndicesVector (0 , 1 ), getIndicesVector (1 , 2 ), getIndicesVector (2 , 3 ),
3087
+ getIndicesVector (3 , 5 )};
3088
+ auto newInitTy =
3089
+ RankedTensorType::get ({initTy.getDimSize (0 ), initTy.getDimSize (1 ),
3090
+ initTy.getDimSize (2 ), initTy.getDimSize (3 )},
3091
+ initTy.getElementType ());
3092
+ auto collapsedInit = rewriter.create <linalg::TensorCollapseShapeOp>(
3093
+ loc, newInitTy, init, collapsedInitDims);
3094
+
3095
+ Value newConv;
3096
+ if (isa<DepthwiseConv2DNhwcOp>(operation)) {
3097
+ newConv = rewriter
3098
+ .create <DepthwiseConv2DNhwOp>(
3099
+ loc, newInitTy, ValueRange{input, collapsedKernel},
3100
+ ValueRange{collapsedInit}, stride, dilation)
3101
+ .getResult (0 );
3102
+ } else if (isa<DepthwiseConv2DNhwcQOp>(operation)) {
3103
+ newConv =
3104
+ rewriter
3105
+ .create <DepthwiseConv2DNhwQOp>(
3106
+ loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp },
3107
+ ValueRange{collapsedInit}, stride, dilation)
3108
+ .getResult (0 );
3109
+ }
3110
+
3111
+ if (!newConv)
3112
+ return failure ();
3113
+
3114
+ // Expand dimensions back out to
3115
+ rewriter.replaceOpWithNewOp <linalg::TensorExpandShapeOp>(
3116
+ operation, resultTy, newConv, collapsedInitDims);
3117
+ return success ();
3118
+ }
3119
+
3120
+ struct SimplifyDepthwiseConvOp
3121
+ : public OpRewritePattern<DepthwiseConv2DNhwcOp> {
3122
+ using OpRewritePattern<DepthwiseConv2DNhwcOp>::OpRewritePattern;
3123
+
3124
+ LogicalResult matchAndRewrite (DepthwiseConv2DNhwcOp op,
3125
+ PatternRewriter &rewriter) const override {
3126
+ Operation *operation = op.getOperation ();
3127
+ Value input = op.getInputOperand (0 )->get ();
3128
+ Value kernel = op.getInputOperand (1 )->get ();
3129
+ Value init = op.getOutputOperand (0 )->get ();
3130
+
3131
+ auto stride = op.strides ();
3132
+ auto dilation = op.dilations ();
3133
+
3134
+ return matchAndReplaceDepthwiseConv (operation, input, kernel, nullptr ,
3135
+ nullptr , init, stride, dilation,
3136
+ rewriter);
3137
+ }
3138
+ };
3139
+
3140
+ struct SimplifyDepthwiseConvQOp
3141
+ : public OpRewritePattern<DepthwiseConv2DNhwcQOp> {
3142
+ using OpRewritePattern<DepthwiseConv2DNhwcQOp>::OpRewritePattern;
3143
+
3144
+ LogicalResult matchAndRewrite (DepthwiseConv2DNhwcQOp op,
3145
+ PatternRewriter &rewriter) const override {
3146
+ Operation *operation = op.getOperation ();
3147
+ Value input = op.getInputOperand (0 )->get ();
3148
+ Value kernel = op.getInputOperand (1 )->get ();
3149
+ Value iZp = op.getInputOperand (2 )->get ();
3150
+ Value kZp = op.getInputOperand (3 )->get ();
3151
+ Value init = op.getOutputOperand (0 )->get ();
3152
+
3153
+ auto stride = op.strides ();
3154
+ auto dilation = op.dilations ();
3155
+
3156
+ return matchAndReplaceDepthwiseConv (operation, input, kernel, iZp, kZp ,
3157
+ init, stride, dilation, rewriter);
3158
+ }
3159
+ };
3160
+
3048
3161
} // namespace
3049
3162
3050
3163
#define LINALGOP_FOLDERS (XXX ) \
@@ -3070,5 +3183,6 @@ LINALGOP_FOLDERS(GenericOp)
3070
3183
3071
3184
void LinalgDialect::getCanonicalizationPatterns(
3072
3185
RewritePatternSet &results) const {
3073
- results.add <EraseDeadLinalgOp, FoldTensorCastOp>(getContext ());
3186
+ results.add <EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
3187
+ SimplifyDepthwiseConvQOp>(getContext ());
3074
3188
}
0 commit comments