Skip to content

Commit 14e7846

Browse files
[mlir][Tensor] Fold destination-style ops into tensor.unpack operation. (#71468)
The destination operand of the `tensor.unpack` operation is only needed to carry shape information. So if the producer of the destination operand implements the `DestinationStyleOpInterface`, then fold it into the `tensor.unpack` operation by replacing the destination operand with the destination for the source.
1 parent 11c1827 commit 14e7846

File tree

2 files changed

+37
-10
lines changed

2 files changed

+37
-10
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3922,18 +3922,29 @@ UnPackOp UnPackOp::createTransposedClone(OpBuilder &b, Location loc,
39223922
metadata.outerDimsPerm);
39233923
}
39243924

3925-
/// pack(unpack(x)) -> x
39263925
LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
39273926
PatternRewriter &rewriter) {
3928-
PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>();
3929-
if (!packOp || packOp.getDestType() != unPackOp.getSourceType())
3930-
return failure();
3931-
if (packOp.getPaddingValue() ||
3932-
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
3933-
!haveSameTiles(packOp, unPackOp))
3934-
return failure();
3935-
rewriter.replaceOp(unPackOp, packOp.getSource());
3936-
return success();
3927+
/// pack(unpack(x)) -> x
3928+
if (PackOp packOp = unPackOp.getSource().getDefiningOp<tensor::PackOp>()) {
3929+
if (packOp.getDestType() != unPackOp.getSourceType())
3930+
return failure();
3931+
if (packOp.getPaddingValue() ||
3932+
!hasSameInnerOuterAttribute(packOp, unPackOp) ||
3933+
!haveSameTiles(packOp, unPackOp))
3934+
return failure();
3935+
rewriter.replaceOp(unPackOp, packOp.getSource());
3936+
return success();
3937+
}
3938+
/// unpack(destinationStyleOp(x)) -> unpack(x)
3939+
if (auto dstStyleOp =
3940+
unPackOp.getDest().getDefiningOp<DestinationStyleOpInterface>()) {
3941+
auto destValue = unPackOp.getDest().cast<OpResult>();
3942+
Value newDest = dstStyleOp.getDpsInits()[destValue.getResultNumber()];
3943+
rewriter.updateRootInPlace(
3944+
unPackOp, [&]() { unPackOp.setDpsInitOperand(0, newDest); });
3945+
return success();
3946+
}
3947+
return failure();
39373948
}
39383949

39393950
bool UnPackOp::isLikeUnPad() {

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1861,3 +1861,19 @@ func.func @invalid_empty_negative_size() -> (tensor<4x5x?xf32>) {
18611861
%1 = tensor.empty(%0) : tensor<4x5x?xf32>
18621862
return %1 : tensor<4x5x?xf32>
18631863
}
1864+
1865+
// -----
1866+
1867+
// Fold DstStyleOp -> tensor.unpack operations.
1868+
func.func @fold_dst_style_ops_into_unpack(%arg0 : tensor<?x?x16x64xf32>, %init : tensor<?x?xf32>) -> tensor<?x?xf32> {
1869+
%cst = arith.constant 0.0 : f32
1870+
%fill = linalg.fill ins(%cst : f32) outs(%init : tensor<?x?xf32>) -> tensor<?x?xf32>
1871+
%unpack = tensor.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [16, 64] into %fill : tensor<?x?x16x64xf32> -> tensor<?x?xf32>
1872+
return %unpack : tensor<?x?xf32>
1873+
}
1874+
// CHECK-LABEL: func @fold_dst_style_ops_into_unpack
1875+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x16x64xf32>
1876+
// CHECK-SAME: %[[INIT:.+]]: tensor<?x?xf32>
1877+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[ARG0]]
1878+
// CHECK-SAME: into %[[INIT]]
1879+
// CHECK: return %[[UNPACK]]

0 commit comments

Comments
 (0)