Skip to content

Commit 1ac2d19

Browse files
committed
[mlir][linalg] Add canonicalizers for depthwise conv
There are two main versions of depthwise conv depending whether the multiplier is 1 or not. In cases where m == 1 we should use the version without the multiplier channel as it can perform greater optimization. Add lowering for the quantized/float versions to have a multiplier of one. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D108959
1 parent 3273430 commit 1ac2d19

File tree

2 files changed

+139
-1
lines changed

2 files changed

+139
-1
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 115 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3045,6 +3045,119 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
30453045
return success();
30463046
}
30473047
};
3048+
3049+
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
3050+
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
3051+
}
3052+
3053+
LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
3054+
Value kernel, Value iZp, Value kZp,
3055+
Value init, Attribute stride,
3056+
Attribute dilation,
3057+
PatternRewriter &rewriter) {
3058+
Location loc = operation->getLoc();
3059+
auto linalgOp = dyn_cast<LinalgOp>(operation);
3060+
// Exit out on the memref version of this operation.
3061+
if (!linalgOp || !linalgOp.hasTensorSemantics())
3062+
return failure();
3063+
3064+
auto result = operation->getResult(0);
3065+
3066+
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
3067+
auto initTy = init.getType().dyn_cast<RankedTensorType>();
3068+
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
3069+
if (!kernelTy || !initTy || !resultTy)
3070+
return failure();
3071+
3072+
if (kernelTy.getDimSize(3) != 1)
3073+
return failure();
3074+
3075+
// Collapse kernel dims.
3076+
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
3077+
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
3078+
auto newKernelTy = RankedTensorType::get(
3079+
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
3080+
kernelTy.getElementType());
3081+
auto collapsedKernel = rewriter.create<linalg::TensorCollapseShapeOp>(
3082+
loc, newKernelTy, kernel, collapsedKernelDims);
3083+
3084+
// Collapse init dims.
3085+
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
3086+
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
3087+
getIndicesVector(3, 5)};
3088+
auto newInitTy =
3089+
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
3090+
initTy.getDimSize(2), initTy.getDimSize(3)},
3091+
initTy.getElementType());
3092+
auto collapsedInit = rewriter.create<linalg::TensorCollapseShapeOp>(
3093+
loc, newInitTy, init, collapsedInitDims);
3094+
3095+
Value newConv;
3096+
if (isa<DepthwiseConv2DNhwcOp>(operation)) {
3097+
newConv = rewriter
3098+
.create<DepthwiseConv2DNhwOp>(
3099+
loc, newInitTy, ValueRange{input, collapsedKernel},
3100+
ValueRange{collapsedInit}, stride, dilation)
3101+
.getResult(0);
3102+
} else if (isa<DepthwiseConv2DNhwcQOp>(operation)) {
3103+
newConv =
3104+
rewriter
3105+
.create<DepthwiseConv2DNhwQOp>(
3106+
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
3107+
ValueRange{collapsedInit}, stride, dilation)
3108+
.getResult(0);
3109+
}
3110+
3111+
if (!newConv)
3112+
return failure();
3113+
3114+
// Expand dimensions back out to
3115+
rewriter.replaceOpWithNewOp<linalg::TensorExpandShapeOp>(
3116+
operation, resultTy, newConv, collapsedInitDims);
3117+
return success();
3118+
}
3119+
3120+
struct SimplifyDepthwiseConvOp
3121+
: public OpRewritePattern<DepthwiseConv2DNhwcOp> {
3122+
using OpRewritePattern<DepthwiseConv2DNhwcOp>::OpRewritePattern;
3123+
3124+
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcOp op,
3125+
PatternRewriter &rewriter) const override {
3126+
Operation *operation = op.getOperation();
3127+
Value input = op.getInputOperand(0)->get();
3128+
Value kernel = op.getInputOperand(1)->get();
3129+
Value init = op.getOutputOperand(0)->get();
3130+
3131+
auto stride = op.strides();
3132+
auto dilation = op.dilations();
3133+
3134+
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
3135+
nullptr, init, stride, dilation,
3136+
rewriter);
3137+
}
3138+
};
3139+
3140+
struct SimplifyDepthwiseConvQOp
3141+
: public OpRewritePattern<DepthwiseConv2DNhwcQOp> {
3142+
using OpRewritePattern<DepthwiseConv2DNhwcQOp>::OpRewritePattern;
3143+
3144+
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcQOp op,
3145+
PatternRewriter &rewriter) const override {
3146+
Operation *operation = op.getOperation();
3147+
Value input = op.getInputOperand(0)->get();
3148+
Value kernel = op.getInputOperand(1)->get();
3149+
Value iZp = op.getInputOperand(2)->get();
3150+
Value kZp = op.getInputOperand(3)->get();
3151+
Value init = op.getOutputOperand(0)->get();
3152+
3153+
auto stride = op.strides();
3154+
auto dilation = op.dilations();
3155+
3156+
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
3157+
init, stride, dilation, rewriter);
3158+
}
3159+
};
3160+
30483161
} // namespace
30493162

30503163
#define LINALGOP_FOLDERS(XXX) \
@@ -3070,5 +3183,6 @@ LINALGOP_FOLDERS(GenericOp)
30703183

30713184
void LinalgDialect::getCanonicalizationPatterns(
30723185
RewritePatternSet &results) const {
3073-
results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
3186+
results.add<EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
3187+
SimplifyDepthwiseConvQOp>(getContext());
30743188
}

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,3 +1004,27 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
10041004
return %r2 : index
10051005
}
10061006

1007+
// -----
1008+
1009+
// CHECK-LABEL: @depthwise_conv
1010+
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
1011+
// CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
1012+
// CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
1013+
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
1014+
// CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
1015+
%0 = linalg.depthwise_conv2D_nhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
1016+
return %0 : tensor<?x?x?x?x1xf32>
1017+
}
1018+
1019+
1020+
// -----
1021+
1022+
// CHECK-LABEL: @depthwise_conv_q
1023+
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
1024+
// CHECK-DAG: %[[KERNEL:.+]] = linalg.tensor_collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
1025+
// CHECK-DAG: %[[INIT:.+]] = linalg.tensor_collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
1026+
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv2D_nhw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
1027+
// CHECK: %[[OUT:.+]] = linalg.tensor_expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
1028+
%0 = linalg.depthwise_conv2D_nhwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
1029+
return %0 : tensor<?x?x?x?x1xi32>
1030+
}

0 commit comments

Comments
 (0)