Skip to content

[mlir] [tensor] Add patterns to remove whole slicing of tensors #107046

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ using ControlFoldFn = std::function<bool(OpOperand *)>;
void populateRewriteAsConstantPatterns(RewritePatternSet &patterns,
const ControlFoldFn &controlFn);

/// Appends patterns for eliminating whole-slice extract_slice and insert_slice.
/// The patterns remove extract_slice and insert_slice when the size matches
/// and the offsets of the slice are all zeros and strides are all ones.
void populateEliminateWholeSlicingPatterns(RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
// Transform helpers
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Tensor/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ add_mlir_dialect_library(MLIRTensorTransforms
RewriteAsConstant.cpp
SwapExtractSliceWithProducerPatterns.cpp
SubsetInsertionOpInterfaceImpl.cpp
EliminateWholeSlicePatterns.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tensor/Transforms
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
//===- EliminateWholeSlicePatterns.cpp - Patterns to remove whole slices --===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/PatternMatch.h"

using namespace mlir;
using namespace mlir::tensor;

namespace {

bool checkEliminateOK(PatternRewriter &rewriter,
OffsetSizeAndStrideOpInterface sliceOp,
mlir::TypedValue<mlir::RankedTensorType> smallerTensor,
mlir::TypedValue<mlir::RankedTensorType> largerTensor) {
auto srcType = largerTensor.getType();
auto resultType = smallerTensor.getType();
if (!isSameTypeWithoutEncoding(srcType, resultType)) {
// fast failure path when in and out types do not match
return false;
}
// both types are ensured to have the same rank
for (int64_t i = 0; i < resultType.getRank(); ++i) {
// check the ExtractSliceOp offsets, should be all-zero
if (sliceOp.isDynamicOffset(i) || sliceOp.getStaticOffset(i) != 0)
return false;
// check the ExtractSliceOp Strides, should be all-one
if (sliceOp.isDynamicStride(i) || sliceOp.getStaticStride(i) != 1)
return false;
}
// check if the dynamic shape matchs
if (resultType.getNumDynamicDims() != 0) {
for (int64_t i = 0; i < resultType.getRank(); ++i) {
if (resultType.isDynamicDim(i)) {
auto largeDim =
getMixedSize(rewriter, sliceOp.getLoc(), largerTensor, i);
auto smallDim = sliceOp.getDynamicSize(i);
if (largeDim.dyn_cast<Value>() != smallDim) {
return false;
}
}
}
}
// if the tensor is in static-shape, we already checked the shapes match via
// isSameTypeWithoutEncoding
return true;
}

struct EliminateWholeSliceExtractSliceOp
: public OpRewritePattern<ExtractSliceOp> {
EliminateWholeSliceExtractSliceOp(MLIRContext *ctx)
: OpRewritePattern<ExtractSliceOp>(ctx) {}

LogicalResult matchAndRewrite(ExtractSliceOp sliceOp,
PatternRewriter &rewriter) const override {
if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getResult(),
sliceOp.getSource())) {
return failure();
}
// all checking are done. Rewrite the IR
rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
rewriter.eraseOp(sliceOp);
return success();
}
};

struct EliminateWholeSliceInsertSliceOp
: public OpRewritePattern<InsertSliceOp> {
EliminateWholeSliceInsertSliceOp(MLIRContext *ctx)
: OpRewritePattern<InsertSliceOp>(ctx) {}

LogicalResult matchAndRewrite(InsertSliceOp sliceOp,
PatternRewriter &rewriter) const override {
if (!checkEliminateOK(rewriter, sliceOp, sliceOp.getSource(),
sliceOp.getDest())) {
return failure();
}
// all checking are done. Rewrite the IR
rewriter.replaceAllUsesWith(sliceOp, sliceOp.getSource());
rewriter.eraseOp(sliceOp);
return success();
}
};

} // namespace

void mlir::tensor::populateEliminateWholeSlicingPatterns(
RewritePatternSet &patterns) {
patterns
.add<EliminateWholeSliceExtractSliceOp, EliminateWholeSliceInsertSliceOp>(
patterns.getContext());
}
194 changes: 194 additions & 0 deletions mlir/test/Dialect/Tensor/eliminate-whole-slicing.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-eliminate-whole-slicing-patterns -canonicalize -mlir-print-local-scope %s | FileCheck %s

//////////////////////////////
// here starts the tests for insert_slice
//////////////////////////////

func.func @elim_dyn_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
%c0f = arith.constant 0.0 : bf16
%3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
%inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
return %inserted_slice_3 : tensor<32x32x32x32xbf16>
}

// CHECK-LABEL: func.func @elim_dyn_insert
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
// CHECK: return %[[INSERT]]

func.func @elim_static_insert(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%c0f = arith.constant 0.0 : bf16
%3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
%inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
return %inserted_slice_3 : tensor<32x32x32x32xbf16>
}

// CHECK-LABEL: func.func @elim_static_insert
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: %[[INSERT:.+]] = tensor.insert_slice %[[FILL]] into %[[SOURCE]]
// CHECK: return %[[INSERT]]

func.func @fail_dyn_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<32x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
%c0f = arith.constant 0.0 : bf16
%3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<?x32x32x32xbf16>) -> tensor<?x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<?x32x32x32xbf16>
%inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> into tensor<32x32x32x32xbf16>
return %inserted_slice_3 : tensor<32x32x32x32xbf16>
}
// fail to optimize due to unmatched insert shape
// CHECK-LABEL: func.func @fail_dyn_insert_shape
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]]
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: tensor.insert_slice
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: return %[[INSERT]]

func.func @fail_static_insert_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%3 = tensor.empty() : tensor<14x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<14x32x32x32xbf16> into tensor<15x32x32x32xbf16>
%inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
return %inserted_slice_3 : tensor<32x32x32x32xbf16>
}
// fail to optimize due to unmatched insert shape
// CHECK-LABEL: func.func @fail_static_insert_shape
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: tensor.empty()
// CHECK: tensor.insert_slice
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: return %[[INSERT]]

func.func @fail_dyn_insert_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%c0f = arith.constant 0.0 : bf16
%3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, %arg2] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
%inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
return %inserted_slice_3 : tensor<32x32x32x32xbf16>
}
// fail to optimize due to dynamic stride
// CHECK-LABEL: func.func @fail_dyn_insert_stride
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: tensor.insert_slice
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: return %[[INSERT]]

// fail to optimize due to non-zero offset
func.func @fail_static_insert_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<32x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%c0f = arith.constant 0.0 : bf16
%3 = linalg.fill ins(%c0f : bf16) outs(%extracted_slice : tensor<15x32x32x32xbf16>) -> tensor<15x32x32x32xbf16>
%inserted_slice = tensor.insert_slice %3 into %extracted_slice[0, 0, 0, 1] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<15x32x32x32xbf16>
%inserted_slice_3 = tensor.insert_slice %inserted_slice into %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<32x32x32x32xbf16>
return %inserted_slice_3 : tensor<32x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_static_insert_offset
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK: tensor.insert_slice
// CHECK: %[[INSERT:.+]] = tensor.insert_slice
// CHECK: return %[[INSERT]]

//////////////////////////////
// here starts the tests for extract_slice
//////////////////////////////
func.func @elim_dyn_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
return %extracted_slice2 : tensor<?x32x32x32xbf16>
}
// CHECK-LABEL: func.func @elim_dyn_extract
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [%[[OFFSET1]], 32, 32, 32]
// CHECK: return %[[EXTRACT]]


func.func @elim_static_extract(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
return %extracted_slice2 : tensor<15x32x32x32xbf16>
}
// CHECK-LABEL: func.func @elim_static_extract
// CHECK-SAME: (%[[SOURCE:.+]]: tensor<32x32x32x32xbf16>, %[[OFFSET0:.+]]: index, %[[OFFSET1:.+]]: index
// CHECK: %[[EXTRACT:.+]] = tensor.extract_slice %[[SOURCE]][%[[OFFSET0]], 0, 0, 0] [15, 32, 32, 32]
// CHECK: return %[[EXTRACT]]

// fail to optimize due to unmatched shape
func.func @fail_dyn_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg2, 32, 32, 32] [1, 1, 1, 1] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
return %extracted_slice2 : tensor<?x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_dyn_extract_shape
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: return

// fail to optimize due to unmatched shape
func.func @fail_static_extract_shape(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<14x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [14, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<14x32x32x32xbf16>
return %extracted_slice2 : tensor<14x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_static_extract_shape
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: return

// fail to optimize due to stride
func.func @fail_extract_stride(%arg0: tensor<32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<?x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<?x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0] [%arg3, 32, 32, 32] [1, 1, 1, 3] : tensor<?x32x32x32xbf16> to tensor<?x32x32x32xbf16>
return %extracted_slice2 : tensor<?x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_extract_stride
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: return

// fail to optimize due to non-zero offset
func.func @fail_static_extract_offset(%arg0: tensor<32x32x32x32xbf16>, %arg2: index) -> tensor<15x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<32x32x32x32xbf16> to tensor<15x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[%arg2, 0, 0, 0] [15, 32, 32, 32] [1, 1, 1, 1] : tensor<15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
return %extracted_slice2 : tensor<15x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_static_extract_offset
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: return



//////////////////////////////
// here starts the tests for expanding/reducing dims
//////////////////////////////
func.func @fail_extract_reduce(%arg0: tensor<1x32x32x32x32xbf16>, %arg2: index, %arg3: index) -> tensor<15x32x32x32xbf16> {
%extracted_slice = tensor.extract_slice %arg0[0, %arg2, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x32x32x32x32xbf16> to tensor<1x15x32x32x32xbf16>
%extracted_slice2 = tensor.extract_slice %extracted_slice[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<1x15x32x32x32xbf16> to tensor<15x32x32x32xbf16>
return %extracted_slice2 : tensor<15x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_extract_reduce
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: return

func.func @fail_insert_expand(%arg0: tensor<1x15x32x32x32xbf16>, %arg1: tensor<1x15x32x32x32xbf16>, %arg2: index) -> tensor<1x15x32x32x32xbf16> {
%extracted_slice = tensor.empty(): tensor<15x32x32x32xbf16>
%extracted_slice2 = tensor.insert_slice %extracted_slice into %arg0[0, 0, 0, 0, 0] [1, 15, 32, 32, 32] [1, 1, 1, 1, 1] : tensor<15x32x32x32xbf16> into tensor<1x15x32x32x32xbf16>
return %extracted_slice2 : tensor<1x15x32x32x32xbf16>
}
// CHECK-LABEL: func.func @fail_insert_expand
// CHECK: tensor.empty
// CHECK: tensor.insert_slice
// CHECK: return
14 changes: 14 additions & 0 deletions mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ struct TestTensorTransforms
*this, "test-tracking-listener",
llvm::cl::desc("Test tensor TrackingListener for the transform dialect"),
llvm::cl::init(false)};

Option<bool> testEliminateWholeSlicingPatterns{
*this, "test-eliminate-whole-slicing-patterns",
llvm::cl::desc("Test patterns to eliminate whole-slicing extract_slice "
"and insert_slice"),
llvm::cl::init(false)};
};
} // namespace

Expand Down Expand Up @@ -154,6 +160,12 @@ static void applySimplifyPackUnpackPatterns(Operation *rootOp) {
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}

static void applyEliminateWholeSlicingPatterns(Operation *rootOp) {
RewritePatternSet patterns(rootOp->getContext());
tensor::populateEliminateWholeSlicingPatterns(patterns);
(void)applyPatternsAndFoldGreedily(rootOp, std::move(patterns));
}

namespace {
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
Expand Down Expand Up @@ -406,6 +418,8 @@ void TestTensorTransforms::runOnOperation() {
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
return signalPassFailure();
}
if (testEliminateWholeSlicingPatterns)
applyEliminateWholeSlicingPatterns(rootOp);
if (testTrackingListener)
if (failed(testTrackingListenerReplacements(rootOp)))
return signalPassFailure();
Expand Down
Loading