Skip to content

[mlir] [tensor] Add patterns to remove whole slicing of tensors #107046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Menooker
Copy link

@Menooker Menooker commented Sep 3, 2024

Eliminate the redundant tensor.extract_slice and tensor.insert_slice when the slice size is proved to be the same as the source tensor. Dynamic shapes are also supported.

Examples of the extract/insert to be removed:

%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>

@llvmbot
Copy link
Member

llvmbot commented Sep 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: Menooker (Menooker)

Changes

Eliminate the redundant tensor.extract_slice and tensor.insert_slice when the slice size is proved to be the same as the source tensor. Dynamic shapes are also supported.

Examples of the extract/insert to be removed:

%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor&lt;15x32x32x32xbf16&gt; to tensor&lt;15x32x32x32xbf16&gt;
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor&lt;15x32x32x32xbf16&gt; into tensor&lt;15x32x32x32xbf16&gt;

Patch is 20.23 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/107046.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h (+6)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt (+1)
  • (added) mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp (+98)
  • (added) mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir (+194)
  • (modified) mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp (+14)
diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
index ae695e0326ca1a..4ae782661681f1 100644
--- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
@@ -102,6 +102,12 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
 void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
                                        const ControlFoldFn &controlFn);
 
+/// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
+/// The patterns remove extract_slice and insert_slice when the size matches
+/// and the offsets of the slice are all zeros and strides are all ones.
+void populateEliminateWholeSlicingPatterns(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Transform helpers
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
index ce32dea09bb0b5..d5bbedd13e7acc 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
@@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
   RewriteAsConstant.cpp
   SwapExtractSliceWithProducerPatterns.cpp
   SubsetInsertionOpInterfaceImpl.cpp
+  EliminateWholeSlicePatterns.cpp
 
   ADDITIONAL_HEADER_DIRS
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
diff --git a/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
new file mode 100644
index 00000000000000..52ca6a9e6f65f0
--- /dev/null
+++ b/mlir/lib/Dialect/Tensor/Transforms/EliminateWholeSlicePatterns.cpp
@@ -0,0 +1,98 @@
+//===- EliminateWholeSlicePatterns.cpp - Patterns to remove whole slices --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+#include "mlir/Dialect/Tensor/IR/Tensor.h"
+#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
+#include "mlir/IR/PatternMatch.h"
+
+using namespace mlir;
+using namespace mlir::tensor;
+
+namespace {
+
+bool checkEliminateOK(PatternRewriter &rewriter,
+                      OffsetSizeAndStrideOpInterface sliceOp,
+                      mlir::TypedValue<mlir::RankedTensorType> smallerTensor,
+                      mlir::TypedValue<mlir::RankedTensorType> largerTensor) {
+  auto srcType = largerTensor.getType();
+  auto resultType = smallerTensor.getType();
+  if (!isSameTypeWithoutEncoding(srcType, resultType)) {
+    // fast failure path when in and out types do not match
+    return false;
+  }
+  // both types are ensured to have the same rank
+  for (int64_t i = 0; i < resultType.getRank(); ++i) {
+    // check the ExtractSliceOp offsets, should be all-zero
+    if (sliceOp.isDynamicOffset(i) || sliceOp.getStaticOffset(i) != 0)
+      return false;
+    // check the ExtractSliceOp Strides, should be all-one
+    if (sliceOp.isDynamicStride(i) || sliceOp.getStaticStride(i) != 1)
+      return false;
+  }
+  // check if the dynamic shape matchs
+  if (resultType.getNumDynamicDims() != 0) {
+    for (int64_t i = 0; i < resultType.getRank(); ++i) {
+      if (resultType.isDynamicDim(i)) {
+        auto largeDim =
+            getMixedSize(rewriter, sliceOp.getLoc(), largerTensor, i);
+        auto smallDim = sliceOp.getDynamicSize(i);
+        if (largeDim.dyn_cast<Value>() != smallDim) {
+          return false;
+        }
+      }
+    }
+  }
+  // if the tensor is in static-shape, we already checked the shapes match via
+  // isSameTypeWithoutEncoding
+  return true;
+}
+
+struct EliminateWholeSliceExtractSliceOp
+    : public OpRewritePattern<ExtractSliceOp> {
+  EliminateWholeSliceExtractSliceOp(MLIRContext *ctx)
+      : OpRewritePattern<ExtractSliceOp>(ctx) {}
+
+  LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getResult(),
+                          sliceOp.getSource())) {
+      return failure();
+    }
+    // all checking are done. Rewrite the IR
+    rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+    rewriter.eraseOp(sliceOp);
+    return success();
+  }
+};
+
+struct EliminateWholeSliceInsertSliceOp
+    : public OpRewritePattern<InsertSliceOp> {
+  EliminateWholeSliceInsertSliceOp(MLIRContext *ctx)
+      : OpRewritePattern<InsertSliceOp>(ctx) {}
+
+  LogicalResult matchAndRewrite(InsertSliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+    if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getSource(),
+                          sliceOp.getDest())) {
+      return failure();
+    }
+    // all checking are done. Rewrite the IR
+    rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
+    rewriter.eraseOp(sliceOp);
+    return success();
+  }
+};
+
+} // namespace
+
+void mlir::tensor::populateEliminateWholeSlicingPatterns(
+    RewritePatternSet &patterns) {
+  patterns
+      .add<EliminateWholeSliceExtractSliceOp, EliminateWholeSliceInsertSliceOp>(
+          patterns.getContext());
+}
diff --git a/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
new file mode 100644
index 00000000000000..077d36c26d4816
--- /dev/null
+++ b/mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
@@ -0,0 +1,194 @@
+// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-eliminate-whole-slicing-patterns -canonicalize -mlir-print-local-scope %s | FileCheck %s
+
+//////////////////////////////
+// here starts the tests for insert_slice
+//////////////////////////////
+
+func.func @elim_dyn_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_dyn_insert
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+//       CHECK:   return %[[INSERT]]
+
+func.func @elim_static_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+
+// CHECK-LABEL: func.func @elim_static_insert
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
+//       CHECK:   return %[[INSERT]]
+
+func.func @fail_dyn_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_dyn_insert_shape
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+func.func @fail_static_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %3 = tensor.empty() : tensor<14x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<14x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to unmatched insert shape
+// CHECK-LABEL: func.func @fail_static_insert_shape
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice
+//       CHECK:   tensor.empty()
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+func.func @fail_dyn_insert_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, %arg2] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// fail to optimize due to dynamic stride
+// CHECK-LABEL: func.func @fail_dyn_insert_stride
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_insert_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %c0f = arith.constant 0.0 : bf16
+  %3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
+  %inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 1] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
+  %inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
+  return %inserted_slice_3 : tensor<32x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_insert_offset
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice
+//       CHECK:   %[[FILL:.+]] = linalg.fill
+//       CHECK:   tensor.insert_slice
+//       CHECK:   %[[INSERT:.+]] = tensor.insert_slice
+//       CHECK:   return %[[INSERT]]
+
+//////////////////////////////
+// here starts the tests for extract_slice
+//////////////////////////////
+func.func @elim_dyn_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_dyn_extract
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [%[[OFFSET1]], 32, 32, 32]
+//       CHECK:   return %[[EXTRACT]]
+
+
+func.func @elim_static_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @elim_static_extract
+//  CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
+//       CHECK:   %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [15, 32, 32, 32]
+//       CHECK:   return %[[EXTRACT]]
+
+// fail to optimize due to unmatched shape
+func.func @fail_dyn_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_dyn_extract_shape
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+// fail to optimize due to unmatched shape
+func.func @fail_static_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<14x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<14x32x32x32xbf16>
+  return %extracted_slice2 : tensor<14x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_shape
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+// fail to optimize due to stride
+func.func @fail_extract_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 3] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
+  return %extracted_slice2 : tensor<?x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_stride
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+// fail to optimize due to non-zero offset
+func.func @fail_static_extract_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<15x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_static_extract_offset
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+
+
+//////////////////////////////
+// here starts the tests for expanding/reducing dims
+//////////////////////////////
+func.func @fail_extract_reduce(%arg0: tensor<1x32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
+  %extracted_slice = tensor.extract_slice %arg0[0, %arg2, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x32x32x32x32xbf16> to tensor<1x15x32x32x32xbf16>
+  %extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_extract_reduce
+//       CHECK:   tensor.extract_slice
+//       CHECK:   tensor.extract_slice
+//       CHECK:   return
+
+func.func @fail_insert_expand(%arg0: tensor<1x15x32x32x32xbf16>, %arg1: tensor<1x15x32x32x32xbf16>, %arg2: index) -> tensor<1x15x32x32x32xbf16> {
+  %extracted_slice = tensor.empty(): tensor<15x32x32x32xbf16>
+  %extracted_slice2 = tensor.insert_slice %extracted_slice into %arg0[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<1x15x32x32x32xbf16>
+  return %extracted_slice2 : tensor<1x15x32x32x32xbf16>
+}
+// CHECK-LABEL: func.func @fail_insert_expand
+//       CHECK:   tensor.empty
+//       CHECK:   tensor.insert_slice
+//       CHECK:   return
\ No newline at end of file
diff --git a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
index 34de600132f5de..a4a91ffd3b7660 100644
--- a/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
@@ -98,6 +98,12 @@ struct TestTensorTransforms
       *this, "test-tracking-listener",
       llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
       llvm::cl::init(false)};
+
+  Option<bool> testEliminateWholeSlicingPatterns{
+      *this, "test-eliminate-whole-slicing-patterns",
+      llvm::cl::desc("Test patterns to eliminate whole-slicing extract_slice "
+                     "and insert_slice"),
+      llvm::cl::init(false)};
 };
 } // namespace
 
@@ -154,6 +160,12 @@ static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
   (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
 }
 
+static void applyEliminateWholeSlicingPatterns(Operation *rootOp) {
+  RewritePatternSet patterns(rootOp->getContext());
+  tensor::populateEliminateWholeSlicingPatterns(patterns);
+  (void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
+}
+
 namespace {
 /// Base pattern to rewrite  a `tensor.collapse_shape -> tensor.extract_slice`.
 /// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -406,6 +418,8 @@ void TestTensorTransforms::runOnOperation() {
             applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
       return si...
[truncated]

Copy link

github-actions bot commented Sep 3, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@Menooker
Copy link
Author

Menooker commented Sep 3, 2024

The long story why we need this pattern.

In our downstream project, we have the IR for tiling, like

%a = extract_slice %out[...] : tensor<...>
%result = scf.for(...) iter_args(%a1 = %a) {
     %slice = extract_slice %a1[...]: tensor<...>
     %0 = elementwise_compute(%slice) outs(%a) : tensor<...>
     %inserted = insert_slice %0 into %a1[...] : tensor<...>
     scf.yield %inserted
}
%final = insert_slice %result into %out[...]

If the tile size for the loop happens to be the same size of %out, the scf.for has trip-count as 1, and it can be canonicalized to

%a = extract_slice %out[...] : tensor<...>  // slice a sub-tensor from out
%slice = extract_slice %a[...]: tensor<...>  // whole-slicing over %a. %slice has the same size of %a
%0 = elementwise_compute(%slice) outs(%a) : tensor<...>
%inserted = insert_slice %0 into %a[...] : tensor<...> //  overwriting %a as a whole. %0 has the same size of %a
%final = insert_slice %inserted into %out[...]

And we have populateMergeConsecutiveInsertExtractSlicePatterns to remove the redundant extract_slice and insert_slice. However, the shape of %a = extract_slice %out[...] : tensor<...> (as well as %slice) is dynamic, and populateMergeConsecutiveInsertExtractSlicePatterns seems not handling that. And it finally generates IR:

%a = extract_slice %out[...] : tensor<...>  // slice a sub-tensor from out
%0 = elementwise_compute(%a) outs(%a) : tensor<...>
%inserted = insert_slice %0 into %a[...] : tensor<...> //  overwriting %a as a whole. %0 has the same size of %a
%final = insert_slice %inserted into %out[...]

This will cause failure of in-placing buffer for %a when bufferization. Because both elementwise_compute(%a) outs(%a) and insert_slice %0 into %a[...] writes to %a.

Hence we introduce this pattern to remove the redundant insert_slice.

The pattern have some overlapping with populateMergeConsecutiveInsertExtractSlicePatterns. However, it targets at different patterns in the IR.

@joker-eph
Copy link
Collaborator

Why isn't this just part of tensor.insert_slice folding?

@Menooker
Copy link
Author

Menooker commented Sep 4, 2024

Why isn't this just part of tensor.insert_slice folding?

I don’t implement this into fold/canoncialize because the bufferization is sensitive to the buffer positions (as is said by comments of populateMergeConsecutiveInsertExtractSlice. We are doing similar in this PR). There may be some cases when this pattern alters the bufferization result in an unexpected way of the user. So I bring this as an optional optimization pattern set.

@joker-eph
Copy link
Collaborator

joker-eph commented Sep 9, 2024

@matthias-springer can you comment here?

To me this is a problem of coupling that would indicate that the bufferization needs a "prepare" phase independent of the canonicalization, but I am not sure I would understand that we need keep useless redundant information in the IR and can't canonicalize this away just because "maybe somehow that affect this one pass".

@matthias-springer
Copy link
Member

When bufferizing a tensor operation, the buffer in which the result is materialized is determined by the tensor operands of the operations. So if some other transformation/canonicalization/folding changes the IR, it may bufferize in a different way. (Still in a functionally correct, but maybe less efficient way.) Ideally, we would have a more powerful bufferization analysis that can reason about tensor subsets/aliasing and generate the same efficient IR given two functionally equivalent input programs. (Also see this discussion here.)

I see 3 ways to address this issue (ordered from best to worst):

  1. Improve the bufferization analysis, so that the bufferization framework can generate efficient output IR for a wider variety of input IR. Maybe such an analysis is no longer based on "destination-passing style". (But the reason that we have DPS is because such an analysis is difficult and DPS is a way to inject hints into the analysis.) This has come up multiple times so far, and I haven't seen anyone willing to sign up for this task so far.
  2. Not using the canoncalizer pass in a compilation pipeline (between the part that generates/"prepares" tensor IR and the bufferization pass). Only run the patterns/foldings that you actually need for the lowering.
  3. Not adding canonicalization/folding patterns to the canonicalizer pass that could mess with the bufferization. That's what we've been doing so far.

@Menooker
Copy link
Author

Hi, reviewers! Any more comments regarding the patch itself?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants