Skip to content

Commit 060208b

Browse files
[mlir][NFC] Move SubTensorOp and SubTensorInsertOp to TensorDialect
The main goal of this commit is to remove the dependency of Standard dialect on the Tensor dialect. * Rename SubTensorOp -> tensor.extract_slice, SubTensorInsertOp -> tensor.insert_slice. * Some helper functions are (already) duplicated between the Tensor dialect and the MemRef dialect. To keep this commit smaller, this will be cleaned up in a separate commit. * Additional dialect dependencies: Shape --> Tensor, Tensor --> Standard * Remove dialect dependencies: Standard --> Tensor * Move canonicalization test cases to correct dialect (Tensor/MemRef). Note: This is a fixed version of https://reviews.llvm.org/D104499, which was reverted due to a missing update to two CMakeFile.txt. Differential Revision: https://reviews.llvm.org/D104676
1 parent c97cf73 commit 060208b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+1864
-1831
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
367367
let dependentDialects = [
368368
"memref::MemRefDialect",
369369
"StandardOpsDialect",
370-
"scf::SCFDialect"
370+
"scf::SCFDialect",
371+
"tensor::TensorDialect"
371372
];
372373
}
373374

@@ -504,7 +505,7 @@ def TosaToSCF : Pass<"tosa-to-scf"> {
504505

505506
def TosaToStandard : Pass<"tosa-to-standard"> {
506507
let summary = "Lower TOSA to the Standard dialect";
507-
let dependentDialects = ["StandardOpsDialect"];
508+
let dependentDialects = ["StandardOpsDialect", "tensor::TensorDialect"];
508509
let description = [{
509510
Pass that converts TOSA operations to the equivalent operations using the
510511
operations in the Standard dialect.

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -579,12 +579,11 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
579579

580580
Tensor-based version:
581581

582-
The body region of the loop contains `subtensor` operations applied to
582+
The body region of the loop contains `extract_slice` operations applied to
583583
every tensor argument of TiledLoopOp.
584584

585585
The body region must contain exactly one block that terminates with
586-
`linalg.yield` with the operands resulting from `subtensor_insert`
587-
operations.
586+
`linalg.yield` with the operands resulting from `insert_slice` operations.
588587

589588
Example:
590589

@@ -594,16 +593,16 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [
594593
outs(%out : tensor<24x64xi8>)
595594
iterators("parallel")
596595
distribution("block_x") {
597-
%lhs_sub = subtensor %lhs[%i, 0] [%c4, %c64] [1, 1]
596+
%lhs_sub = tensor.extract_slice %lhs[%i, 0] [%c4, %c64] [1, 1]
598597
: tensor<24x64xi8> to tensor<?x?xi8>
599-
%rhs_sub = subtensor %rhs[%i, 0] [%c4, %c64] [1, 1]
598+
%rhs_sub = tensor.extract_slice %rhs[%i, 0] [%c4, %c64] [1, 1]
600599
: tensor<24x64xi8> to tensor<?x?xi8>
601-
%out_sub = subtensor %out[%i, 0] [%c4, %c64] [1, 1]
600+
%out_sub = tensor.extract_slice %out[%i, 0] [%c4, %c64] [1, 1]
602601
: tensor<24x64xi8> to tensor<?x?xi8>
603602

604603
%result_sub = linalg.generic ...
605604

606-
%result = subtensor_insert %result_sub into %out[%i, 0][%c4, %c64][1, 1]
605+
%result = tensor.insert_slice %result_sub into %out[%i, 0][%c4, %c64][1, 1]
607606
: tensor<?x?xi8> into tensor<24x64xi8>
608607
linalg.yield %result : tensor<24x64xi8>
609608
}

mlir/include/mlir/Dialect/Linalg/Transforms/Hoisting.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ void hoistRedundantVectorTransfersOnTensor(FuncOp func);
4747
/// If hoistPaddingOnTensors is called with `nLoops` = 2 on the following IR.
4848
/// ```
4949
/// scf.for (%i, %j, %k)
50-
/// %st0 = subtensor f(%i, %k) : ... to tensor<?x?xf32>
50+
/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
5151
/// %0 = linalg.pad_tensor %st0 low[0, 0] high[...] {
5252
/// ^bb0( ... ):
5353
/// linalg.yield %pad
@@ -61,16 +61,17 @@ void hoistRedundantVectorTransfersOnTensor(FuncOp func);
6161
/// scf.for (%i) {
6262
/// %packed_init = linalg.init_tensor range(%j) : tensor<?x4x8xf32>
6363
/// %packed = scf.for (%k) iter_args(%p : %packed_init) {
64-
/// %st0 = subtensor f(%i, %k) : ... to tensor<?x?xf32>
64+
/// %st0 = tensor.extract_slice f(%i, %k) : ... to tensor<?x?xf32>
6565
/// %0 = linalg.pad_tensor %st0 low[0, 0] high[...] {
6666
/// ^bb0( ... ):
6767
/// linalg.yield %pad
6868
/// } : tensor<?x?xf32> to tensor<4x8xf32>
69-
/// %1 = subtensor_insert %0 ... : tensor<4x8xf32> to tensor<?x4x8xf32>
69+
/// %1 = tensor.insert_slice %0 ...
70+
/// : tensor<4x8xf32> to tensor<?x4x8xf32>
7071
/// scf.yield %1: tensor<?x4x8xf32>
7172
/// } -> tensor<?x4x8xf32>
7273
/// scf.for (%j, %k) {
73-
/// %st0 = subtensor %packed [%k, 0, 0][1, 4, 8][1, 1, 1] :
74+
/// %st0 = tensor.extract_slice %packed [%k, 0, 0][1, 4, 8][1, 1, 1] :
7475
/// tensor<?x4x8xf32> to tensor<4x8xf32>
7576
/// compute(%st0)
7677
/// }

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Linalg/Utils/Utils.h"
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1414
#include "mlir/Dialect/SCF/Utils.h"
15+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1516
#include "mlir/Dialect/Vector/VectorOps.h"
1617
#include "mlir/IR/Identifier.h"
1718
#include "mlir/IR/PatternMatch.h"
@@ -1077,12 +1078,12 @@ LogicalResult applyStagedPatterns(
10771078
const FrozenRewritePatternSet &stage2Patterns,
10781079
function_ref<LogicalResult(Operation *)> stage3Lambda = nullptr);
10791080

1080-
/// Rewrite subtensor(pad_tensor(x)) into pad_tensor(subtensor(x)).
1081-
struct SubTensorOfPadTensorSwapPattern
1082-
: public OpRewritePattern<SubTensorOp> {
1083-
using OpRewritePattern<SubTensorOp>::OpRewritePattern;
1081+
/// Rewrite extract_slice(pad_tensor(x)) into pad_tensor(extract_slice(x)).
1082+
struct ExtractSliceOfPadTensorSwapPattern
1083+
: public OpRewritePattern<tensor::ExtractSliceOp> {
1084+
using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
10841085

1085-
LogicalResult matchAndRewrite(SubTensorOp subTensorOp,
1086+
LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
10861087
PatternRewriter &rewriter) const override;
10871088
};
10881089

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ bool isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
7878
bool isFusableInto(const LinalgDependenceGraph &graph, LinalgOp consumer,
7979
Value consumedView, LinalgOp producer);
8080

81-
/// Creates subtensor/subview ops for all `tiledOperands` of the given
81+
/// Creates extract_slice/subview ops for all `tiledOperands` of the given
8282
/// `linalgOp` with `builder`, assuming `linalgOp` is being fused into a loop
8383
/// nest for tiling with the given induction variables `ivs` and tile sizes
8484
/// `tileSizes`. `sizeBounds` are the iteration space bounds for *all* the
@@ -118,15 +118,17 @@ Optional<FusionInfo> fuseProducerOfBuffer(OpBuilder &b,
118118
const LinalgDependenceGraph &graph);
119119
/// Tensor counterpart of `fuseProducerOfBuffer`.
120120
/// This implements the fusion part of the "tileAndFuse on tensors"
121-
/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
122-
/// op (generally obtained by applying the tiling transformation).
121+
/// transformation and thus requires the `consumerOpOperand` to be a
122+
/// `extract_slice` op (generally obtained by applying the tiling
123+
/// transformation).
123124
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
124125
OpOperand &consumerOpOperand);
125126
/// Tensor counterpart of `fuseProducerOfBuffer`.
126127
/// This implements the fusion part of the "tileAndFuse on tensors"
127-
/// transformation and thus requires the `consumerOpOperand` to be a `subtensor`
128-
/// op (generally obtained by applying the tiling transformation).
129-
/// Assumes `producerOfTensor` is a Linalg op that produces `consumerOpOperand`.
128+
/// transformation and thus requires the `consumerOpOperand` to be a
129+
/// `extract_slice` op (generally obtained by applying the tiling
130+
/// transformation). Assumes `producerOfTensor` is a Linalg op that produces
131+
/// `consumerOpOperand`.
130132
Optional<FusionInfo> fuseProducerOfTensor(OpBuilder &b,
131133
OpResult producerOpResult,
132134
OpOperand &consumerOpOperand);

mlir/include/mlir/Dialect/Shape/IR/Shape.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#ifndef MLIR_SHAPE_IR_SHAPE_H
1515
#define MLIR_SHAPE_IR_SHAPE_H
1616

17+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1718
#include "mlir/IR/BuiltinOps.h"
1819
#include "mlir/IR/Dialect.h"
1920
#include "mlir/IR/OpDefinition.h"

mlir/include/mlir/Dialect/Shape/IR/ShapeBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def ShapeDialect : Dialect {
3535
}];
3636

3737
let cppNamespace = "::mlir::shape";
38+
let dependentDialects = ["tensor::TensorDialect"];
3839

3940
let hasConstantMaterializer = 1;
4041
let hasOperationAttrVerify = 1;

mlir/include/mlir/Dialect/StandardOps/IR/Ops.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2424
#include "mlir/Interfaces/SideEffectInterfaces.h"
2525
#include "mlir/Interfaces/VectorInterfaces.h"
26-
#include "mlir/Interfaces/ViewLikeInterface.h"
2726

2827
// Pull in all enum type definitions and utility function declarations.
2928
#include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc"
@@ -34,12 +33,6 @@ class Builder;
3433
class FuncOp;
3534
class OpBuilder;
3635
class PatternRewriter;
37-
38-
/// Return the list of Range (i.e. offset, size, stride). Each Range
39-
/// entry contains either the dynamic value or a ConstantIndexOp constructed
40-
/// with `b` at location `loc`.
41-
SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
42-
OpBuilder &b, Location loc);
4336
} // namespace mlir
4437

4538
#define GET_OP_CLASSES

0 commit comments

Comments
 (0)