Skip to content

[MLIR][SCF] Add an API to fuse consumer to a producer within scf loop #88712

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 1, 2024

Conversation

Abhishek-Varma
Copy link
Contributor

-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma [email protected]

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-scf

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma <[email protected]>


Patch is 41.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88712.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+13)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+55)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+75-21)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+511)
  • (added) mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir (+110)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+53)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+21)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..75b48a2cdd8dc3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+  Operation *origConsumer;     // Original untiled consumer.
+  Value tiledAndFusedConsumer; // Tile and fused consumer value.
+  SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+                           bool useSCFFor);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
+  void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+                              AffineMap indexingMap,
+                              ArrayRef<OpFoldResult> offsets,
+                              ArrayRef<OpFoldResult> sizes,
+                              SmallVector<OpFoldResult> &mappedOffsets,
+                              SmallVector<OpFoldResult> &mappedSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[range.index()] = range.value().offset;
+        mappedSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+      mappedSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult getIterDomainTilePositionFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled get iter domain position when operand is not "
+          "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
   // Return the details of the output tile generated by the tiled
   // implementation.
   LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return op->emitOpError(
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+                           mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..45c8f8362ad581 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -798,6 +799,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 // tileConsumerAndFuseProducersUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *source,
+                                         unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+  auto iterArg = dyn_cast<BlockArgument>(source->get());
+  Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -820,6 +874,463 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    llvm::outs() << "containing op is not a scf.forall: " << containingOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result"
+            << "\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface"
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.forall as init"
+                 << "\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+  for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+    auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+      if (!targetInsertOp) {
+        targetInsertOp = parallelInsertSliceOp;
+      } else {
+        llvm::outs() << "containingOp's result inserted multi time"
+                     << "\n";
+        return failure();
+      }
+    }
+  }
+
+  if (!targetInsertOp) {
+    llvm::outs() << "containingOp's result was not inserted"
+                 << "\n";
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    llvm::outs() << "containingOp's result yield with stride"
+                 << "\n";
+    return failure();
+  }
+
+  Location loc = forallOp.getLoc();
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    llvm::outs() << "can't get iter domain position from input position"
+                 << "\n";
+    return failure();
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      llvm::outs()
+          << "can't get result domain position from iter domain position"
+          << "\n";
+      return failure();
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    llvm::outs() << "get tiled implementation failed"
+                 << "\n";
+    return failure();
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+    return failure();
+  }
+
+  // Replace tiled op's operand.
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) ...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 15, 2024

@llvm/pr-subscribers-mlir

Author: Abhishek Varma (Abhishek-Varma)

Changes

-- This commit adds an API to fuse consumer to a producer within scf.for/scf.forall loop.

Signed-off-by: Abhishek Varma <[email protected]>


Patch is 41.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88712.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+13)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+55)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+75-21)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+511)
  • (added) mlir/test/Interfaces/TilingInterface/fuse-consumer.mlir (+110)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+53)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td (+21)
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be28..75b48a2cdd8dc3 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -126,6 +126,19 @@ struct SCFTileAndFuseOptions {
   }
 };
 
+/// Fuse the consumer of the source of `candidateSliceOp` by computing the
+/// required slice of the consumer in-place.  Note that the method
+/// replaces the uses of `candidateSliceOp` with the tiled and fused consumer
+/// value but does not delete the slice operation.
+struct SCFFuseConsumerOfSliceResult {
+  Operation *origConsumer;     // Original untiled consumer.
+  Value tiledAndFusedConsumer; // Tile and fused consumer value.
+  SmallVector<Operation *> tiledOps;
+};
+FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSlice(RewriterBase &rewriter, Operation *candidateSliceOp,
+                           bool useSCFFor);
+
 /// Fuse the producer of the source of `candidateSliceOp` by computing the
 /// required slice of the producer in-place.  Note that the method
 /// replaces the uses of `candidateSliceOp` with the tiled and fused producer
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..01bf19764b0938 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
+  void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
+                              AffineMap indexingMap,
+                              ArrayRef<OpFoldResult> offsets,
+                              ArrayRef<OpFoldResult> sizes,
+                              SmallVector<OpFoldResult> &mappedOffsets,
+                              SmallVector<OpFoldResult> &mappedSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+    auto numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &range : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[range.index()] = range.value().offset;
+        mappedSizes[range.index()] = range.value().size;
+      }
+    }
+    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition =
+          cast<AffineDimExpr>(resultExpr.value()).getPosition();
+      mappedOffsets[dimPosition] = offsets[resultExpr.index()];
+      mappedSizes[dimPosition] = sizes[resultExpr.index()];
+    }
+  }
+
+  // Return the details of the output tile generated by the tiled
+  // implementation.
+  LogicalResult getIterDomainTilePositionFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVector<OpFoldResult> &iterDomainOffsets,
+      SmallVector<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return op->emitOpError(
+          "unhandled get iter domain position when operand is not "
+          "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
   // Return the details of the output tile generated by the tiled
   // implementation.
   LogicalResult
@@ -160,6 +213,20 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandPosition(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterDomainTilePositionFromOperandPosition(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return op->emitOpError(
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(op, b, indexingMap, offsets, sizes, mappedOffsets,
+                           mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69df..45c8f8362ad581 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/IR/Dominance.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/DestinationStyleOpInterface.h"
@@ -798,6 +799,59 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
 // tileConsumerAndFuseProducersUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFFor(OpOperand &source,
+                                      unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.for.
+  unsigned yieldOperandNumber = source.getOperandNumber();
+  Value resultingValue =
+      source.getOwner()->getParentOp()->getResult(yieldOperandNumber);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
+static std::tuple<Operation *, std::optional<OpOperand *>>
+getUntiledConsumerFromSliceDestSCFForall(OpOperand *source,
+                                         unsigned &operandNumber) {
+  // Step 1. Fetch the corresponding output
+  // TODO(avarma): Make it generic for multiple values yielding scf.forall.
+  auto iterArg = dyn_cast<BlockArgument>(source->get());
+  Value resultingValue = iterArg.getOwner()->getParentOp()->getResult(0);
+
+  // Step 3. Get users.
+  std::optional<OpOperand *> destinationIterArg;
+  Operation *untiledConsumer;
+  for (Operation *user : resultingValue.getUsers()) {
+    // TODO(avarma): Address the case where the consumer op itself can return
+    //               more than one result.
+    for (Value operand : user->getOperands()) {
+      if (operand == resultingValue) {
+        untiledConsumer = user;
+        break;
+      }
+      operandNumber++;
+    }
+    break;
+  }
+  return {untiledConsumer, destinationIterArg};
+}
+
 /// Return the untiled producer whose slice is used in a tiled consumer. The
 /// method traverses the tile loop nest (`loops`) if needed, and returns the
 /// `iter_args` of the outer most that is encountered. Traversing the iter_args
@@ -820,6 +874,463 @@ getUntiledProducerFromSliceSource(OpOperand *source,
   return {dyn_cast<OpResult>(source->get()), destinationIterArg};
 }
 
+/// Implementation of fusing consumer of a single slice by computing the
+/// slice of the consumer in-place for scf.forall.
+static FailureOr<scf::SCFFuseConsumerOfSliceResult>
+tileAndFuseConsumerOfSliceSCFForall(
+    RewriterBase &rewriter, tensor::ParallelInsertSliceOp candidateSliceOp) {
+  // 1. Get the consumer of the source.
+  unsigned operandNumber = 0;
+  auto [consumerOp, destinationInitArg] =
+      getUntiledConsumerFromSliceDestSCFForall(
+          &candidateSliceOp.getDestMutable(), operandNumber);
+  if (!consumerOp)
+    return failure();
+  OpBuilder::InsertionGuard g(rewriter);
+  // Using candidateSliceOp->getParentOp() because we have the following case :-
+  // scf.forall.in_parallel {
+  //   tensor.parallel_insert_slice ...
+  // }
+  rewriter.setInsertionPoint(candidateSliceOp->getParentOp());
+
+  Operation *containingOp = candidateSliceOp->getParentOp()->getParentOp();
+  // Check consumer has tiling interface.
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    llvm::outs() << "consumer is not a TileableInterface: " << *consumerOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    llvm::outs() << "containing op is not a scf.forall: " << containingOp
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        llvm::outs()
+            << "consumer's operand use more than one containingOp's result"
+            << "\n";
+        return failure();
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    llvm::outs() << "consumer op should have destination style op interface"
+                 << "\n";
+    return failure();
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    llvm::outs() << "consumer op take result of scf.forall as init"
+                 << "\n";
+    return failure();
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+  for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+    auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+      if (!targetInsertOp) {
+        targetInsertOp = parallelInsertSliceOp;
+      } else {
+        llvm::outs() << "containingOp's result inserted multi time"
+                     << "\n";
+        return failure();
+      }
+    }
+  }
+
+  if (!targetInsertOp) {
+    llvm::outs() << "containingOp's result was not inserted"
+                 << "\n";
+    return failure();
+  }
+
+  SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    llvm::outs() << "containingOp's result yield with stride"
+                 << "\n";
+    return failure();
+  }
+
+  Location loc = forallOp.getLoc();
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    llvm::outs() << "can't get iter domain position from input position"
+                 << "\n";
+    return failure();
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      llvm::outs()
+          << "can't get result domain position from iter domain position"
+          << "\n";
+      return failure();
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    llvm::outs() << "get tiled implementation failed"
+                 << "\n";
+    return failure();
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    llvm::outs() << "failed to tile consumer op: " << *tileableConsumer << "\n";
+    return failure();
+  }
+
+  // Replace tiled op's operand.
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(consumerOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) ...
[truncated]

Copy link

github-actions bot commented Apr 15, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great start! I reviewed this to some extent. I left one main comment about how to restructure this to make this close to tileConsumerAndFuseProducer that makes it easy to handle some of the destination args correctly. Let me know if you want to chat about this more in a setting with higher-bandwidth.

@Abhishek-Varma
Copy link
Contributor Author

Thank you so much @MaheshRavishankar for such detailed implementation pointers! It was super helpful!

I have addressed other comments as well.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Abhishek-Varma ! This is looking much closer. I have another round of comments, but I think this is looking pretty close to being landable....

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @Abhishek-Varma . I left a comment, if it is unclear we can do some quick pair programming to get this working, or I can take this PR and get a better idea of the state of the code at each point you are seeing here.

@Abhishek-Varma
Copy link
Contributor Author

Abhishek-Varma commented Apr 19, 2024

Thank you @MaheshRavishankar @ftynse for your review comments! The one involving OpOperand & was extremely helpful and made the code super clean! Thank you.

I've addressed all your comments except the ones related to TilingInterface.td and TilingInterfaceImpl.cpp - since it is based of the in-flight PR 88528 that will add these.
Perhaps it's best handled there?

The major changes in the latest push in this PR are :-

  1. Add ability to deal with consumer yielding multiple values - previously I was constraining it to be a single yielding consumer. (You may check the new lit tests I've added).
  2. I have updated all the assumptions with necessary checks and returning failure() in case it fails.
  3. The newer algo mentioned here.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I have seen this enough and I am comfortable with this landing. Over time might be better to collapse the scf.for implementation and scf.forall implementation into one using loop interface and maybe adding appropriate interface methods.

Thanks @Abhishek-Varma

@MaheshRavishankar
Copy link
Contributor

Thank you @MaheshRavishankar @ftynse for your review comments! The one involving OpOperand & was extremely helpful and made the code super clean! Thank you.

I've addressed all your comments except the ones related to TilingInterface.td and TilingInterfaceImpl.cpp - since it is based of the in-flight PR 88528 that will add these. Perhaps it's best handled there?

The major changes in the latest push in this PR are :-

  1. Add ability to deal with consumer yielding multiple values - previously I was constraining it to be a single yielding consumer. (You may check the new lit tests I've added).
  2. I have updated all the assumptions with necessary checks and returning failure() in case it fails.
  3. The newer algo mentioned here.

Thanks @Abhishek-Varma lets land this PR after #85528 lands.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First of all, thanks for your contribution @Abhishek-Varma

Now, I am afraid I cannot subscribe to any large code changes to tiling transforms until #77874 is addressed to my satisfaction. I already gave a pass 6 months ago and another one 3 months ago when things were supposed to be addressed "in short order".

We're now past tech-debt reduction time:

  1. first let's address [mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. #77874,
  2. then let's address the reuse as also pointed out by @ftynse
All of this code looks very similar to the scf.for case. Is it possible to hoist it out into a helper function, potentially with templates?

More generally, the entire 1,2,3,4, ... thing here could be a sequence of calls to appropriately named and documented functions. But not blocking on this one, only hoisting out long blocks of common code.

Thanks for helping move this forward!

At this time, more copy-pasta of code is unwelcome. Let's first reduce the duplication

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for pushing on this!

ftynse
ftynse previously requested changes May 21, 2024
Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please bear with me. This is a rather big and complex patch, I need time to properly re-review this, in particular design comments and tests.

operandToReplace.set(sliceOp.getResult());
} else if (auto sliceOp =
dyn_cast<tensor::ParallelInsertSliceOp>(candidateSliceOp)) {
operandToReplace.set(sliceOp.getSource());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems off. Here the source of the sliceOp is the tiled operation, but this is used to replace the operand of the cloned consumer which is untiled. This is also potentially related to the extract_slice issue pointed out by @qedawkins earlier. This is what I think this should be doing

  1. Create a new tensor.insert_slice that represent the insertion. In the scf.for case this is a clone of the candidateSliceOp. In the scf.forall case this is created from the tensor.parallel_insert_slice except now it has a result.
  2. Clone the consumer op (already done here). Replace the operand in the consumer along which fusion is happening with the result of the tensor.insert_slice.
  3. Call the method tensor::replaceInsertSliceWithTiledConsumer with the cloned insert slice op and the cloned consumer op
  4. Now replace all tensor.extract_slice uses of the cloned tensor.insert_slice with the source of the cloned tensor.insert_slice. By construction all the slice uses of the tensor.insert_slice are exactly the same shape as the source.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much @MaheshRavishankar ! This made so much sense and has really helped clean up the issue with tensor.extract_slice!

I've addressed it in the latest push.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By construction all the slice uses of the tensor.insert_slice are exactly the same shape as the source.

Where does this guarantee come from? Looking at the implementation of replaceInsertSliceWithTiledConsumer it is just calling getTiledImplementationFromOperandTile without passing in the source of the insert_slice.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding here an explanation in case it helps :-

State of the IR after step 3 of the algo in the above thread :

         %clone_insert_slice = cloned tensor.insert_slice :
                                    %source<32> into %dest<64> | OFFSET | STRIDES | SIZES
                           
         %tiled_operand = tensor.extract_slice:
                           %clone_insert_slice<64> to <32> | OFFSET | STRIDES | SIZES
          
         %res = tiled_consumer_op %tiled_operand

NOTE:

  1. OFFSETS | STRIDES | SIZES of both tensor.insert_slice and tensor.extract_slice is SAME, therefore it is a valid intermediate IR yielded by the getTiledImplementationFromOperandTile API which is in fact taking those OFFSETS/STRIDES info. And this basically means that tensor.insert_slice and tensor.extract_slice is just inverse of each other.
  2. Right before calling the API to generate tiled consumer op, we have
    %res = consumer_op %clone_insert_slice<64>.

Now, perhaps what @MaheshRavishankar meant above for step 4 is that by design the use of %tiled_operand is going to be exactly the same shape as the "source" of %clone_insert_slice. Thus we can simply replace all uses of %tiled_operand with %source.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My main question is where in the tiling interface do we require that the tiled implementation generates this exact %tiled_operand = tensor.extract_slice | OFFSET | STRIDES | SIZES extract slice? The only requirement I see in getTiledImplementation is that it produces a result tile that corresponds to the given offsets/sizes/strides. I don't see any requirements given about how the operands are used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh okay.
Well, not sure about the doc comments there, but implementation-wise getTiledImplementation does in fact operate on the operands too via makeTiledShapes here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's for linalg. Other operations can do whatever they want (not trying to be difficult, just trying to make sure we at least clarify any implicit interface API requirements like this).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By construction all the slice uses of the tensor.insert_slice are exactly the same shape as the source.

Where does this guarantee come from? Looking at the implementation of replaceInsertSliceWithTiledConsumer it is just calling getTiledImplementationFromOperandTile without passing in the source of the insert_slice.

You are using the slice of the operand to compute the slice of the iteration space that computes that operand. That is only possible if that is a bijection. Then the tile of the operand computed from the iteration space is going to be the same. Things have to be consistent.

Yes, we are indexing on tensor.extract_slice and tensor.insert_slice, which we can maybe eventually switch to an interface that allows different kind of "slices", but this is the current state right now. If we miss it, then its a "missed optimization" and not a correctness issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I'll defer to you then, but I think it would be good to document this implicit dependence on extract_slice on the SCF api.

@Abhishek-Varma Abhishek-Varma force-pushed the avarma_fuse_consumer branch from 60a1b7b to 7e9f0b5 Compare May 22, 2024 08:15
Copy link
Contributor

@cxy-1993 cxy-1993 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this patch LGTM now

Copy link
Contributor

@Yun-Fly Yun-Fly left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now when we come to tile and fuse of producer, the flip of this is all we should need. IIRC this is what this patch initially started with but through many many rounds of reviews has deflected from initial intent (which is now lost in all the comment threads).

As @MaheshRavishankar said, from this view, this PR LGTM overall.

Also, I have same concerned about what you are talking related to tiling interface change. However, personally, I prefer that we can open a RFC and another PR to formally solve it,

@Abhishek-Varma Abhishek-Varma force-pushed the avarma_fuse_consumer branch from 611cb4e to 17037e9 Compare May 30, 2024 05:36
@Abhishek-Varma
Copy link
Contributor Author

Hi all.

Please let me know if the current state of the PR is okay to land and then carry on the RFC discussion about.

The growing thread of discussion is making this quite difficult to follow/manage - I tried sifting through the comments to the best of my capabilities/understanding and have accordingly reverted the state to what it was previously (with lesser conflicts) and I believe we should be able to unblock landing this PR now and shift the discussion thread as an RFC.

CC: @MaheshRavishankar @ftynse @qedawkins @nicolasvasilache @Yun-Fly @cxy-1993

@ftynse
Copy link
Member

ftynse commented May 30, 2024

Also, I have same concerned about what you are talking related to tiling interface change.

Which specific concerns?

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The growing thread of discussion is making this quite difficult to follow/manage - I tried sifting through the comments to the best of my capabilities/understanding and have accordingly reverted the state to what it was previously (with lesser conflicts) and I believe we should be able to unblock landing this PR now and shift the discussion thread as an RFC.

Please add a note in the documentation of getTiledImplementationFromResultTile that, for most operations, it should be a trivial composition of two other methods. This will significantly simplify the life of any downstream (other than yours) that will want to adopt your change.

And make sure you are comfortable with the level of understanding of the changes you are about to submit....

I am unblocking this. My arguments against the design still stand, but they date back to https://reviews.llvm.org/D127809, where I should have requested an RFC for changing something in lib/Interfaces. I won't make that mistake again, but @Abhishek-Varma is not the one who introduced that, so it's not up to him to resolve it.

@ftynse ftynse self-requested a review May 30, 2024 08:23
@Yun-Fly
Copy link
Contributor

Yun-Fly commented May 30, 2024

Which specific concerns?

The direction of how to refactor getTiledImplementation and other related interface. For instance, what getTiledImplementation is responsible for, the definition of which becomes a little confusing and muddled since we introduce consumer fusion.

My thinking here is that getTiledImplementation does too much. It should not be emitting IR that computes a slice of each operand. Instead, it should be given such slices.

I guess you mean that preparing tiled operand by caller in advance and than calling getTiledImplementationGivenOperandTiles as you mentioned. In this way, it looks more like a clone operation wrapper specially with given tiled operand.

I have strong concerns with this. getTiledImplementation is necessarily black boxed. I would not want to assume that you can produce an implementation of tile from the slice. Also it seems backwards to me to compute the slice of the operands, then pass it to the tiling implementation to say produce the tile.

In another side, I guess @MaheshRavishankar insist that how to take slice is also one component of how to do tiling on on an operation, which should not be hand over to the caller to deal with tiled operand. BTW, I think anyway we can create a dummy insert_slice as he suggested to pass dummy operand with original size to tiling implementation.

%tiled_v = tiledProducer ...

// create dummy `insert_slice` to align with what current `getTiledImplementation` expect for
%dummy_insert = tensor.insert_slice %tiled_v into %?  [OFFSET] [SIZE]: tensor<tiled_shape> into tensor<original_shape>

// clonedConsumer should use %dummy_insert as its new operand

It is really hard to stand by which one(or even any other new proposal) is general solution before we have much clear definition of responsibility taken by getTiledImplementation, and that is the direction what I mean at the beginning.

@ftynse ftynse dismissed their stale review May 30, 2024 09:32

Abstaining from approval

@MaheshRavishankar
Copy link
Contributor

The growing thread of discussion is making this quite difficult to follow/manage - I tried sifting through the comments to the best of my capabilities/understanding and have accordingly reverted the state to what it was previously (with lesser conflicts) and I believe we should be able to unblock landing this PR now and shift the discussion thread as an RFC.

Please add a note in the documentation of getTiledImplementationFromResultTile that, for most operations, it should be a trivial composition of two other methods. This will significantly simplify the life of any downstream (other than yours) that will want to adopt your change.

A follow up here is indeed what would be good to add to the interface... I will update the documentation for TilingInterface to make things a bit more clear in terms of how the interface is setup.

And make sure you are comfortable with the level of understanding of the changes you are about to submit....

I consider myself the defacto maintainer of the TilingInterface (and its use in the tiling algorithm under SCF dialect using the interface). I have spent enough time here understand what is implemented here, and I am comfortable with this landing. Please redirect any concerns/issues you see to me, and I will try to address/contextualize them.

I am unblocking this. My arguments against the design still stand, but they date back to https://reviews.llvm.org/D127809, where I should have requested an RFC for changing something in lib/Interfaces. I won't make that mistake again, but @Abhishek-Varma is not the one who introduced that, so it's not up to him to resolve it.

Since this was originally introduced by me, I am happy to chat more. Value your concerns a lot here, but we might be approaching this from different angles, but there is no reason why we cant find a meet point. I am happy to setup some time to discuss your concerns offline and reach a shared understanding.

@MaheshRavishankar
Copy link
Contributor

First of all, thanks for your contribution @Abhishek-Varma

Now, I am afraid I cannot subscribe to any large code changes to tiling transforms until #77874 is addressed to my satisfaction. I already gave a pass 6 months ago and another one 3 months ago when things were supposed to be addressed "in short order".

We're now past tech-debt reduction time:

  1. first let's address [mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. #77874,
  2. then let's address the reuse as also pointed out by @ftynse
All of this code looks very similar to the scf.for case. Is it possible to hoist it out into a helper function, potentially with templates?

More generally, the entire 1,2,3,4, ... thing here could be a sequence of calls to appropriately named and documented functions. But not blocking on this one, only hoisting out long blocks of common code.

Thanks for helping move this forward!

At this time, more copy-pasta of code is unwelcome. Let's first reduce the duplication

I think #91878 addresses the concerns in this comment. That is also landing soon. @nicolasvasilache please unblock if you agree.

@nicolasvasilache
Copy link
Contributor

Will look tomorrow, thanks for flagging!

@nicolasvasilache nicolasvasilache dismissed their stale review May 31, 2024 15:08

Assuming #91878 gets finished and landed, my main concern does not hold anymore.

@MaheshRavishankar
Copy link
Contributor

CI failures seem unrelated. Merging this change.

@MaheshRavishankar MaheshRavishankar merged commit 2b2ce50 into llvm:main Jun 1, 2024
5 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants