Skip to content

Commit 0623d1c

Browse files
committed
[MLIR][Tensor] Add Destination style RewritePattern for DimOp.
Fold dim of a destination passing op with dim of the corresponding init. This enables canonicalization to fold away unnecessary tensor.dim ops which in turn enables folding away of other operations, as can be seen in conv_tensors_dynamic where affine.min operations were folded away.
1 parent 7d7dcc1 commit 0623d1c

File tree

4 files changed

+54
-19
lines changed

4 files changed

+54
-19
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,32 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
579579
return success();
580580
}
581581
};
582+
583+
/// Fold dim of a destination passing style op into the dim of the corresponding
584+
/// init.
585+
struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
586+
using OpRewritePattern<DimOp>::OpRewritePattern;
587+
588+
LogicalResult matchAndRewrite(DimOp dimOp,
589+
PatternRewriter &rewriter) const override {
590+
auto source = dimOp.getSource();
591+
auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
592+
if (!destOp)
593+
return failure();
594+
595+
auto resultIndex = source.cast<OpResult>().getResultNumber();
596+
auto initOperand = destOp.getDpsInitOperand(resultIndex);
597+
598+
rewriter.updateRootInPlace(
599+
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
600+
return success();
601+
}
602+
};
582603
} // namespace
583604

584605
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
585606
MLIRContext *context) {
586-
results.add<DimOfCastOp>(context);
607+
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
587608
}
588609

589610
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,8 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> {
397397

398398
// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index
399399
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
400-
// CHECK: %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>)
401400
// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
402-
// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
401+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32>
403402
// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
404403
// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
405404
// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
@@ -908,3 +907,24 @@ func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
908907
ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
909908
return %arg0 : tensor<16x64x256xf32>
910909
}
910+
911+
// -----
912+
913+
// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op
914+
// CHECK: tensor.dim
915+
// CHECK: tensor.dim
916+
// CHECK-NOT: tensor.dim
917+
// CHECK: return
918+
func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
919+
%c0 = arith.constant 0 : index
920+
%c1 = arith.constant 1 : index
921+
%dim0_0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
922+
%dim1_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
923+
%0 = tensor.empty(%dim0_0, %dim1_0) : tensor<?x?xf32>
924+
%1 = linalg.copy ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
925+
%dim0_1 = tensor.dim %1, %c0 : tensor<?x?xf32>
926+
%dim1_1 = tensor.dim %1, %c1 : tensor<?x?xf32>
927+
%2 = tensor.empty(%dim0_1, %dim1_1) : tensor<?x?xf32>
928+
%3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
929+
return %3: tensor<?x?xf32>
930+
}

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,8 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
197197
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
198198
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
199199
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * -2 + s0 * 2 + s1 - 2, d1 * 2 + s1 - 2)>
200-
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
201200
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
202201
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
203-
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 4)>
204202
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, -d1 + s1, 2)>
205203

206204
// CHECK: func @conv_tensors_dynamic
@@ -225,23 +223,19 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
225223
// CHECK-DAG: %[[FILTER_OC:.+]] = tensor.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
226224
// CHECK-DAG: %[[INPUT_N:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
227225
// CHECK-DAG: %[[INPUT_C:.+]] = tensor.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
228-
// CHECK-DAG: %[[FILL_H:.+]] = tensor.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
229-
// CHECK-DAG: %[[FILL_W:.+]] = tensor.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
230226

231227
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
232228
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
233229
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
234230
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
235231
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
236232
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
237-
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[FILL_H]], %[[FILTER_H]]]
238-
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
233+
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]], %[[FILTER_H]]]
239234
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
240235
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
241236
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
242237
// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
243-
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[FILL_W]], %[[FILTER_W]]]
244-
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
238+
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]], %[[FILTER_W]]]
245239
// CHECK-NEXT: %[[ST_INPUT:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
246240
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
247241
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
@@ -253,7 +247,7 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
253247
// CHECK-NEXT: %[[ST_FILTER:.+]] = tensor.extract_slice %[[FILTER]][0, 0, 0, %[[IV3]]]
254248
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
255249
// CHECK-NEXT: %[[ST_FILL:.+]] = tensor.extract_slice %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
256-
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_2]]]
250+
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC_2]]]
257251
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_nhwc_hwcf
258252
// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
259253
// CHECK-SAME: outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ transform.sequence failures(propagate) {
4343
// CHECK: arith.addf
4444
// CHECK: linalg.yield
4545
// CHECK: } -> tensor<?x?xf32>
46-
// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor<?x?xf32>
47-
// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor<?x?xf32>
48-
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
46+
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
4947
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
5048
// CHECK: }
5149
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
@@ -76,14 +74,16 @@ transform.sequence failures(propagate) {
7674
by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
7775
}
7876

77+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
78+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
79+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
7980
// CHECK: func @reduction_tile_transpose
8081
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
8182
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
8283
// CHECK: scf.for
83-
// CHECK: linalg.generic
84-
// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x?xf32>
85-
// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<?x?xf32>
86-
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
84+
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
85+
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
86+
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
8787
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
8888
// CHECK: }
8989
// CHECK: linalg.generic

0 commit comments

Comments
 (0)