Skip to content

Commit 70eb0e3

Browse files
authored
[mlir][tensor] Fix tensor.pad to remove newly static values (#79938)
The canonicalization incrementally converts foldable dynamic hi/lo padding to static hi/lo values. During this canonicalization the static-fied valued should be removed from the dynamic values.
1 parent 198652a commit 70eb0e3

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3158,19 +3158,23 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
31583158

31593159
// Extract the static info from the high and low operands.
31603160
SmallVector<int64_t> constOperandsLow;
3161+
SmallVector<Value> newLows;
31613162
for (auto operand : padTensorOp.getLow()) {
31623163
APSInt intOp;
31633164
if (!matchPattern(operand, m_ConstantInt(&intOp))) {
31643165
constOperandsLow.push_back(ShapedType::kDynamic);
3166+
newLows.push_back(operand);
31653167
continue;
31663168
}
31673169
constOperandsLow.push_back(intOp.getExtValue());
31683170
}
31693171
SmallVector<int64_t> constOperandsHigh;
3172+
SmallVector<Value> newHighs;
31703173
for (auto operand : padTensorOp.getHigh()) {
31713174
APSInt intOp;
31723175
if (!matchPattern(operand, m_ConstantInt(&intOp))) {
31733176
constOperandsHigh.push_back(ShapedType::kDynamic);
3177+
newHighs.push_back(operand);
31743178
continue;
31753179
}
31763180
constOperandsHigh.push_back(intOp.getExtValue());
@@ -3222,7 +3226,7 @@ struct FoldStaticPadding : public OpRewritePattern<PadOp> {
32223226
newOutDims, padTensorOp.getType().getElementType());
32233227
auto newOp = rewriter.create<PadOp>(
32243228
padTensorOp->getLoc(), newResultType, input, staticLow, staticHigh,
3225-
padTensorOp.getLow(), padTensorOp.getHigh(), padTensorOp.getNofold(),
3229+
newLows, newHighs, padTensorOp.getNofold(),
32263230
getPrunedAttributeList(padTensorOp, PadOp::getAttributeNames()));
32273231

32283232
IRMapping mapper;

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,7 @@ func.func @pad_same_static_shape(%arg0: tensor<5x6xf32>, %a: index)
13611361
// CHECK-LABEL: func @pad_fold_static(
13621362
// CHECK-SAME: %[[INPUT:.*]]: tensor<?x64x?x?xf32>) -> tensor<?x?x?x?xf32> {
13631363
// CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32
1364-
// CHECK: %[[PADDING:.*]] = arith.constant 4 : index
1364+
// CHECK-NOT: arith.constant 4 : index
13651365
// CHECK: %[[PADDED:.*]] = tensor.pad %[[INPUT]]
13661366
// CHECK-SAME: low[0, 4, 1, 1] high[0, 4, 1, 1] {
13671367
// CHECK: ^bb0(%[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index):

0 commit comments

Comments
 (0)