Skip to content

Commit 0aa459f

Browse files
authored
[MLIR][Tensor] Add Destination style RewritePattern for DimOp. (#65780)
This enables canonicalization to fold away unnecessary tensor.dim ops which in turn enables folding away of other operations, as can be seen in conv_tensors_dynamic where affine.min operations were folded away.
1 parent 2744d9b commit 0aa459f

File tree

4 files changed

+54
-19
lines changed

4 files changed

+54
-19
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -579,11 +579,32 @@ struct DimOfCastOp : public OpRewritePattern<DimOp> {
579579
return success();
580580
}
581581
};
582+
583+
/// Fold dim of a destination passing style op into the dim of the corresponding
584+
/// init.
585+
struct DimOfDestStyleOp : public OpRewritePattern<DimOp> {
586+
using OpRewritePattern<DimOp>::OpRewritePattern;
587+
588+
LogicalResult matchAndRewrite(DimOp dimOp,
589+
PatternRewriter &rewriter) const override {
590+
auto source = dimOp.getSource();
591+
auto destOp = source.getDefiningOp<DestinationStyleOpInterface>();
592+
if (!destOp)
593+
return failure();
594+
595+
auto resultIndex = source.cast<OpResult>().getResultNumber();
596+
auto initOperand = destOp.getDpsInitOperand(resultIndex);
597+
598+
rewriter.updateRootInPlace(
599+
dimOp, [&]() { dimOp.getSourceMutable().assign(initOperand->get()); });
600+
return success();
601+
}
602+
};
582603
} // namespace
583604

584605
void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
585606
MLIRContext *context) {
586-
results.add<DimOfCastOp>(context);
607+
results.add<DimOfCastOp, DimOfDestStyleOp>(context);
587608
}
588609

589610
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,8 @@ func.func @fold_static_pad_fill() -> tensor<412x276xf32> {
397397

398398
// CHECK-DAG: %[[I1:.+]] = arith.constant 1 : index
399399
// CHECK-DAG: %[[F0:.+]] = arith.constant 0.000000e+00 : f32
400-
// CHECK: %[[OF:.+]] = linalg.fill ins(%[[F0]] : f32) outs(%[[SRC]] : tensor<8x?x16x32xf32>)
401400
// CHECK: %[[S0:.+]] = affine.apply #[[MAP0]]()[%[[LOW0]]]
402-
// CHECK: %[[DIM1:.+]] = tensor.dim %[[OF]], %[[I1]] : tensor<8x?x16x32xf32>
401+
// CHECK: %[[DIM1:.+]] = tensor.dim %[[SRC]], %[[I1]] : tensor<8x?x16x32xf32>
403402
// CHECK: %[[S1:.+]] = affine.apply #[[MAP1]]()[%[[DIM1]]]
404403
// CHECK: %[[S2:.+]] = affine.apply #[[MAP2]]()[%[[HIGH2]]]
405404
// CHECK: %[[S3:.+]] = affine.apply #[[MAP3]]()[%[[LOW3]], %[[HIGH3]]]
@@ -908,3 +907,24 @@ func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
908907
ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
909908
return %arg0 : tensor<16x64x256xf32>
910909
}
910+
911+
// -----
912+
913+
// CHECK-LABEL: func @canonicalize_dim_of_dest_style_op
914+
// CHECK: tensor.dim
915+
// CHECK: tensor.dim
916+
// CHECK-NOT: tensor.dim
917+
// CHECK: return
918+
func.func @canonicalize_dim_of_dest_style_op(%arg0 : tensor<?x?xf32>) -> tensor<?x?xf32> {
919+
%c0 = arith.constant 0 : index
920+
%c1 = arith.constant 1 : index
921+
%dim0_0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
922+
%dim1_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
923+
%0 = tensor.empty(%dim0_0, %dim1_0) : tensor<?x?xf32>
924+
%1 = linalg.copy ins(%arg0 : tensor<?x?xf32>) outs(%0 : tensor<?x?xf32>) -> tensor<?x?xf32>
925+
%dim0_1 = tensor.dim %1, %c0 : tensor<?x?xf32>
926+
%dim1_1 = tensor.dim %1, %c1 : tensor<?x?xf32>
927+
%2 = tensor.empty(%dim0_1, %dim1_1) : tensor<?x?xf32>
928+
%3 = linalg.copy ins(%1 : tensor<?x?xf32>) outs(%2 : tensor<?x?xf32>) -> tensor<?x?xf32>
929+
return %3: tensor<?x?xf32>
930+
}

mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,10 +197,8 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
197197
// CHECK: #[[BOUND16_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 16)>
198198
// CHECK: #[[X2_MAP:.+]] = affine_map<(d0) -> (d0 * 2)>
199199
// CHECK: #[[INPUT_BOUND:.+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * -2 + s0 * 2 + s1 - 2, d1 * 2 + s1 - 2)>
200-
// CHECK: #[[BOUND16_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 16)>
201200
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
202201
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
203-
// CHECK: #[[BOUND4_MAP_2:.+]] = affine_map<(d0)[s0, s1] -> (-d0 + s1, -d0 + s0, 4)>
204202
// CHECK: #[[BOUND2_MAP_2:.+]] = affine_map<(d0, d1)[s0, s1] -> (-d0 + s0, -d1 + s1, 2)>
205203

206204
// CHECK: func @conv_tensors_dynamic
@@ -225,23 +223,19 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
225223
// CHECK-DAG: %[[FILTER_OC:.+]] = tensor.dim %[[FILTER]], %[[C3]] : tensor<?x?x?x?xf32>
226224
// CHECK-DAG: %[[INPUT_N:.+]] = tensor.dim %[[INPUT]], %[[C0]] : tensor<?x?x?x?xf32>
227225
// CHECK-DAG: %[[INPUT_C:.+]] = tensor.dim %[[INPUT]], %[[C3]] : tensor<?x?x?x?xf32>
228-
// CHECK-DAG: %[[FILL_H:.+]] = tensor.dim %[[FILL]], %[[C1]] : tensor<?x?x?x?xf32>
229-
// CHECK-DAG: %[[FILL_W:.+]] = tensor.dim %[[FILL]], %[[C2]] : tensor<?x?x?x?xf32>
230226

231227
// CHECK: scf.for %[[IV0:.+]] = %{{.+}} to %[[ELEM_N]] step %{{.+}} iter_args(%{{.+}} = %[[FILL]])
232228
// CHECK-NEXT: %[[SIZE_ELEM_N:.+]] = affine.min #[[BOUND8_MAP]](%[[IV0]])[%[[ELEM_N]]]
233229
// CHECK-NEXT: %[[SIZE_INPUT_N:.+]] = affine.min #[[BOUND8_MAP_2]](%[[IV0]])[%[[INPUT_N]], %[[ELEM_N]]]
234230
// CHECK-NEXT: scf.for %[[IV1:.+]] = %{{.+}} to %[[ELEM_OH]]
235231
// CHECK-NEXT: %[[SIZE_ELEM_OH:.+]] = affine.min #[[BOUND16_MAP]](%[[IV1]])[%[[ELEM_OH]]]
236232
// CHECK-NEXT: %[[OFFSET_OH:.+]] = affine.apply #[[X2_MAP]](%[[IV1]])
237-
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[FILL_H]], %[[FILTER_H]]]
238-
// CHECK-NEXT: %[[SIZE_ELEM_OH_2:.+]] = affine.min #[[BOUND16_MAP_2]](%[[IV1]])[%[[FILL_H]], %[[ELEM_OH]]]
233+
// CHECK-NEXT: %[[SIZE_INPUT_H:.+]] = affine.min #[[INPUT_BOUND]](%[[IV1]], %[[SIZE_ELEM_OH]])[%[[ELEM_OH]], %[[FILTER_H]]]
239234
// CHECK-NEXT: scf.for %[[IV2:.+]] = %{{.+}} to %[[ELEM_OW]]
240235
// CHECK-NEXT: %[[SIZE_ELEM_OW:.+]] = affine.min #[[BOUND4_MAP]](%[[IV2]])[%[[ELEM_OW]]]
241236
// CHECK-NEXT: %[[SIZE_ELEM_OC:.+]] = affine.min #[[BOUND2_MAP]](%[[IV2]])[%[[ELEM_OC]]]
242237
// CHECK-NEXT: %[[OFFSET_OW:.+]] = affine.apply #[[X2_MAP]](%[[IV2]])
243-
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[FILL_W]], %[[FILTER_W]]]
244-
// CHECK-NEXT: %[[SIZE_ELEM_OW_2:.+]] = affine.min #[[BOUND4_MAP_2]](%[[IV2]])[%[[FILL_W]], %[[ELEM_OW]]]
238+
// CHECK-NEXT: %[[SIZE_INPUT_W:.+]] = affine.min #[[INPUT_BOUND]](%[[IV2]], %[[SIZE_ELEM_OW]])[%[[ELEM_OW]], %[[FILTER_W]]]
245239
// CHECK-NEXT: %[[ST_INPUT:.+]] = tensor.extract_slice %[[INPUT]][%[[IV0]], %[[OFFSET_OH]], %[[OFFSET_OW]], 0]
246240
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_INPUT_H]], %[[SIZE_INPUT_W]], %[[INPUT_C]]]
247241
// CHECK-NEXT: scf.for %[[IV3:.+]] = %{{.+}} to %[[ELEM_OC]] step %{{.+}} iter_args(%[[ARG:[a-z0-9]+]]
@@ -253,7 +247,7 @@ func.func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?
253247
// CHECK-NEXT: %[[ST_FILTER:.+]] = tensor.extract_slice %[[FILTER]][0, 0, 0, %[[IV3]]]
254248
// CHECK-SAME: [%[[FILTER_H]], %[[FILTER_W]], %[[FILTER_IC]], %[[SIZE_ELEM_OC_2]]]
255249
// CHECK-NEXT: %[[ST_FILL:.+]] = tensor.extract_slice %[[FILL]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]]
256-
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH_2]], %[[SIZE_ELEM_OW_2]], %[[SIZE_ELEM_OC_2]]]
250+
// CHECK-SAME: [%[[SIZE_INPUT_N]], %[[SIZE_ELEM_OH]], %[[SIZE_ELEM_OW]], %[[SIZE_ELEM_OC_2]]]
257251
// CHECK-NEXT: %[[ST_CONV:.+]] = linalg.conv_2d_nhwc_hwcf
258252
// CHECK-SAME: ins(%[[ST_INPUT]], %[[ST_FILTER]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
259253
// CHECK-SAME: outs(%[[ST_FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>

mlir/test/Dialect/Linalg/transform-tile-reduction.mlir

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ transform.sequence failures(propagate) {
4343
// CHECK: arith.addf
4444
// CHECK: linalg.yield
4545
// CHECK: } -> tensor<?x?xf32>
46-
// CHECK: %[[D3:.*]] = tensor.dim %[[PR]], %[[C0]] : tensor<?x?xf32>
47-
// CHECK: %[[D4:.*]] = tensor.dim %[[PR]], %[[C1]] : tensor<?x?xf32>
48-
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
46+
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D0]], %[[PS]]] [1, 1] : tensor<?x?xf32> into tensor<?x5xf32>
4947
// CHECK: scf.yield %[[INS]] : tensor<?x5xf32>
5048
// CHECK: }
5149
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "reduction"]} ins(%[[L]] : tensor<?x5xf32>) outs(%[[ARG1]] : tensor<?xf32>) {
@@ -76,14 +74,16 @@ transform.sequence failures(propagate) {
7674
by tile_sizes = [5, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op)
7775
}
7876

77+
// CHECK-DAG: #[[MAP0:.*]] = affine_map<(d0)[s0] -> (-d0 + s0, 5)>
78+
// CHECK-DAG: #[[MAP1:.*]] = affine_map<(d0, d1) -> (d0, d1)>
79+
// CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1) -> (d1)>
7980
// CHECK: func @reduction_tile_transpose
8081
// CHECK: tensor.empty(%{{.*}}) : tensor<5x?xf32>
8182
// CHECK: linalg.fill {{.*}} : tensor<5x?xf32>) -> tensor<5x?xf32>
8283
// CHECK: scf.for
83-
// CHECK: linalg.generic
84-
// CHECK: %[[D3:.*]] = tensor.dim %{{.*}}, %[[C0]] : tensor<?x?xf32>
85-
// CHECK: %[[D4:.*]] = tensor.dim %{{.*}}, %[[C1]] : tensor<?x?xf32>
86-
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[PR]] into %[[ARG3]][0, 0] [%[[D3]], %[[D4]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
84+
// CHECK: %[[EXT:.*]] = tensor.extract_slice %[[ARG3:.*]][0, 0] [%[[D0:.*]], %[[D1:.*]]] [1, 1] : tensor<5x?xf32> to tensor<?x?xf32>
85+
// CHECK: %[[R:.*]] = linalg.generic {indexing_maps = [#[[MAP1]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[L:.*]] : tensor<?x?xf32>) outs(%[[EXT]] : tensor<?x?xf32>)
86+
// CHECK: %[[INS:.*]] = tensor.insert_slice %[[R]] into %[[ARG3]][0, 0] [%[[D0]], %[[D1]]] [1, 1] : tensor<?x?xf32> into tensor<5x?xf32>
8787
// CHECK: scf.yield {{.*}} : tensor<5x?xf32>
8888
// CHECK: }
8989
// CHECK: linalg.generic

0 commit comments

Comments
 (0)