Skip to content

[mlir][Linalg] Deprecate linalg::tileToForallOp and linalg::tileToForallOpUsingTileSizes #91878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

MaheshRavishankar
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar commented May 12, 2024

The implementation of these methods are legacy and they are removed in favor of using the scf::tileUsingSCF methods as replacements. To get the latter on par with requirements of the deprecated methods, the tiling allows one to specify the maximum number of tiles to use instead of specifying the tile sizes. When tiling to scf.forall this specification is used to generate the num_threads version of the operation.

A slight deviation from previous implementation is that the deprecated method always generated the num_threads variant of the scf.forall operation. Instead now this is driven by the tiling options specified. This reduces the indexing math generated when the tile sizes are specified.

Moving from linalg::tileToForallOp to scf::tileUsingSCF

OpBuilder b;
TilingInterface op;
ArrayRef<OpFoldResult> numThreads;
ArrayAttr mapping;
FailureOr<ForallTilingResult> result =linalg::tileToForallOp(b, op, numThreads, mapping);

can be replaced by

scf::SCFTilingOptions options;
options.setNumThreads(numThreads);
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options);

This generates the numThreads version of the scf.forall for the inter-tile loops, i.e.

... = scf.forall (%arg0, %arg1) in (%nt0, %nt1) shared_outs(...)

Moving from linalg::tileToForallOpUsingTileSizes to scf::tileUsingSCF

OpBuilder b;
TilingInterface op;
ArrayRef<OpFoldResult> tileSizes;
ArrayAttr mapping;
FailureOr<ForallTilingResult> result =linalg::tileToForallOpUsingTileSizes(b, op, tileSizes, mapping);

can be replaced by

scf::SCFTilingOptions options;
options.setTileSizes(tileSizes);
options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
options.setMapping(mapping.getValue()); /*note the difference that setMapping takes an ArrayRef<Attribute> */
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(b, op, options);

Also note that linalg::tileToForallOpUsingTileSizes would effectively call the linalg::tileToForallOp by computing the numThreads from the op and tileSizes and generate the numThreads version of the scf.forall. That is not the case anymore. Instead this will directly generate the tileSizes version of the scf.forall op

... = scf.forall(%arg0, %arg1) = (%lb0, %lb1) to (%ub0, %ub1) step(%step0, %step1) shared_outs(...)

If you actually want to use the numThreads version, it is upto the caller to compute the numThreads and set options.setNumThreads instead of options.setTileSizes. Note that there is a slight difference in the num threads version and tile size version. The former requires an additional affine.max on the tile size to ensure non-negative tile sizes. When lowering to numThreads version this affine.max is not needed since by construction the tile sizes are non-negative. In previous implementations, the numThreads version generated when using the linalg::tileToForallOpUsingTileSizes method would avoid generating the affine.max operation. To get the same state, downstream users will have to additionally normalize the scf.forall operation.

Changes to transform.structured.tile_using_forall

The transform dialect op that called into linalg::tileToForallOp and linalg::tileToForallOpUsingTileSizes have been modified to call scf::tileUsingSCF. The transform dialect op always generates the numThreads version of the scf.forall op. So when tile_sizes are specified for the transform dialect op, first the tile_sizes version of the scf.forall is generated by the scf::tileUsingSCF method which is then further normalized to get back to the same state. So there is no functional change to transform.structured.tile_using_forall. It always generates the numThreads version of the scf.forall op (as it did before this change).

@MaheshRavishankar MaheshRavishankar force-pushed the deprecate_tile_to_forall branch from ff45ad2 to 8d74612 Compare May 12, 2024 05:45
@MaheshRavishankar MaheshRavishankar changed the title [mlir][SCF] Allow tiling by specifying maximum number of tiles. [mlir][SCF] Deprecate linalg::tileToForallOp and linalg::tileToForallOpUsingTileSizes May 12, 2024
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making progress on this!

I see this seems to single out tiling by tileSizes, what is the plan for numThreads?
I want to make sure this is not an attempt to remove it because it would be a non-starter: it connects to further transforms for which the information is important.

@MaheshRavishankar
Copy link
Contributor Author

Thanks for making progress on this!

I see this seems to single out tiling by tileSizes, what is the plan for numThreads?
I want to make sure this is not an attempt to remove it because it would be a non-starter: it connects to further transforms for which the information is important.

It's not being removed. The numtiles option added here translates to num threads when using scr.forall (as mentioned in the description). I have t updated the lit tests yet cause I wanted to gather some initial feedback before going and fixing the lit tests (they all run and I have visually verified they look ok, but some existing lit tests have some strange semantics like zero sliced tensors)

Id like to understand more what is load bearing on the num threads aspects as opposed to just using tile sizes. They seem interchangeable to me. I have a mental model of how you can use tile sizes and make it work all the way. From whatever I have seen the num threads path creates more indexing math and some weird corner cases (there is a lit test that is producing zero sized slices for example that is particularly scary).
Can you provide some information about what transformations need the num threads information that cannot be derived when needed (like just getting the number of iterations of the loop)

@nicolasvasilache
Copy link
Contributor

nicolasvasilache commented May 17, 2024

Can you provide some information about what transformations need the num threads information that cannot be derived when needed (like just getting the number of iterations of the loop)

Not immediately as this is out of my immediate working memory but TL;DR:

  • logic to reinfer number of threads from IR and in particular 1-trip count loop etc is extremely brittle as soon as things don't trivially divide. IIRC it may be quite brittle even in cases that things divide but do not quote me on this
  • tile dynamic sized ops by dynamic adaptive tile sizes and things become even more hairy
  • rediscovering a static number of threads after we threw the information away is not worth it

@MaheshRavishankar MaheshRavishankar force-pushed the deprecate_tile_to_forall branch from 8d74612 to 0bb323b Compare May 22, 2024 06:22
@MaheshRavishankar MaheshRavishankar marked this pull request as ready for review May 22, 2024 06:58
@llvmbot
Copy link
Member

llvmbot commented May 22, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: None (MaheshRavishankar)

Changes

The implementation of these methods are legacy and they are removed in favor of using the scf::tileUsingSCF methods as replacements. To get the latter on par with requirements of the deprecated methods, the tiling allows one to specify the maximum number of tiles to use instead of specifying the tile sizes. When tiling to scf.forall this specification is used to generate the num_threads version of the operation.

A slight deviation from previous implementation is that the deprecated method always generated the num_threads variant of the scf.forall operation. Instead now this is driven by the tiling options specified. This reduces the indexing math generated when the tile sizes are specified.


Patch is 59.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/91878.diff

11 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h (+5-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (-24)
  • (modified) mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h (+28-7)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+35-13)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp (-182)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+244-75)
  • (modified) mlir/test/Dialect/Linalg/tile-to-forall.mlir (+25-28)
  • (modified) mlir/test/Dialect/Linalg/transform-op-tile.mlir (+11-18)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir (+5-5)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir (+25-25)
  • (modified) mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp (+1-5)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
index 3af642752724c..db25c9b241734 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h
@@ -30,6 +30,10 @@ class GenericOp;
 class LinalgOp;
 } // namespace linalg
 
+namespace scf {
+struct SCFTilingResult;
+} // namespace scf
+
 namespace tensor {
 class InsertSliceOp;
 class PackOp;
@@ -60,7 +64,7 @@ tileToForallOpImpl(RewriterBase &rewriter, transform::TransformState &state,
                    ArrayRef<OpFoldResult> mixedNumThreads,
                    ArrayRef<OpFoldResult> mixedTileSizes,
                    std::optional<ArrayAttr> mapping,
-                   linalg::ForallTilingResult &tilingResult);
+                   scf::SCFTilingResult &tilingResult);
 
 } // namespace transform
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index f77c19ed0fcce..73fd6a469d0e7 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -846,30 +846,6 @@ FailureOr<StaticMultiSizeSpecification>
 computeStaticMultiTileSizes(LinalgOp op, unsigned dimension, int64_t targetSize,
                             int64_t divisor);
 
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`, applying
-/// tiling by `numThreads`.
-/// If non-empty, the `mapping` is added as an attribute to the
-/// resulting `scf.forall`.
-/// Zero tile sizes indicate that the dimension is not tiled, and can be
-/// thought of as tiling by the full size of data. It is the user's
-/// responsibility to ensure that `numThreads` is a valid tiling specification
-/// (i.e. that only tiles parallel dimensions, e.g. in the Linalg case).
-struct ForallTilingResult {
-  Operation *tileOp;
-  Operation *tiledOp;
-};
-FailureOr<ForallTilingResult> tileToForallOp(RewriterBase &builder,
-                                             TilingInterface op,
-                                             ArrayRef<OpFoldResult> numThreads,
-                                             std::optional<ArrayAttr> mapping);
-
-/// Same as `tileToForallOp`, but calculate the number of threads
-/// required using the given tileSizes.
-FailureOr<ForallTilingResult>
-tileToForallOpUsingTileSizes(RewriterBase &builder, TilingInterface op,
-                             ArrayRef<OpFoldResult> tileSizes,
-                             std::optional<ArrayAttr> mapping);
-
 /// Transformation information returned after reduction tiling.
 struct ForallReductionTilingResult {
   /// The partial reduction tiled op generated.
diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
index 965ef9e203be2..f6858c3cddf46 100644
--- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
+++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h
@@ -31,9 +31,11 @@ using SCFTileSizeComputationFunction =
 
 /// Options to use to control tiling.
 struct SCFTilingOptions {
-  /// Computation function that returns the tile sizes for each operation.
-  /// Delayed construction of constant tile sizes should occur to interoperate
-  /// with folding.
+  /// Computation function that returns the tile sizes to use for each loop.
+  /// Returning a tile size of zero implies no tiling for that loop. If the
+  /// size of the returned vector is smaller than the number of loops, the inner
+  /// loops are not tiled. If the size of the returned vector is larger, then
+  /// the vector is truncated to number of loops.
   SCFTileSizeComputationFunction tileSizeComputationFunction = nullptr;
 
   SCFTilingOptions &
@@ -44,7 +46,27 @@ struct SCFTilingOptions {
   /// Convenience function to set the `tileSizeComputationFunction` to a
   /// function that computes tile sizes at the point they are needed. Allows
   /// proper interaction with folding.
-  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> ts);
+  SCFTilingOptions &setTileSizes(ArrayRef<OpFoldResult> tileSizes);
+
+  /// Computation function that returns the number of threads to use for
+  /// each loop. Returning a num threads of zero implies no tiling for that
+  /// loop. If the size of the returned vector is smaller than the number of
+  /// loops, the inner loops are not tiled. If the size of the returned vector
+  /// is larger, then the vector is truncated to number of loops. Note: This
+  /// option is only supported with loopType set to `LoopType::ForallOp`. If the
+  /// tile size function is not specified while the num threads computation is,
+  /// then the tile size is determined automatically to map at most one tile per
+  /// thread.
+  SCFTileSizeComputationFunction numThreadsComputationFunction = nullptr;
+
+  SCFTilingOptions &
+  setNumThreadsComputationFunction(SCFTileSizeComputationFunction fun) {
+    numThreadsComputationFunction = std::move(fun);
+    return *this;
+  }
+  /// Convenience function to set the `tileSizeComputationFunction` to a
+  /// function that computes tile sizes at the point they are needed.
+  SCFTilingOptions &setNumThreads(ArrayRef<OpFoldResult> numThreads);
 
   /// The interchange vector to reorder the tiled loops.
   SmallVector<int64_t> interchangeVector = {};
@@ -66,9 +88,8 @@ struct SCFTilingOptions {
   /// when using loop constructs that dont support such a mapping (like
   /// `scf.for`)
   SmallVector<Attribute> mappingVector = {};
-  SCFTilingOptions &setMapping(ArrayRef<DeviceMappingAttrInterface> mapping) {
-    mappingVector = llvm::map_to_vector(
-        mapping, [](auto attr) -> Attribute { return attr; });
+  SCFTilingOptions &setMapping(ArrayRef<Attribute> mapping) {
+    mappingVector = llvm::to_vector(mapping);
     return *this;
   }
 };
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 13582a140a965..e9bcbee5c1a8a 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -2917,7 +2917,7 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     TransformOpInterface transformOp, Operation *target,
     ArrayRef<OpFoldResult> mixedNumThreads,
     ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
-    linalg::ForallTilingResult &tilingResult) {
+    scf::SCFTilingResult &tilingResult) {
   // Transform all targets one by one.
   auto tileableOp = dyn_cast<TilingInterface>(target);
   if (!tileableOp) {
@@ -2928,18 +2928,39 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl(
     return diag;
   }
   rewriter.setInsertionPoint(tileableOp);
-  FailureOr<linalg::ForallTilingResult> maybeTilingResult = failure();
+  scf::SCFTilingOptions options;
+  options.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp);
   if (!mixedNumThreads.empty()) {
-    maybeTilingResult =
-        linalg::tileToForallOp(rewriter, tileableOp, mixedNumThreads, mapping);
+    options.setNumThreads(mixedNumThreads);
   } else {
-    maybeTilingResult = linalg::tileToForallOpUsingTileSizes(
-        rewriter, tileableOp, mixedTileSizes, mapping);
+    SmallVector<Range> loopRanges = tileableOp.getIterationDomain(rewriter);
+    unsigned nLoops = loopRanges.size();
+    SmallVector<OpFoldResult> numThreads;
+    numThreads.reserve(nLoops);
+    AffineExpr s0, s1;
+    bindSymbols(rewriter.getContext(), s0, s1);
+    AffineExpr divExpr = s0.ceilDiv(s1);
+    for (int i = 0, e = std::min(mixedTileSizes.size(), loopRanges.size());
+         i < e; ++i) {
+      OpFoldResult numTiles = mixedTileSizes[i];
+      if (!isConstantIntValue(numTiles, 0))
+        numTiles = affine::makeComposedFoldedAffineApply(
+            rewriter, tileableOp.getLoc(), divExpr,
+            {loopRanges[i].size, numTiles});
+      numThreads.push_back(numTiles);
+    }
+    options.setNumThreads(numThreads);
+    options.setTileSizes(mixedTileSizes);
+  }
+  if (mapping) {
+    options.setMapping(mapping.value().getValue());
   }
+  FailureOr<scf::SCFTilingResult> maybeTilingResult =
+      scf::tileUsingSCF(rewriter, tileableOp, options);
 
   if (failed(maybeTilingResult))
     return transformOp.emitDefaultSilenceableFailure(tileableOp);
-  rewriter.replaceOp(tileableOp, maybeTilingResult->tileOp->getResults());
+  rewriter.replaceOp(tileableOp, maybeTilingResult->replacements);
 
   tilingResult = *maybeTilingResult;
   return DiagnosedSilenceableFailure::success();
@@ -2975,14 +2996,14 @@ DiagnosedSilenceableFailure transform::TileUsingForallOp::apply(
     return status;
 
   for (Operation *target : state.getPayloadOps(getTarget())) {
-    linalg::ForallTilingResult tilingResult;
+    scf::SCFTilingResult tilingResult;
     DiagnosedSilenceableFailure diag = tileToForallOpImpl(
         rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes,
         getMapping(), tilingResult);
     if (!diag.succeeded())
       return diag;
-    tileOps.push_back(tilingResult.tileOp);
-    tiledOps.push_back(tilingResult.tiledOp);
+    tileOps.push_back(tilingResult.loops.front());
+    tiledOps.append(tilingResult.tiledOps);
   }
 
   transformResults.set(cast<OpResult>(getForallOp()), tileOps);
@@ -3460,7 +3481,7 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
 
   // OpBuilder only used to compute attributes.
   OpBuilder b(getContext());
-  linalg::ForallTilingResult tilingResult;
+  scf::SCFTilingResult tilingResult;
   DiagnosedSilenceableFailure diag = tileToForallOpImpl(
       /*rewriter=*/rewriter,
       /*state=*/state,
@@ -3473,8 +3494,9 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne(
   if (!diag.succeeded())
     return diag;
 
-  results.push_back(tilingResult.tileOp);
-  results.push_back(tilingResult.tiledOp);
+  results.push_back(tilingResult.loops.front());
+  for (auto op : tilingResult.tiledOps)
+    results.push_back(op);
   return DiagnosedSilenceableFailure::success();
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
index df4089d61bfd7..30031a443f7d8 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp
@@ -304,188 +304,6 @@ static void calculateTileOffsetsAndSizes(
   }
 }
 
-/// Returns a vector of bools representing if, for each axis, `op` can be tiled
-/// without incurring in a race condition and thus it is thread-safe to do the
-/// tiling. This is checked by iterating over numThreads and ensuring that the
-/// corresponding iterator type is "parallel". If it is not, then we know that
-/// such dimension is unsafe to tile.
-SmallVector<bool> safeToTileToForall(mlir::MLIRContext *ctx, LinalgOp linalgOp,
-                                     ArrayRef<OpFoldResult> numThreads) {
-  auto iterators = linalgOp.getIteratorTypesArray();
-  SmallVector<bool> safeToTile(numThreads.size(), true);
-
-  for (unsigned i = 0, e = numThreads.size(); i != e; i++) {
-    if (auto attr = llvm::dyn_cast_if_present<Attribute>(numThreads[i])) {
-      if (cast<IntegerAttr>(attr).getValue().getSExtValue() > 1) {
-        safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-      }
-    } else {
-      safeToTile[i] = iterators[i] == utils::IteratorType::parallel;
-    }
-  }
-  return safeToTile;
-}
-
-/// Rewrite a TilingInterface `op` to a tiled `scf.forall`. The
-/// tiling is specified by the number of tiles/threads `numThreads` and the
-/// optional nominal tile size `nominalTileSizes`. If `nominalTilSizes` is
-/// not specified, then  it is derived from `numThreads` as `ceilDiv(dimSize[i],
-/// numThreads[i])`. If non-empty, the `mapping` is added as an
-/// attribute to the resulting `scf.forall`. A zero tile sizes indicate
-/// that the dimension is not tiled, and can be thought of as tiling by the full
-/// size of data.
-/// It is the user's responsibility to ensure that `numThreads` is a valid
-/// tiling specification (i.e. that only tiles parallel dimensions, e.g. in the
-/// Linalg case). If the dimension is not parallelizable, a warning is issued to
-/// notify the user that the generated code is not safe to parallelize. If
-/// `omitTileOffsetBoundsCheck` is true, then the function will assume that
-/// `tileSize[i] * (numThread[i] -1) <= dimSize[i]` holds.
-static FailureOr<ForallTilingResult> tileToForallOpImpl(
-    RewriterBase &b, TilingInterface op, ArrayRef<OpFoldResult> numThreads,
-    std::optional<ArrayRef<OpFoldResult>> nominalTileSizes,
-    std::optional<ArrayAttr> mapping, bool omitTileOffsetBoundsCheck) {
-  Location loc = op->getLoc();
-  OpBuilder::InsertionGuard g(b);
-
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  if (loopRanges.empty())
-    return op->emitOpError("expected non-empty loop ranges");
-  auto hasStrideOne = [](Range r) { return !isConstantIntValue(r.stride, 1); };
-  if (llvm::any_of(loopRanges, hasStrideOne))
-    return op->emitOpError("only stride-1 supported atm");
-
-  // Gather destination tensors.
-  SmallVector<Value> dest;
-  if (failed(tensor::getOrCreateDestinations(b, loc, op, dest)))
-    return op->emitOpError("failed to get destination tensors");
-
-  SmallVector<OpFoldResult> nonZeroNumThreads =
-      llvm::to_vector(llvm::make_filter_range(numThreads, [](OpFoldResult ofr) {
-        return !isConstantIntValue(ofr, 0);
-      }));
-  SmallVector<Value> materializedNonZeroNumThreads =
-      llvm::to_vector(llvm::map_range(nonZeroNumThreads, [&](OpFoldResult ofr) {
-        return getValueOrCreateConstantIndexOp(b, loc, ofr);
-      }));
-
-  LinalgOp linalgOp = dyn_cast<LinalgOp>(op.getOperation());
-  if (linalgOp) {
-    // Check if tiling is thread safe and print a warning if not.
-    SmallVector<bool> tilingSafety =
-        safeToTileToForall(b.getContext(), linalgOp, numThreads);
-    for (size_t i = 0; i < tilingSafety.size(); i++)
-      if (!tilingSafety[i])
-        op.emitWarning() << "tiling is not thread safe at axis #" << i;
-  }
-
-  // 1. Create the ForallOp. We don't use the lambda body-builder
-  // version because we require the use of RewriterBase in the body, so we
-  // manually move the insertion point to the body below.
-  scf::ForallOp forallOp = b.create<scf::ForallOp>(
-      loc, getAsOpFoldResult((materializedNonZeroNumThreads)), dest, mapping);
-
-  // 2. Fill out the ForallOp body.
-  SmallVector<OpFoldResult> tiledOffsets, tiledSizes;
-  calculateTileOffsetsAndSizes(b, loc, forallOp, numThreads, loopRanges,
-                               omitTileOffsetBoundsCheck, nominalTileSizes,
-                               tiledOffsets, tiledSizes);
-
-  // 3. Clone the tileable op and update its destination operands to use the
-  // output bbArgs of the ForallOp.
-  ArrayRef<BlockArgument> destBbArgs = forallOp.getRegionIterArgs();
-  Operation *tiledOp = nullptr;
-  SmallVector<Value> tiledValues;
-  {
-    // 3.a. RAII guard, inserting within forallOp, before terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-    Operation *clonedOp = b.clone(*op.getOperation());
-    auto destinationStyleOp = dyn_cast<DestinationStyleOpInterface>(clonedOp);
-    if (destinationStyleOp) {
-      for (OpOperand &outOperand : destinationStyleOp.getDpsInitsMutable()) {
-        // Swap tensor inits with the corresponding block argument of the
-        // scf.forall op. Memref inits remain as is.
-        if (isa<TensorType>(outOperand.get().getType())) {
-          auto *it = llvm::find(dest, outOperand.get());
-          assert(it != dest.end() && "could not find destination tensor");
-          unsigned destNum = std::distance(dest.begin(), it);
-          outOperand.set(destBbArgs[destNum]);
-        }
-      }
-    }
-
-    // 4. Tile the cloned op and delete the clone.
-    FailureOr<TilingResult> tilingResult =
-        cast<TilingInterface>(clonedOp).getTiledImplementation(b, tiledOffsets,
-                                                               tiledSizes);
-    if (failed(tilingResult))
-      return clonedOp->emitError("Failed to tile op: ");
-    if (tilingResult->tiledOps.size() != 1) {
-      return clonedOp->emitError("expected a single produced tiled op, got ")
-             << tilingResult->tiledOps.size();
-    }
-
-    b.eraseOp(clonedOp);
-    tiledOp = tilingResult->tiledOps.front();
-    tiledValues = tilingResult->tiledValues;
-  }
-
-  // 5. Parallel insert back into the result tensor.
-  for (auto it : llvm::zip(llvm::seq(unsigned(0), unsigned(dest.size())),
-                           tiledValues, destBbArgs)) {
-    // 5.a. Partial subset information is inserted just before the terminator.
-    OpBuilder::InsertionGuard g(b);
-    b.setInsertionPoint(forallOp.getTerminator());
-
-    SmallVector<OpFoldResult> resultOffsets, resultSizes;
-    if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
-                                        tiledSizes, resultOffsets,
-                                        resultSizes)))
-      return op->emitOpError("output offsets couldn't be calculated");
-    SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
-
-    // 5.b. Parallel insertions are inserted at the end of the combining
-    // terminator.
-    b.setInsertionPointToEnd(forallOp.getTerminator().getBody());
-    b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
-                                            std::get<2>(it), resultOffsets,
-                                            resultSizes, strides);
-  }
-  return ForallTilingResult{forallOp, tiledOp};
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOp(RewriterBase &b, TilingInterface op,
-                       ArrayRef<OpFoldResult> numThreads,
-                       std::optional<ArrayAttr> mapping) {
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/std::nullopt, mapping,
-                            /*omitTileOffsetBoundsCheck=*/false);
-}
-
-FailureOr<ForallTilingResult>
-linalg::tileToForallOpUsingTileSizes(RewriterBase &b, TilingInterface op,
-                                     ArrayRef<OpFoldResult> tileSizes,
-                                     std::optional<ArrayAttr> mapping) {
-  SmallVector<Range> loopRanges = op.getIterationDomain(b);
-  unsigned nLoops = loopRanges.size();
-  SmallVector<OpFoldResult> numThreads;
-  numThreads.reserve(nLoops);
-  AffineExpr s0, s1;
-  bindSymbols(b.getContext(), s0, s1);
-  AffineExpr divExpr = s0.ceilDiv(s1);
-  for (const auto &it : llvm::zip(tileSizes, loopRanges)) {
-    OpFoldResult numTiles = std::get<0>(it);
-    if (!isConstantIntValue(numTiles, 0))
-      numTiles = makeComposedFoldedAffineApply(
-          b, op.getLoc(), divExpr, {std::get<1>(it).size, std::get<0>(it)});
-    numThreads.push_back(numTiles);
-  }
-  return tileToForallOpImpl(b, op, numThreads,
-                            /*nominalTileSizes=*/tileSizes, mapping,
-                            /*omitTileOffsetBoundsCheck=*/true);
-}
-
 template <typename LoopTy>
 static FailureOr<TiledLinalgOp>
 tileLinalgOpImpl(RewriterBase &b, LinalgOp op, ArrayRef<OpFoldResult> tileSizes,
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1a84a59ddb69d..c0878c42da3b1 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -41,6 +41,16 @@ scf::SCFTilingOptions::setTileSizes(ArrayRef<OpFoldResult> ts) {
   return *this;
 }
 
+scf::SCFTilingOptions &
+scf::SCFTilingOptions::setNumThreads(ArrayRef<OpFoldResult> nt) {
+  assert(!numThreadsComputationFunction && "num tiles already set");
+  auto numThreads = llvm::to_vector(nt);
+  numThreadsComputationFunction = [numThreads](OpBuilder &b, Operation *op) {
+    return numThreads;
+  };
+  return *this;
+}
+
 /// Helper method to adjust the interchange vector to match the iteration
 /// domain.
 static SmallVector<int64_t>
@@ -60,7 +70,117 @@ fillInterchangeVector(ArrayRef<int64_t> interchangeVector,
 // tileUsingSCF implementation.
 //===...
[truncated]

@MaheshRavishankar
Copy link
Contributor Author

Can you provide some information about what transformations need the num threads information that cannot be derived when needed (like just getting the number of iterations of the loop)

Not immediately as this is out of my immediate working memory but TL;DR:

  • logic to reinfer number of threads from IR and in particular 1-trip count loop etc is extremely brittle as soon as things don't trivially divide. IIRC it may be quite brittle even in cases that things divide but do not quote me on this
  • tile dynamic sized ops by dynamic adaptive tile sizes and things become even more hairy
  • rediscovering a static number of threads after we threw the information away is not worth it

@nicolasvasilache I think I got a handle on what was done, and this was a good pointer. Most of the test pass, there area only a few corner case lit tests that are failing for me locally (CI says something else, I need to triage that), but this is ready for review.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo some nits. I don't have historical context or the context of other users of these APIs so please wait for another approval from someone who does.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try to make a pass over the C++ before EoW

@MaheshRavishankar
Copy link
Contributor Author

@nicolasvasilache and @qedawkins Ill look into the CHECK-LABEL issue, but I have always found that CHECK-LABEL with $... which keeps names "global" is extremely hard to debug. So I just give up on CHECK-LABEL and use of $.... As long as the test passes can we be a bit flexible on the CHECK vs CHECK-LABEL?

@MaheshRavishankar MaheshRavishankar force-pushed the deprecate_tile_to_forall branch from 7b26f1e to 8b05a95 Compare May 24, 2024 04:41
@stellaraccident
Copy link
Contributor

Thanks for doing this, Mahesh.

@MaheshRavishankar MaheshRavishankar force-pushed the deprecate_tile_to_forall branch from 0d08ebe to 3de4212 Compare May 24, 2024 05:20
Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of comment, thanks for pushing this forward!

// Compute the tile sizes from the iteration domain and number
// of tiles as follows
// - niters = ceilDiv(ub - lb, step)
// - tileSize = ceilDiv(niters, numThreads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things have shifted around from the previous implementations:

  1. we used to derive iteration domain directly from op.iterationDomain, here we separate the 2 at the API boundary: any particular reason or can we just avoid the extra ArrayRef<Range> iterationDomain argument to getTileSizes?
  2. the tile size computation logic used to be centralized in a since place and now it is split amongst 2 places. Can you explain / document the need for this forall-specific (I think?) precomputation?
  3. the index logic seems to have changed here as you have more divisions followed by more multiplications than I seem to see in calculateTileOffsetsAndSizes. Is the expectation that they cancel out or are we hitting previously untested portions of the code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re 1: The iterationDomain could insert new operations. I want to avoid calling that multiple times to create the same operations.

Re 2: I dont follow cause the the tile size is centralized now. It is computed in the getTileSizes. Could you clarify which part do you think is split out?

Re 3 : I dont think there is any extra divisions. Most of the lit tests stay the same. The only ones that changed are the dynamic shape tests which I think were originally wrong (You cant really avoid generating the max operation).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My general thinking in such large PRs is to look out for potential causes of differences.
The ideal scenario is when tests don't change(i.e. the PR is NFC) and the APIs don't change too drastically.
So what I am evaluating here is whether this is good to go despite changes to test files of whether something material was changed.

Re 1: we are trading off increase in API complexity for fewer duplicate ops. I think this is a good change and not materially different, in principle this LGTM. Side note, one way to effect such improvements without significant API changes is to pass an extra struct to memoize the quantities you already created and want to reuse. This minimizes changes and reduces the probability of error while getting you the same effect.

Beyond the API complexity increase / reduction of redundant IR, there may be a catch: I see this PR introduces --cse in one of the test. My immediate gut feeling is we are introducing more IR in other places. If --cse can be dropped then I think this first point is resolved.

Re 2: calculateTileOffsetsAndSizes has related logic in the same code location. Now the logic seems split across getTileOffsetAndSizes and getTileSizes/getUserTileSizesAndNumThreads. The rename to getUserTileSizesAndNumThreads helps. So it seems we are doing a getUserTileSizesAndNumThreads before any loop is generated and later a refinemenent of getTileOffsetAndSizes within the body (passed via. lambda). The part that is putting all this within the same function is what was confusing I believe. Can we hoist out the lambdas as helper functions and reduce the complexity here? It was really not clear to me that we had an inline lambda to wrangle with: my recollections from the refactoring 6 months ago do not include this.

Re 3: this should not be in this PR IMO. This is a material change that looks important by itself (bug fixing), I don't think we should have it in the middle of such a big PR with so much code complexity.


if (useNumThreads) {
// Prune the zero numthreads.
SmallVector<OpFoldResult> nonZeroNumThreads;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something with filter_range should let us have a 1-liner here.

Still feels like common indexing utils should exist / be added for this: this is similar to lbs = 0, steps =1, ubs = numThreads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What utility are you looking at. The difference here is the way the scf.forall op is setup.
I have been thinking about how to bridge the subtle semantic difference between the tilesize version of scf.forall and the numthreads version of scf.forall. I think I have a solution, but that is probably better described by an RFC. Ill send it out shortly.

@MaheshRavishankar
Copy link
Contributor Author

@nicolasvasilache ready for another round of reviews

residualTileSizeExpr = s2 - (d0 + d1 * s0 * s1);
bindSymbols(rewriter.getContext(), s0, s1);
offsetExpr = d0 + d1 * s0;
residualTileSizeExpr = s1 - (d0 + d1 * s0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, makes sense to keep such extensions for a future PR is necessary.
Out of curiosity, do you already encounter cases with strides != 1 in the wild?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was some fly-by comment that we need to support non-unit strides in tiling, but I dont have a full context on it. Ill keep an eye out for this.

// Compute the tile sizes from the iteration domain and number
// of tiles as follows
// - niters = ceilDiv(ub - lb, step)
// - tileSize = ceilDiv(niters, numThreads)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My general thinking in such large PRs is to look out for potential causes of differences.
The ideal scenario is when tests don't change(i.e. the PR is NFC) and the APIs don't change too drastically.
So what I am evaluating here is whether this is good to go despite changes to test files of whether something material was changed.

Re 1: we are trading off increase in API complexity for fewer duplicate ops. I think this is a good change and not materially different, in principle this LGTM. Side note, one way to effect such improvements without significant API changes is to pass an extra struct to memoize the quantities you already created and want to reuse. This minimizes changes and reduces the probability of error while getting you the same effect.

Beyond the API complexity increase / reduction of redundant IR, there may be a catch: I see this PR introduces --cse in one of the test. My immediate gut feeling is we are introducing more IR in other places. If --cse can be dropped then I think this first point is resolved.

Re 2: calculateTileOffsetsAndSizes has related logic in the same code location. Now the logic seems split across getTileOffsetAndSizes and getTileSizes/getUserTileSizesAndNumThreads. The rename to getUserTileSizesAndNumThreads helps. So it seems we are doing a getUserTileSizesAndNumThreads before any loop is generated and later a refinemenent of getTileOffsetAndSizes within the body (passed via. lambda). The part that is putting all this within the same function is what was confusing I believe. Can we hoist out the lambdas as helper functions and reduce the complexity here? It was really not clear to me that we had an inline lambda to wrangle with: my recollections from the refactoring 6 months ago do not include this.

Re 3: this should not be in this PR IMO. This is a material change that looks important by itself (bug fixing), I don't think we should have it in the middle of such a big PR with so much code complexity.

@MaheshRavishankar
Copy link
Contributor Author

My general thinking in such large PRs is to look out for potential causes of differences.
The ideal scenario is when tests don't change(i.e. the PR is NFC) and the APIs don't change too drastically.
So what I am evaluating here is whether this is good to go despite changes to test files of whether something material was changed.

The change to the test file is minimal. I think the only change (and please do verify) is the case where I think the existing lowering was wrong. Essentially in a purely dynamic case, you need to generate the affine.max. I dont know how the existing path avoided it, but it looks wrong. So I fixed the test. Please do correct my understanding if I am off here.

Beyond the API complexity increase / reduction of redundant IR, there may be a catch: I see this PR introduces --cse in one of the test. My immediate gut feeling is we are introducing more IR in other places. If --cse can be dropped then I think this first point is resolved.

CSE was added just to account for constants. The constants are really annoying with %[[C0]] and %[[C0_0]] and %[[C0_1]] and their ordering.

Re 2: calculateTileOffsetsAndSizes has related logic in the same code location. Now the logic seems split across getTileOffsetAndSizes and getTileSizes/getUserTileSizesAndNumThreads. The rename to getUserTileSizesAndNumThreads helps. So it seems we are doing a getUserTileSizesAndNumThreads before any loop is generated and later a refinemenent of getTileOffsetAndSizes within the body (passed via. lambda). The part that is putting all this within the same function is what was confusing I believe. Can we hoist out the lambdas as helper functions and reduce the complexity here? It was really not clear to me that we had an inline lambda to wrangle with: my recollections from the refactoring 6 months ago do not include this.

Ok, Ill move the inline-lambda out, but the lambda here is essentially playing the same roles as labmdas do for creating the loop body when used with rewriter.create<scf::For> or rewriter.create<scf::ForAll>. It is the loop body of the innermost tiled loop. I actually wanted to add this as a method on LoopLikeOpInterface but there were concerns that this is really just needed for tiling and does not belong in the interface. I dont want to add to the diff here though. Ill keep it as a inline-lambda and then I can change it. I am not for inline-lambdas as well, but in this case, it seemed warranted, but happy to change it as a follow up.

Re 3: this should not be in this PR IMO. This is a material change that looks important by itself (bug fixing), I don't think we should have it in the middle of such a big PR with so much code complexity.

I dont know what the original implementation was doing, and where it was going wrong. I didnt try to fix the bug, but the refactoring showed a change in the lit test and then I verified that the current PR is doing the right thing and the old path is wrong. I dont see much value in triaging a bug on a path that is being deleted.

@nicolasvasilache
Copy link
Contributor

nicolasvasilache commented Jun 1, 2024

I dont know what the original implementation was doing, and where it was going wrong. I didnt try to fix the bug, but the refactoring showed a change in the lit test and then I verified that the current PR is doing the right thing and the old path is wrong. I dont see much value in triaging a bug on a path that is being deleted.

Help me understand this better: you started refactoring things, then found a bug by some other means, changed behavior, now it is correct in your opinion but cannot pinpoint what was wrong.

Please elaborate on how the bug showed up, was it some IREE execution test ? is this something that did not show before this refactoring or was it always there ? Or is it just "the IR changed" when refactoring happened.

In any case, the MO here can be improved.
I agree there is little to no value in fixing the bug before the refactoring.
My opinion is there is strong negative value in conflating a material change we don't understand with a deep refactoring.

Your path forward here is to land the refactoring without changes to the affine.max part and then to address the bug separately after we understand what is going wrong and have a proper test for it.

@MaheshRavishankar
Copy link
Contributor Author

I dont know what the original implementation was doing, and where it was going wrong. I didnt try to fix the bug, but the refactoring showed a change in the lit test and then I verified that the current PR is doing the right thing and the old path is wrong. I dont see much value in triaging a bug on a path that is being deleted.

Help me understand this better: you started refactoring things, then found a bug by some other means, changed behavior, now it is correct in your opinion but cannot pinpoint what was wrong.

As far as I know I didn't change any behavior. I didn't try to fix a bug. I didn't even know there was a bug. I was trying to see differences in the lit test and I think the output of the lit test with the refactoring is correct. What was there before was wrong as far as I can tell.

Please elaborate on how the bug showed up, was it some IREE execution test ? is this something that did not show before this refactoring or was it always there ? Or is it just "the IR changed" when refactoring happened.

In any case, the MO here can be improved.
I agree there is little to no value in fixing the bug before the refactoring.
My opinion is there is strong negative value in conflating a material change we don't understand with a deep refactoring.

Your path forward here is to land the refactoring without changes to the affine.max part and then to address the bug separately after we understand what is going wrong and have a proper test for it.

There is no change to affine.max. This should be just refactoring. What you are suggesting is essentially I find the bug and introduce it back with the refactoring? Please take a look at the change in the lit tests (which are relatively minor). The changed tests are correct. I have no idea why the old path did something wrong, but it seems to be. And like you also agree above there is no value in triaging that.

Copy link

github-actions bot commented Jun 14, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@MaheshRavishankar
Copy link
Contributor Author

After deeper investigation, I see, the tests that change are not the ones with "num_threads", all changed tests are only related to the classical "tile_sizes" case...

Before this change, when tiling with tile sizes, we used to call:

return tileToForallOpImpl(b, op, numThreads,
                             /*nominalTileSizes=*/tileSizes, mapping,
                             /*omitTileOffsetBoundsCheck=*/true);

I.e. we explicitly omitted the affine.max because we statically know that tiling with tile sizes does not require it. This is the same reasoning for tiling with scf.for: we do not emit affine.max on the lower bound.

This change introduces affine.max in "regular tiling" with tile sizes and I believe this is redundant and unnecessary.

@MaheshRavishankar coming back to you point about:

Re 3 : I dont think there is any extra divisions. Most of the lit tests stay the same. The only ones that changed are the dynamic shape tests which I think were originally wrong (You cant really avoid generating the max operation).

Can you explain why you think they were originally wrong? My position is that you can avoid generating the max operation when tiling scf.forall with tile sizes for the same reason that you don't generate affine.max when tiling scf.for ..

Edit: this is the minimal test that now introduces an affine.max that was not here before:

func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C: tensor<?x?xf32>) -> tensor<?x?xf32> {
  %0 = linalg.matmul ins(%A, %B : tensor<?x?xf32>, tensor<?x?xf32>)
                    outs(%C : tensor<?x?xf32>) -> (tensor<?x?xf32>)
  return %0 : tensor<?x?xf32>
}

module attributes {transform.with_named_sequence} {
  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op
    %sz = transform.param.constant 10 : i64 -> !transform.param<i64>
    %1:2 = transform.structured.tile_using_forall %0 tile_sizes [%sz, 20]
           : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
    transform.yield
  }
}

I tried to replace transform.structured.tile_using_forall with transform.structured.tile_using_for on your branch to confirm we have a different behavior (i.e. that we do not produce the affine.max in the scf.for case) but this results in a crash in TileUsingForOp (I don't know if introduced by this PR, did not have time to investigate more):

#11 0x00007f60582315d8 mlir::transform::TileUsingForOp::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) /home/nico/llvm-project/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp:2801:7
#12 0x00007f6058189ca1 mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Model<mlir::transform::TileUsingForOp>::apply(mlir::transform::detail::TransformOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) /home/nico/llvm-project/build-Debug/tools/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h.inc:477:56
#13 0x00007f604995a0fe mlir::transform::TransformOpInterface::apply(mlir::transform::TransformRewriter&, mlir::transform::TransformResults&, mlir::transform::TransformState&) /home/nico/llvm-project/build-Debug/tools/mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.cpp.inc:61:14

@nicolasvasilache I updated the change, now all the lit tests are consistent with move to using scf::tileUsingSCF. With regard to the small test above, it works for me with this output

#map = affine_map<()[s0] -> (s0 ceildiv 10)>
#map1 = affine_map<()[s0] -> (s0 ceildiv 20)>
#map2 = affine_map<(d0) -> (d0 * 10)>
#map3 = affine_map<(d0) -> (d0 * 20)>
#map4 = affine_map<(d0)[s0] -> (d0 * -10 + s0, 10)>
#map5 = affine_map<(d0)[s0] -> (d0 * -20 + s0, 20)>
module {
  func.func @matmul_tile_size_dynamic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
    %dim_1 = tensor.dim %arg1, %c1 : tensor<?x?xf32>
    %0 = affine.apply #map()[%dim]
    %1 = affine.apply #map1()[%dim_1]
    %2 = scf.forall (%arg3, %arg4) in (%0, %1) shared_outs(%arg5 = %arg2) -> (tensor<?x?xf32>) {
      %3 = affine.apply #map2(%arg3)
      %4 = affine.apply #map3(%arg4)
      %5 = affine.min #map4(%arg3)[%dim]
      %6 = affine.min #map5(%arg4)[%dim_1]
      %extracted_slice = tensor.extract_slice %arg0[%3, 0] [%5, %dim_0] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
      %extracted_slice_2 = tensor.extract_slice %arg1[0, %4] [%dim_0, %6] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
      %extracted_slice_3 = tensor.extract_slice %arg5[%3, %4] [%5, %6] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
      %7 = linalg.matmul ins(%extracted_slice, %extracted_slice_2 : tensor<?x?xf32>, tensor<?x?xf32>) outs(%extracted_slice_3 : tensor<?x?xf32>) -> tensor<?x?xf32>
      scf.forall.in_parallel {
        tensor.parallel_insert_slice %7 into %arg5[%3, %4] [%5, %6] [1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
      }
    }
    return %2 : tensor<?x?xf32>
  }
  module attributes {transform.with_named_sequence} {
    transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
      %0 = transform.structured.match ops{["linalg.matmul"]} in %arg0 : (!transform.any_op) -> !transform.any_op
      %1 = transform.param.constant 10 : i64 -> !transform.param<i64>
      %tiled_op, %forall_op = transform.structured.tile_using_forall %0 tile_sizes [%1, 20] : (!transform.any_op, !transform.param<i64>) -> (!transform.any_op, !transform.any_op)
      transform.yield
    }
  }
}

And no affine.max as you pointed out.

@MaheshRavishankar
Copy link
Contributor Author

@nicolasvasilache I am hoping to land this next week as all concerns have been addressed afaics. pTAL

Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar force-pushed the deprecate_tile_to_forall branch from 76f79cf to 3375498 Compare July 30, 2024 20:42
@MaheshRavishankar
Copy link
Contributor Author

@nicolasvasilache I am coming back to this and would like to land this. All the lit tests are just reordering of instructions. So its fully NFC. Have also tested with IREE and no changes there either. I plan to land this on Friday if I dont hear from you.

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks!

MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jul 30, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jul 30, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jul 31, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
MaheshRavishankar added a commit to MaheshRavishankar/iree that referenced this pull request Jul 31, 2024
Signed-off-by: MaheshRavishankar <[email protected]>
@MaheshRavishankar MaheshRavishankar merged commit 6740d70 into llvm:main Jul 31, 2024
7 checks passed
bjacob pushed a commit to iree-org/iree that referenced this pull request Aug 1, 2024
Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes`

Signed-off-by: MaheshRavishankar <[email protected]>
bjacob pushed a commit to iree-org/iree that referenced this pull request Aug 1, 2024
Deprecate `linalg::tileToForallOp` and `linalg::tileToForallOpUsingTileSizes`

Signed-off-by: Mahesh Ravishankar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants