Skip to content

Extend TilingInterface to allow more flexible tiling #95422

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

srinathava
Copy link

Ref: discource thread

Problem:

The current version of transform.structured.fuse relies on ops implementing a TilingInterface. An op which implements such an interface returns a TilingResult defined as:

/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
struct TilingResult {
  SmallVector<Operation *> tiledOps;
  SmallVector<Value> tiledValues;
};

The way the algorithm is currently implemented, only the last operation in tiledOps is considered for further fusion.

Where it breaks down is when we implement a TilingInterface for the tosa.concat operation like so (MLIR pseudo-code):

%slice = scf.if (%offset < size(t1)) (
    scf.yield tensor.extract_slice %arg1 ...
} else {
    scf.yield tensor.extract_slice %arg2 ...
}

Even if both the scf.yield ops are returned in the tiledOps field, only the last one is further fused with upstream producers.

In this PR, we now extend TilingResult to contain a list of tensor::ExtractSliceOps. This allows the interface to directly return the list of slice ops it created to implement the tiled result. This required some plumbing of the tensor::ExtractSliceOps through TilingResult -> SCFTilingResult -> SCFFuseProducerOfSliceResult. This is then used to add to the worklist of extract slice ops which we process. This also required the current LinalgTilingInterface to provide the extractSliceOps.

Copy link

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: Srinath Avadhanula (srinathava)

Changes

Ref: discource thread

Problem:

The current version of transform.structured.fuse relies on ops implementing a TilingInterface. An op which implements such an interface returns a TilingResult defined as:

/// Container for result values of tiling.
/// - `tiledOps` contains operations created by the tiling implementation that
/// are returned to the caller for further transformations.
/// - `tiledValues` contains the tiled value corresponding to the result of the
/// untiled operation.
struct TilingResult {
  SmallVector&lt;Operation *&gt; tiledOps;
  SmallVector&lt;Value&gt; tiledValues;
};

The way the algorithm is currently implemented, only the last operation in tiledOps is considered for further fusion.

Where it breaks down is when we implement a TilingInterface for the tosa.concat operation like so (MLIR pseudo-code):

%slice = scf.if (%offset &lt; size(t1)) (
    scf.yield tensor.extract_slice %arg1 ...
} else {
    scf.yield tensor.extract_slice %arg2 ...
}

Even if both the scf.yield ops are returned in the tiledOps field, only the last one is further fused with upstream producers.

In this PR, we now extend TilingResult to contain a list of tensor::ExtractSliceOps. This allows the interface to directly return the list of slice ops it created to implement the tiled result. This required some plumbing of the tensor::ExtractSliceOps through TilingResult -> SCFTilingResult -> SCFFuseProducerOfSliceResult. This is then used to add to the worklist of extract slice ops which we process. This also required the current LinalgTilingInterface to provide the extractSliceOps.


Full diff: https://github.com/llvm/llvm-project/pull/95422.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+2)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.h (+4)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+7-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+9-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+19-18)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index dac79111af3c9..fecd33193eb0d 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -85,6 +85,7 @@ struct SCFTilingResult {
   /// Values to use as replacements for the untiled op. Is the same size as the
   /// number of results of the untiled op.
   SmallVector<Value> replacements;
+  SmallVector<Operation *> extractSliceOps;
 };
 
 /// Method to tile an op that implements the `TilingInterface` using
@@ -135,6 +136,7 @@ struct SCFFuseProducerOfSliceResult {
   OpResult origProducer;       // Original untiled producer.
   Value tiledAndFusedProducer; // Tile and fused producer value.
   SmallVector<Operation *> tiledOps;
+  SmallVector<Operation *> extractSliceOps;
 };
 std::optional<SCFFuseProducerOfSliceResult>
 tileAndFuseProducerOfSlice(RewriterBase &rewriter,
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.h b/mlir/include/mlir/Interfaces/TilingInterface.h
index ca570490ccf5b..e5ed016d53fc1 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.h
+++ b/mlir/include/mlir/Interfaces/TilingInterface.h
@@ -28,9 +28,13 @@ namespace mlir {
 /// are returned to the caller for further transformations.
 /// - `tiledValues` contains the tiled value corresponding to the result of the
 /// untiled operation.
+/// - `extractSliceOps` contains all the `tensor.extract_slice` ops used in
+/// generating the `tiledOps`. Usually these are operands to the `tiledOps`
+/// but they can be embedded in regions owned by `tiledOps`.
 struct TilingResult {
   SmallVector<Operation *> tiledOps;
   SmallVector<Value> tiledValues;
+  SmallVector<Operation *> extractSliceOps;
 };
 
 } // namespace mlir
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index b79afebfa8158..5198e0bceaa6e 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2501,7 +2501,13 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
   Operation *tiledOp =
       mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
 
-  return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+  SmallVector<Operation *> sliceOps;
+  for (Value operand : tiledOperands)
+    if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+      sliceOps.push_back(sliceOp);
+
+  return TilingResult{
+      {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
 }
 
 LogicalResult SoftmaxOp::getResultTilePosition(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index c3ab3cecfada7..f25ccc38ba0a3 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -129,7 +129,13 @@ struct LinalgOpTilingInterface
     Operation *tiledOp = clone(b, linalgOp, resultTensorTypes, tiledOperands);
     offsetIndices(b, cast<LinalgOp>(tiledOp), offsets);
 
-    return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
+    SmallVector<Operation *> sliceOps;
+    for (Value operand : tiledOperands)
+      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
+        sliceOps.push_back(sliceOp);
+
+    return TilingResult{
+        {tiledOp}, SmallVector<Value>(tiledOp->getResults()), sliceOps};
   }
 
   /// Utility to fetch the offsets and sizes when applied as per the indexing
@@ -247,7 +253,8 @@ struct LinalgOpTilingInterface
 
     return TilingResult{
         tilingResult->tiledOps,
-        SmallVector<Value>{tilingResult->tiledValues[resultNumber]}};
+        SmallVector<Value>{tilingResult->tiledValues[resultNumber]},
+        tilingResult->extractSliceOps};
   }
 
   /// Method to generate the tiled implementation of an operation from the tile
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..fb3ec2a5fa0a8 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -619,7 +619,8 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
     if (llvm::all_of(tileSizes, isZeroIndex)) {
       tiledResults.append(clonedOp->result_begin(), clonedOp->result_end());
       tilingResult =
-          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults()};
+          TilingResult{/*tiledOps=*/{clonedOp}, clonedOp->getResults(),
+                       /*extractSliceOps=*/{}};
       return success();
     }
 
@@ -675,12 +676,14 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op,
   // op.
   if (loops.empty()) {
     return scf::SCFTilingResult{tilingResult->tiledOps, loops,
-                                tilingResult->tiledValues};
+                                tilingResult->tiledValues,
+                                tilingResult->extractSliceOps};
   }
 
   SmallVector<Value> replacements = llvm::map_to_vector(
       loops.front()->getResults(), [](OpResult r) -> Value { return r; });
-  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements};
+  return scf::SCFTilingResult{tilingResult->tiledOps, loops, replacements,
+                              tilingResult->extractSliceOps};
 }
 
 FailureOr<scf::SCFReductionTilingResult>
@@ -931,9 +934,9 @@ mlir::scf::tileAndFuseProducerOfSlice(
         ->getOpOperands()[destinationInitArg.value()->getOperandNumber()]
         .set(origDestinationTensors[resultNumber]);
   }
-  return scf::SCFFuseProducerOfSliceResult{fusableProducer,
-                                           tileAndFuseResult->tiledValues[0],
-                                           tileAndFuseResult->tiledOps};
+  return scf::SCFFuseProducerOfSliceResult{
+      fusableProducer, tileAndFuseResult->tiledValues[0],
+      tileAndFuseResult->tiledOps, tileAndFuseResult->extractSliceOps};
 }
 
 /// Reconstruct the fused producer from within the tiled-and-fused code.
@@ -962,13 +965,12 @@ LogicalResult mlir::scf::yieldReplacementForFusedProducer(
                   .getDefiningOp<DestinationStyleOpInterface>()) {
         rewriter.setInsertionPoint(tiledDestStyleOp);
         Value newRegionArg = newRegionIterArgs.back();
-        auto destSlice = rewriter.create<tensor::ExtractSliceOp>(
-            sliceOp.getLoc(), newRegionArg, sliceOp.getMixedOffsets(),
-            sliceOp.getMixedSizes(), sliceOp.getMixedStrides());
         unsigned resultNumber = fusableProducer.getResultNumber();
-        rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
-          tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
-        });
+        auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
+                             .getDefiningOp<tensor::ExtractSliceOp>();
+        if (origSlice) {
+          origSlice.getSourceMutable().set(newRegionArg);
+        }
       }
       Block *block = rewriter.getInsertionPoint()->getBlock();
       rewriter.setInsertionPoint(block->getTerminator());
@@ -1036,15 +1038,14 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
   //    operations. If the producers of the source of the `tensor.extract_slice`
   //    can be tiled such that the tiled value is generated in-place, that
   //    effectively tiles + fuses the operations.
-  auto addCandidateSlices = [](Operation *fusedOp,
+  auto addCandidateSlices = [](const SmallVector<Operation *> &newSliceOps,
                                std::deque<tensor::ExtractSliceOp> &candidates) {
-    for (Value operand : fusedOp->getOperands())
-      if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>())
-        candidates.push_back(sliceOp);
+    for (auto *op : newSliceOps)
+      candidates.push_back(llvm::cast<tensor::ExtractSliceOp>(op));
   };
 
   std::deque<tensor::ExtractSliceOp> candidates;
-  addCandidateSlices(tiledAndFusedOps.back(), candidates);
+  addCandidateSlices(tilingResult->extractSliceOps, candidates);
   OpBuilder::InsertionGuard g(rewriter);
   while (!candidates.empty()) {
     // Traverse the slices in BFS fashion.
@@ -1086,7 +1087,7 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
             fusedResult->tiledAndFusedProducer.getDefiningOp()) {
       fusedProducers.insert(fusedResult->origProducer.getDefiningOp());
       tiledAndFusedOps.insert(tiledAndFusedOp);
-      addCandidateSlices(tiledAndFusedOp, candidates);
+      addCandidateSlices(fusedResult->extractSliceOps, candidates);
     }
   }
 

@srinathava srinathava marked this pull request as draft June 13, 2024 15:36
rewriter.modifyOpInPlace(tiledDestStyleOp, [&]() {
tiledDestStyleOp.getDpsInitsMutable()[resultNumber].set(destSlice);
});
auto origSlice = tiledDestStyleOp.getDpsInits()[resultNumber]
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were previously creating a new tensor::ExtractSliceOp to set as the destination argument of the tiled op. This worked before because we would be looking at the operands of the tiled op to build the worklist of candidate tensor::ExtractSliceOps in tileConsumerAndFuseProducers. However, we now directly use the slice ops returned by the TilingInterfaceOp before we call into yieldReplacementForFusedProducer.

We therefore modify the slice op argument of the fused producer in place instead of creating a new tensor::ExtractSliceOp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to try this out more... I cloned the slice cause that made it easier. I think without this I kept running into invalid IR creation, but maybe it was an artifact of something else not working. I will have to try this out in IREE to really stress test it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR indeed needs this change to ensure the slice ops returned by the TilingInterfaceOp would not expire after yieldReplacementForFusedProducer. Otherwise, it needs another solution to update the slice ops cached in SCFFuseProducerOfSliceResult before to latest one(created in yieldReplacementForFusedProducer).

BTW, this change will also affect another PR involving yieldReplacementForFusedProducer as well , hopefully merged before this PR.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the heads up @Yun-Fly

@ftynse ftynse self-requested a review June 14, 2024 13:40
@srinathava srinathava marked this pull request as ready for review June 17, 2024 12:26
@srinathava srinathava requested a review from hanhanW as a code owner June 17, 2024 12:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants