Skip to content

Commit b006902

Browse files
[mlir] Fold trivial subtensor / subtensor_insert ops.
Static subtensor / subtensor_insert of the same size as the source / destination tensor and root @[0..0] with strides [1..1] are folded away. Differential revision: https://reviews.llvm.org/D96991
1 parent b7e05c8 commit b006902

File tree

4 files changed

+83
-0
lines changed

4 files changed

+83
-0
lines changed

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,12 @@ bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs,
364364
/// comparison predicates.
365365
bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs,
366366
const APFloat &rhs);
367+
368+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
369+
/// or the same SSA value.
370+
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
371+
/// no IndexAttr and that IndexType have no bitwidth.
372+
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2);
367373
} // end namespace mlir
368374

369375
#endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H

mlir/include/mlir/Dialect/StandardOps/IR/Ops.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2928,6 +2928,7 @@ def SubTensorOp : BaseOpWithOffsetSizesAndStrides<
29282928
}];
29292929

29302930
let hasCanonicalizer = 1;
2931+
let hasFolder = 1;
29312932
}
29322933

29332934
//===----------------------------------------------------------------------===//
@@ -3026,6 +3027,7 @@ def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<
30263027
/// and `strides` operands.
30273028
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 2; }
30283029
}];
3030+
let hasFolder = 1;
30293031
}
30303032

30313033
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,27 @@ static void dispatchIndexOpFoldResults(ArrayRef<OpFoldResult> ofrs,
5959
dispatchIndexOpFoldResult(ofr, dynamicVec, staticVec, sentinel);
6060
}
6161

62+
/// Return true if ofr1 and ofr2 are the same integer constant attribute values
63+
/// or the same SSA value.
64+
/// Ignore integer bitwitdh and type mismatch that come from the fact there is
65+
/// no IndexAttr and that IndexType have no bitwidth.
66+
bool mlir::isEqualConstantIntOrValue(OpFoldResult op1, OpFoldResult op2) {
67+
auto getConstantIntValue = [](OpFoldResult ofr) -> llvm::Optional<int64_t> {
68+
Attribute attr = ofr.dyn_cast<Attribute>();
69+
// Note: isa+cast-like pattern allows writing the condition below as 1 line.
70+
if (!attr && ofr.get<Value>().getDefiningOp<ConstantOp>())
71+
attr = ofr.get<Value>().getDefiningOp<ConstantOp>().getValue();
72+
if (auto intAttr = attr.dyn_cast_or_null<IntegerAttr>())
73+
return intAttr.getValue().getSExtValue();
74+
return llvm::None;
75+
};
76+
auto cst1 = getConstantIntValue(op1), cst2 = getConstantIntValue(op2);
77+
if (cst1 && cst2 && *cst1 == *cst2)
78+
return true;
79+
auto v1 = op1.dyn_cast<Value>(), v2 = op2.dyn_cast<Value>();
80+
return v1 && v2 && v1 == v2;
81+
}
82+
6283
//===----------------------------------------------------------------------===//
6384
// StandardOpsDialect Interfaces
6485
//===----------------------------------------------------------------------===//
@@ -3557,6 +3578,34 @@ void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
35573578
context);
35583579
}
35593580

3581+
//
3582+
static LogicalResult
3583+
foldIdentityOffsetSizeAndStrideOpInterface(OffsetSizeAndStrideOpInterface op,
3584+
ShapedType shapedType) {
3585+
OpBuilder b(op.getContext());
3586+
for (OpFoldResult ofr : op.getMixedOffsets())
3587+
if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(0)))
3588+
return failure();
3589+
// Rank-reducing noops only need to inspect the leading dimensions: llvm::zip
3590+
// is appropriate.
3591+
auto shape = shapedType.getShape();
3592+
for (auto it : llvm::zip(op.getMixedSizes(), shape))
3593+
if (!isEqualConstantIntOrValue(std::get<0>(it),
3594+
b.getIndexAttr(std::get<1>(it))))
3595+
return failure();
3596+
for (OpFoldResult ofr : op.getMixedStrides())
3597+
if (!isEqualConstantIntOrValue(ofr, b.getIndexAttr(1)))
3598+
return failure();
3599+
return success();
3600+
}
3601+
3602+
OpFoldResult SubTensorOp::fold(ArrayRef<Attribute>) {
3603+
if (getSourceType() == getType() &&
3604+
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
3605+
return this->source();
3606+
return OpFoldResult();
3607+
}
3608+
35603609
//===----------------------------------------------------------------------===//
35613610
// SubTensorInsertOp
35623611
//===----------------------------------------------------------------------===//
@@ -3597,6 +3646,13 @@ void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result,
35973646
build(b, result, source, dest, offsetValues, sizeValues, strideValues);
35983647
}
35993648

3649+
OpFoldResult SubTensorInsertOp::fold(ArrayRef<Attribute>) {
3650+
if (getSourceType() == getType() &&
3651+
succeeded(foldIdentityOffsetSizeAndStrideOpInterface(*this, getType())))
3652+
return this->source();
3653+
return OpFoldResult();
3654+
}
3655+
36003656
//===----------------------------------------------------------------------===//
36013657
// TensorLoadOp
36023658
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Standard/canonicalize.mlir

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,22 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
157157
memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
158158
return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
159159
}
160+
161+
// CHECK-LABEL: func @trivial_subtensor
162+
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
163+
// CHECK-NOT: subtensor
164+
// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
165+
func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
166+
%0 = subtensor %arg0[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<4x6x16x32xi8>
167+
return %0 : tensor<4x6x16x32xi8>
168+
}
169+
170+
// CHECK-LABEL: func @trivial_subtensor_insert
171+
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
172+
// CHECK-NOT: subtensor
173+
// CHECK: return %[[ARG0]] : tensor<4x6x16x32xi8>
174+
func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
175+
%0 = subtensor_insert %arg0 into %arg1[0, 0, 0, 0] [4, 6, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> into tensor<4x6x16x32xi8>
176+
return %0 : tensor<4x6x16x32xi8>
177+
}
178+

0 commit comments

Comments
 (0)