-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[mlir] [tensor] Add patterns to remove whole slicing of tensors #107046
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tensor Author: Menooker (Menooker) ChangesEliminate the redundant Examples of the extract/insert to be removed:
Patch is 20.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107046.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..4ae782661681f1 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,12 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);
+/// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
+/// The patterns remove extract_slice and insert_slice when the size matches
+/// and the offsets of the slice are all zeros and strides are all ones.
+void populateEliminateWholeSlicingPatterns(
+ RewritePatternSet &patterns);
+
//===----------------------------------------------------------------------===//
// Transform helpers
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index ce32dea09bb0b5..d5bbedd13e7acc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
RewriteAsConstant.cpp
SwapExtractSliceWithProducerPatterns.cpp
SubsetInsertionOpInterfaceImpl.cpp
+ EliminateWholeSlicePatterns.cpp
ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
new file mode 100644
index 00000000000000..52ca6a9e6f65f0
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
@@ -0,0 +1,98 @@
+//===- EliminateWholeSlicePatterns.cpp - Patterns to remove whole slices --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+bool checkEliminateOK(PatternRewriter &rewriter,
+ OffsetSizeAndStrideOpInterface sliceOp,
+ mlir::TypedValue<mlir::RankedTensorType> smallerTensor,
+ mlir::TypedValue<mlir::RankedTensorType> largerTensor) {
+ auto srcType = largerTensor.getType();
+ auto resultType = smallerTensor.getType();
+ if (!isSameTypeWithoutEncoding(srcType, resultType)) {
+ // fast failure path when in and out types do not match
+ return false;
+ }
+ // both types are ensured to have the same rank
+ for (int64_t i = 0; i < resultType.getRank(); ++i) {
+ // check the ExtractSliceOp offsets, should be all-zero
+ if (sliceOp.isDynamicOffset(i) || sliceOp.getStaticOffset(i) != 0)
+ return false;
+ // check the ExtractSliceOp Strides, should be all-one
+ if (sliceOp.isDynamicStride(i) || sliceOp.getStaticStride(i) != 1)
+ return false;
+ }
+ // check if the dynamic shape matchs
+ if (resultType.getNumDynamicDims() != 0) {
+ for (int64_t i = 0; i < resultType.getRank(); ++i) {
+ if (resultType.isDynamicDim(i)) {
+ auto largeDim =
+ getMixedSize(rewriter, sliceOp.getLoc(), largerTensor, i);
+ auto smallDim = sliceOp.getDynamicSize(i);
+ if (largeDim.dyn_cast<Value>() != smallDim) {
+ return false;
+ }
+ }
+ }
+ }
+ // if the tensor is in static-shape, we already checked the shapes match via
+ // isSameTypeWithoutEncoding
+ return true;
+}
+
+struct EliminateWholeSliceExtractSliceOp
+ : public OpRewritePattern<ExtractSliceOp> {
+ EliminateWholeSliceExtractSliceOp(MLIRContext *ctx)
+ : OpRewritePattern<ExtractSliceOp>(ctx) {}
+
+ LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getResult(),
+ sliceOp.getSource())) {
+ return failure();
+ }
+ // all checking are done. Rewrite the IR
+ rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+ rewriter.eraseOp(sliceOp);
+ return success();
+ }
+};
+
+struct EliminateWholeSliceInsertSliceOp
+ : public OpRewritePattern<InsertSliceOp> {
+ EliminateWholeSliceInsertSliceOp(MLIRContext *ctx)
+ : OpRewritePattern<InsertSliceOp>(ctx) {}
+
+ LogicalResult matchAndRewrite(InsertSliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getSource(),
+ sliceOp.getDest())) {
+ return failure();
+ }
+ // all checking are done. Rewrite the IR
+ rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+ rewriter.eraseOp(sliceOp);
+ return success();
+ }
+};
+
+} // namespace
+
+void mlir::tensor::populateEliminateWholeSlicingPatterns(
+ RewritePatternSet &patterns) {
+ patterns
+ .add<EliminateWholeSliceExtractSliceOp, EliminateWholeSliceInsertSliceOp>(
+ patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
new file mode 100644
index 00000000000000..077d36c26d4816
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-eliminate-whole-slicing-patterns -canonicalize -mlir-print-local-scope %s | FileCheck %s
+
+//////////////////////////////
+// here starts the tests for insert_slice
+//////////////////////////////
+
+func.func @elim_dyn_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_dyn_insert
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+// CHECK: return %[[INSERT]]
+
+func.func @elim_static_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_static_insert
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+// CHECK: return %[[INSERT]]
+
+func.func @fail_dyn_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_dyn_insert_shape
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+func.func @fail_static_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %3 = tensor.empty() : tensor<14x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<14x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_static_insert_shape
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: tensor.empty()
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+func.func @fail_dyn_insert_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, %arg2] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to dynamic stride
+// CHECK-LABEL: func.func @fail_dyn_insert_stride
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_insert_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %c0f = arith.constant 0.0 : bf16
+ %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+ %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 1] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+ %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+ return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_insert_offset
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
+// CHECK: %[[FILL:.+]] = linalg.fill
+// CHECK: tensor.insert_slice
+// CHECK: %[[INSERT:.+]] = tensor.insert_slice
+// CHECK: return %[[INSERT]]
+
+//////////////////////////////
+// here starts the tests for extract_slice
+//////////////////////////////
+func.func @elim_dyn_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_dyn_extract
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [%[[OFFSET1]], 32, 32, 32]
+// CHECK: return %[[EXTRACT]]
+
+
+func.func @elim_static_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_static_extract
+// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [15, 32, 32, 32]
+// CHECK: return %[[EXTRACT]]
+
+// fail to optimize due to unmatched shape
+func.func @fail_dyn_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_dyn_extract_shape
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+// fail to optimize due to unmatched shape
+func.func @fail_static_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<14x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<14x32x32x32xbf16>
+ return %extracted_slice2 : tensor<14x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_shape
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+// fail to optimize due to stride
+func.func @fail_extract_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 3] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+ return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_stride
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_extract_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<15x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_offset
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+
+
+//////////////////////////////
+// here starts the tests for expanding/reducing dims
+//////////////////////////////
+func.func @fail_extract_reduce(%arg0: tensor<1x32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+ %extracted_slice = tensor.extract_slice %arg0[0, %arg2, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x32x32x32x32xbf16> to tensor<1x15x32x32x32xbf16>
+ %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_reduce
+// CHECK: tensor.extract_slice
+// CHECK: tensor.extract_slice
+// CHECK: return
+
+func.func @fail_insert_expand(%arg0: tensor<1x15x32x32x32xbf16>, %arg1: tensor<1x15x32x32x32xbf16>, %arg2: index) -> tensor<1x15x32x32x32xbf16> {
+ %extracted_slice = tensor.empty(): tensor<15x32x32x32xbf16>
+ %extracted_slice2 = tensor.insert_slice %extracted_slice into %arg0[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<1x15x32x32x32xbf16>
+ return %extracted_slice2 : tensor<1x15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_insert_expand
+// CHECK: tensor.empty
+// CHECK: tensor.insert_slice
+// CHECK: return
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 34de600132f5de..a4a91ffd3b7660 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -98,6 +98,12 @@ struct TestTensorTransforms
*this, "test-tracking-listener",
llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
llvm::cl::init(false)};
+
+ Option<bool> testEliminateWholeSlicingPatterns{
+ *this, "test-eliminate-whole-slicing-patterns",
+ llvm::cl::desc("Test patterns to eliminate whole-slicing extract_slice "
+ "and insert_slice"),
+ llvm::cl::init(false)};
};
} // namespace
@@ -154,6 +160,12 @@ static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}
+static void applyEliminateWholeSlicingPatterns(Operation *rootOp) {
+ RewritePatternSet patterns(rootOp->getContext());
+ tensor::populateEliminateWholeSlicingPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -406,6 +418,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return si...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
The long story why we need this pattern. In our downstream project, we have the IR for tiling, like %a = extract_slice %out[...] : tensor<...>
%result = scf.for(...) iter_args(%a1 = %a) {
%slice = extract_slice %a1[...]: tensor<...>
%0 = elementwise_compute(%slice) outs(%a) : tensor<...>
%inserted = insert_slice %0 into %a1[...] : tensor<...>
scf.yield %inserted
}
%final = insert_slice %result into %out[...] If the tile size for the loop happens to be the same size of %a = extract_slice %out[...] : tensor<...> // slice a sub-tensor from out
%slice = extract_slice %a[...]: tensor<...> // whole-slicing over %a. %slice has the same size of %a
%0 = elementwise_compute(%slice) outs(%a) : tensor<...>
%inserted = insert_slice %0 into %a[...] : tensor<...> // overwriting %a as a whole. %0 has the same size of %a
%final = insert_slice %inserted into %out[...] And we have %a = extract_slice %out[...] : tensor<...> // slice a sub-tensor from out
%0 = elementwise_compute(%a) outs(%a) : tensor<...>
%inserted = insert_slice %0 into %a[...] : tensor<...> // overwriting %a as a whole. %0 has the same size of %a
%final = insert_slice %inserted into %out[...] This will cause failure of in-placing buffer for Hence we introduce this pattern to remove the redundant The pattern have some overlapping with |
Why isn't this just part of |
I don’t implement this into fold/canoncialize because the bufferization is sensitive to the buffer positions (as is said by comments of populateMergeConsecutiveInsertExtractSlice. We are doing similar in this PR). There may be some cases when this pattern alters the bufferization result in an unexpected way of the user. So I bring this as an optional optimization pattern set. |
@matthias-springer can you comment here? To me this is a problem of coupling that would indicate that the bufferization needs a "prepare" phase independent of the canonicalization, but I am not sure I would understand that we need keep useless redundant information in the IR and can't canonicalize this away just because "maybe somehow that affect this one pass". |
When bufferizing a tensor operation, the buffer in which the result is materialized is determined by the tensor operands of the operations. So if some other transformation/canonicalization/folding changes the IR, it may bufferize in a different way. (Still in a functionally correct, but maybe less efficient way.) Ideally, we would have a more powerful bufferization analysis that can reason about tensor subsets/aliasing and generate the same efficient IR given two functionally equivalent input programs. (Also see this discussion here.) I see 3 ways to address this issue (ordered from best to worst):
|
Hi, reviewers! Any more comments regarding the patch itself? |
Eliminate the redundant
tensor.extract_slice
andtensor.insert_slice
when the slice size is proved to be the same as the source tensor. Dynamic shapes are also supported.Examples of the extract/insert to be removed: