Skip to content

Commit 1407f5b

Browse files
Max191Max Dawkins
and
Max Dawkins
authored
[mlir] Canonicalize extract_slice(unpack) (#133777)
Canonicalizes a chain of `linalg.unpack -> tensor.extract_slice` into a `linalg.unpack` with reduced dest sizes. This will only happen when the unpack op's only user is a non rank-reducing slice with zero offset and unit strides. --------- Signed-off-by: Max Dawkins <[email protected]> Signed-off-by: Max Dawkins <[email protected]> Co-authored-by: Max Dawkins <[email protected]>
1 parent 0e3049c commit 1407f5b

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "mlir/IR/AffineExprVisitor.h"
3030
#include "mlir/IR/AffineMap.h"
3131
#include "mlir/IR/Attributes.h"
32+
#include "mlir/IR/Builders.h"
3233
#include "mlir/IR/BuiltinAttributes.h"
3334
#include "mlir/IR/BuiltinTypeInterfaces.h"
3435
#include "mlir/IR/Matchers.h"
@@ -5243,6 +5244,29 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
52435244
[&]() { unPackOp.setDpsInitOperand(0, newDest); });
52445245
return success();
52455246
}
5247+
/// extract_slice(unpack(x into y)) -> unpack(x into extract_slice(y))
5248+
if (unPackOp->hasOneUse()) {
5249+
auto extractSliceUser =
5250+
dyn_cast<tensor::ExtractSliceOp>(*unPackOp->getUsers().begin());
5251+
if (extractSliceUser &&
5252+
areAllConstantIntValue(extractSliceUser.getMixedOffsets(), 0) &&
5253+
areAllConstantIntValue(extractSliceUser.getMixedStrides(), 1) &&
5254+
extractSliceUser.getSourceType().getRank() ==
5255+
extractSliceUser.getResultType().getRank()) {
5256+
OpBuilder::InsertionGuard g(rewriter);
5257+
rewriter.setInsertionPoint(unPackOp);
5258+
auto newDest = rewriter.create<tensor::ExtractSliceOp>(
5259+
unPackOp->getLoc(), unPackOp.getDest(),
5260+
extractSliceUser.getMixedOffsets(), extractSliceUser.getMixedSizes(),
5261+
extractSliceUser.getMixedStrides());
5262+
rewriter.modifyOpInPlace(unPackOp, [&]() {
5263+
unPackOp.setDpsInitOperand(0, newDest);
5264+
unPackOp.getResult().setType(newDest.getType());
5265+
});
5266+
rewriter.replaceOp(extractSliceUser, unPackOp);
5267+
return success();
5268+
}
5269+
}
52465270

52475271
// Insert tensor.cast ops if static shape inference is available..
52485272
SmallVector<int64_t> srcShape, destShape;

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1772,3 +1772,78 @@ func.func @fold_cast_unpack_dynamic_tile_size(
17721772
into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
17731773
return %unpack : tensor<7x?xi32>
17741774
}
1775+
1776+
// -----
1777+
1778+
//===----------------------------------------------------------------------===//
1779+
// linalg.unpack + tensor.extract_slice
1780+
//===----------------------------------------------------------------------===//
1781+
1782+
func.func @fold_extract_slice_into_unpack(
1783+
%src : tensor<28x2x?x16x16xf32>, %dest : tensor<28x32x?xf32>, %size : index
1784+
) -> tensor<28x28x?xf32> {
1785+
%unpack = linalg.unpack %src
1786+
outer_dims_perm = [0, 1, 2]
1787+
inner_dims_pos = [1, 2]
1788+
inner_tiles = [16, 16]
1789+
into %dest : tensor<28x2x?x16x16xf32> -> tensor<28x32x?xf32>
1790+
%extracted_slice = tensor.extract_slice %unpack
1791+
[0, 0, 0] [28, 28, %size] [1, 1, 1] : tensor<28x32x?xf32> to tensor<28x28x?xf32>
1792+
return %extracted_slice : tensor<28x28x?xf32>
1793+
}
1794+
1795+
// CHECK-LABEL: func @fold_extract_slice_into_unpack
1796+
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x?x16x16xf32>
1797+
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32x?xf32>
1798+
// CHECK-SAME: %[[SIZE:.+]]: index
1799+
// CHECK: %[[DEST_SLICE:.+]] = tensor.extract_slice %[[DEST]]
1800+
// CHECK-SAME: [0, 0, 0] [28, 28, %[[SIZE]]] [1, 1, 1]
1801+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1802+
// CHECK-SAME: into %[[DEST_SLICE]]
1803+
// CHECK: return %[[UNPACK]]
1804+
1805+
// -----
1806+
1807+
func.func @no_fold_extract_slice_into_unpack_rank_reducing(
1808+
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
1809+
) -> tensor<28xf32> {
1810+
%unpack = linalg.unpack %src
1811+
outer_dims_perm = [0, 1]
1812+
inner_dims_pos = [1]
1813+
inner_tiles = [16]
1814+
into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
1815+
%extracted_slice = tensor.extract_slice %unpack
1816+
[0, 0] [1, 28] [1, 1] : tensor<28x32xf32> to tensor<28xf32>
1817+
return %extracted_slice : tensor<28xf32>
1818+
}
1819+
1820+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_rank_reducing
1821+
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1822+
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1823+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1824+
// CHECK-SAME: into %[[DEST]]
1825+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1826+
// CHECK: return %[[SLICE]]
1827+
1828+
// -----
1829+
1830+
func.func @no_fold_extract_slice_into_unpack_non_zero_offset(
1831+
%src : tensor<28x2x16xf32>, %dest : tensor<28x32xf32>
1832+
) -> tensor<28x28xf32> {
1833+
%unpack = linalg.unpack %src
1834+
outer_dims_perm = [0, 1]
1835+
inner_dims_pos = [1]
1836+
inner_tiles = [16]
1837+
into %dest : tensor<28x2x16xf32> -> tensor<28x32xf32>
1838+
%extracted_slice = tensor.extract_slice %unpack
1839+
[0, 1] [28, 28] [1, 1] : tensor<28x32xf32> to tensor<28x28xf32>
1840+
return %extracted_slice : tensor<28x28xf32>
1841+
}
1842+
1843+
// CHECK-LABEL: func @no_fold_extract_slice_into_unpack_non_zero_offset
1844+
// CHECK-SAME: %[[SRC:.+]]: tensor<28x2x16xf32>
1845+
// CHECK-SAME: %[[DEST:.+]]: tensor<28x32xf32>
1846+
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[SRC]]
1847+
// CHECK-SAME: into %[[DEST]]
1848+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]]
1849+
// CHECK: return %[[SLICE]]

0 commit comments

Comments
 (0)