Skip to content

[mlir][scf] Extend consumer fuse to single nested scf.for #94190

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 12, 2024

Conversation

Yun-Fly
Copy link
Contributor

@Yun-Fly Yun-Fly commented Jun 3, 2024

Hi, based on early discussion in this thread. This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.

#map = affine_map<(d0) -> (d0 * 128)>
module {
  func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor<256x256xf32>
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -> tensor<256x256xf32> {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -> (tensor<128x128xf32>) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -> (tensor<128x128xf32>) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
          scf.yield %insert_slice : tensor<128x128xf32>
        }
        scf.yield %3 : tensor<128x128xf32>
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
      }
    }
    %5 = linalg.add ins(%1, %arg2 : tensor<256x256xf32>, tensor<256x256xf32>) outs(%dest0 : tensor<256x256xf32>) -> tensor<256x256xf32>
    return %5 : tensor<256x256xf32>
  }
}

What's New in this PR:

  1. support nested loop structure, including both scf.for and scf.forall.
  2. support multi-level insert_slice or parallel_insert_slice.

NOTE that: this PR DOES NOT deal with the refactor of getTiledImplementation we have talked before but just focuses on the functionality enhancement, BTW, in above example, you can also find that the similar issue related to unmatched semantic between tiled operand and assumption of current getTiledImplementation even on dpsInits. To unblock this necessary patch, I temporarily follow the method as @MaheshRavishankar suggested, using dummy insert_slice to align those gap.

The resulting IR will finally appear like below:

#map = affine_map<(d0) -> (d0 * 128)>
#map1 = affine_map<(d0, d1) -> (d0 + d1 * 128)>
module {
  module {
    func.func @fuse_tilable_consumer_nested_scf_loop(%arg0: tensor<256x512xf32>, %arg1: tensor<512x256xf32>, %arg2: tensor<256x256xf32>) -> tensor<256x256xf32> {
      %c0 = arith.constant 0 : index
      %c64 = arith.constant 64 : index
      %c128 = arith.constant 128 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = tensor.empty() : tensor<256x256xf32>
      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor<256x256xf32>) -> tensor<256x256xf32>
      %2:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1, %arg6 = %0) -> (tensor<256x256xf32>, tensor<256x256xf32>) {
        %3 = affine.apply #map(%arg3)
        %4 = affine.apply #map(%arg4)
        %extracted_slice = tensor.extract_slice %arg5[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %extracted_slice_0 = tensor.extract_slice %arg0[%3, 0] [128, 512] [1, 1] : tensor<256x512xf32> to tensor<128x512xf32>
        %extracted_slice_1 = tensor.extract_slice %arg1[0, %4] [512, 128] [1, 1] : tensor<512x256xf32> to tensor<512x128xf32>
        %extracted_slice_2 = tensor.extract_slice %arg6[%3, %4] [128, 128] [1, 1] : tensor<256x256xf32> to tensor<128x128xf32>
        %5:2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice, %arg9 = %extracted_slice_2) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
          %6:2 = scf.for %arg10 = %c0 to %c128 step %c64 iter_args(%arg11 = %arg8, %arg12 = %arg9) -> (tensor<128x128xf32>, tensor<128x128xf32>) {
            %extracted_slice_3 = tensor.extract_slice %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %extracted_slice_4 = tensor.extract_slice %extracted_slice_0[%arg7, 0] [64, 512] [1, 1] : tensor<128x512xf32> to tensor<64x512xf32>
            %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[0, %arg10] [512, 64] [1, 1] : tensor<512x128xf32> to tensor<512x64xf32>
            %7 = linalg.matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor<64x512xf32>, tensor<512x64xf32>) outs(%extracted_slice_3 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %8 = affine.apply #map1(%arg7, %arg3)
            %9 = affine.apply #map1(%arg10, %arg4)
            %extracted_slice_6 = tensor.extract_slice %arg2[%8, %9] [64, 64] [1, 1] : tensor<256x256xf32> to tensor<64x64xf32>
            %extracted_slice_7 = tensor.extract_slice %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<128x128xf32> to tensor<64x64xf32>
            %10 = linalg.add ins(%7, %extracted_slice_6 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%extracted_slice_7 : tensor<64x64xf32>) -> tensor<64x64xf32>
            %inserted_slice = tensor.insert_slice %7 into %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            %inserted_slice_8 = tensor.insert_slice %10 into %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor<64x64xf32> into tensor<128x128xf32>
            scf.yield %inserted_slice, %inserted_slice_8 : tensor<128x128xf32>, tensor<128x128xf32>
          }
          scf.yield %6#0, %6#1 : tensor<128x128xf32>, tensor<128x128xf32>
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5#1 into %arg6[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
          tensor.parallel_insert_slice %5#0 into %arg5[%3, %4] [128, 128] [1, 1] : tensor<128x128xf32> into tensor<256x256xf32>
        }
      }
      return %2#1 : tensor<256x256xf32>
    }
  }
}

Looking forward to your suggestion and review, thanks.

Copy link

github-actions bot commented Jun 3, 2024

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be
notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write
permissions for the repository. In which case you can instead tag reviewers by
name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review
by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate
is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jun 3, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-scf

Author: None (Yun-Fly)

Changes

Hi, based on early discussion in this thread. This patch aims to extend new feature of fusing consumer to more complex nested loop structure. E.g.

#map = affine_map&lt;(d0) -&gt; (d0 * 128)&gt;
module {
  func.func @<!-- -->fuse_tilable_consumer_nested_scf_loop(%arg0: tensor&lt;256x512xf32&gt;, %arg1: tensor&lt;512x256xf32&gt;, %arg2: tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt; {
    %c0 = arith.constant 0 : index
    %c64 = arith.constant 64 : index
    %c128 = arith.constant 128 : index
    %cst = arith.constant 0.000000e+00 : f32
    %dest0 = tensor.empty() : tensor&lt;256x256xf32&gt;
    %dest1 = linalg.fill ins(%cst : f32) outs(%dest0 : tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt;
    %1 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %dest1) -&gt; tensor&lt;256x256xf32&gt; {
      %iv0 = affine.apply #map(%arg3)
      %iv1 = affine.apply #map(%arg4)
      %extracted_slice_1 = tensor.extract_slice %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor&lt;256x256xf32&gt; to tensor&lt;128x128xf32&gt;
      %extracted_slice_2 = tensor.extract_slice %arg0[%iv0, 0] [128, 512] [1, 1] : tensor&lt;256x512xf32&gt; to tensor&lt;128x512xf32&gt;
      %extracted_slice_3 = tensor.extract_slice %arg1[0, %iv1] [512, 128] [1, 1] : tensor&lt;512x256xf32&gt; to tensor&lt;512x128xf32&gt;
      %2 = scf.for %arg6 = %c0 to %c128 step %c64 iter_args(%arg7 = %extracted_slice_1) -&gt; (tensor&lt;128x128xf32&gt;) {
        %3 = scf.for %arg8 = %c0 to %c128 step %c64 iter_args(%arg9 = %arg7) -&gt; (tensor&lt;128x128xf32&gt;) {
          %extracted_slice_4 = tensor.extract_slice %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor&lt;128x128xf32&gt; to tensor&lt;64x64xf32&gt;
          %extracted_slice_5 = tensor.extract_slice %extracted_slice_2[%arg6, 0] [64, 512] [1, 1] : tensor&lt;128x512xf32&gt; to tensor&lt;64x512xf32&gt;
          %extracted_slice_6 = tensor.extract_slice %extracted_slice_3[0, %arg8] [512, 64] [1, 1] : tensor&lt;512x128xf32&gt; to tensor&lt;512x64xf32&gt;
          %4 = linalg.matmul ins(%extracted_slice_5, %extracted_slice_6 : tensor&lt;64x512xf32&gt;, tensor&lt;512x64xf32&gt;) outs(%extracted_slice_4 : tensor&lt;64x64xf32&gt;) -&gt; tensor&lt;64x64xf32&gt;
          %insert_slice = tensor.insert_slice %4 into %arg9[%arg6, %arg8] [64, 64] [1, 1] : tensor&lt;64x64xf32&gt; into tensor&lt;128x128xf32&gt;
          scf.yield %insert_slice : tensor&lt;128x128xf32&gt;
        }
        scf.yield %3 : tensor&lt;128x128xf32&gt;
      }
      scf.forall.in_parallel {
         tensor.parallel_insert_slice %2 into %arg5[%iv0, %iv1] [128, 128] [1, 1] : tensor&lt;128x128xf32&gt; into tensor&lt;256x256xf32&gt;
      }
    }
    %5 = linalg.add ins(%1, %arg2 : tensor&lt;256x256xf32&gt;, tensor&lt;256x256xf32&gt;) outs(%dest0 : tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt;
    return %5 : tensor&lt;256x256xf32&gt;
  }
}

What's New in this PR:

  1. support nested loop structure, including both scf.for and scf.forall.
  2. support multi-level insert_slice or parallel_insert_slice.

NOTE that: this PR DOES NOT deal with the refactor of getTiledImplementation we have talked before but just focuses on the functionality enhancement, BTW, in above example, you can also find that the similar issue related to unmatched semantic between tiled operand and assumption of current getTiledImplementation even on dpsInits. To unblock this necessary patch, I temporarily follow the method as @MaheshRavishankar suggested, using dummy insert_slice to align those gap.

The resulting IR will finally appear like below:

#map = affine_map&lt;(d0) -&gt; (d0 * 128)&gt;
#map1 = affine_map&lt;(d0, d1) -&gt; (d0 + d1 * 128)&gt;
module {
  module {
    func.func @<!-- -->fuse_tilable_consumer_nested_scf_loop(%arg0: tensor&lt;256x512xf32&gt;, %arg1: tensor&lt;512x256xf32&gt;, %arg2: tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt; {
      %c0 = arith.constant 0 : index
      %c64 = arith.constant 64 : index
      %c128 = arith.constant 128 : index
      %cst = arith.constant 0.000000e+00 : f32
      %0 = tensor.empty() : tensor&lt;256x256xf32&gt;
      %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor&lt;256x256xf32&gt;) -&gt; tensor&lt;256x256xf32&gt;
      %2:2 = scf.forall (%arg3, %arg4) in (2, 2) shared_outs(%arg5 = %1, %arg6 = %0) -&gt; (tensor&lt;256x256xf32&gt;, tensor&lt;256x256xf32&gt;) {
        %3 = affine.apply #map(%arg3)
        %4 = affine.apply #map(%arg4)
        %extracted_slice = tensor.extract_slice %arg5[%3, %4] [128, 128] [1, 1] : tensor&lt;256x256xf32&gt; to tensor&lt;128x128xf32&gt;
        %extracted_slice_0 = tensor.extract_slice %arg0[%3, 0] [128, 512] [1, 1] : tensor&lt;256x512xf32&gt; to tensor&lt;128x512xf32&gt;
        %extracted_slice_1 = tensor.extract_slice %arg1[0, %4] [512, 128] [1, 1] : tensor&lt;512x256xf32&gt; to tensor&lt;512x128xf32&gt;
        %extracted_slice_2 = tensor.extract_slice %arg6[%3, %4] [128, 128] [1, 1] : tensor&lt;256x256xf32&gt; to tensor&lt;128x128xf32&gt;
        %5:2 = scf.for %arg7 = %c0 to %c128 step %c64 iter_args(%arg8 = %extracted_slice, %arg9 = %extracted_slice_2) -&gt; (tensor&lt;128x128xf32&gt;, tensor&lt;128x128xf32&gt;) {
          %6:2 = scf.for %arg10 = %c0 to %c128 step %c64 iter_args(%arg11 = %arg8, %arg12 = %arg9) -&gt; (tensor&lt;128x128xf32&gt;, tensor&lt;128x128xf32&gt;) {
            %extracted_slice_3 = tensor.extract_slice %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor&lt;128x128xf32&gt; to tensor&lt;64x64xf32&gt;
            %extracted_slice_4 = tensor.extract_slice %extracted_slice_0[%arg7, 0] [64, 512] [1, 1] : tensor&lt;128x512xf32&gt; to tensor&lt;64x512xf32&gt;
            %extracted_slice_5 = tensor.extract_slice %extracted_slice_1[0, %arg10] [512, 64] [1, 1] : tensor&lt;512x128xf32&gt; to tensor&lt;512x64xf32&gt;
            %7 = linalg.matmul ins(%extracted_slice_4, %extracted_slice_5 : tensor&lt;64x512xf32&gt;, tensor&lt;512x64xf32&gt;) outs(%extracted_slice_3 : tensor&lt;64x64xf32&gt;) -&gt; tensor&lt;64x64xf32&gt;
            %8 = affine.apply #map1(%arg7, %arg3)
            %9 = affine.apply #map1(%arg10, %arg4)
            %extracted_slice_6 = tensor.extract_slice %arg2[%8, %9] [64, 64] [1, 1] : tensor&lt;256x256xf32&gt; to tensor&lt;64x64xf32&gt;
            %extracted_slice_7 = tensor.extract_slice %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor&lt;128x128xf32&gt; to tensor&lt;64x64xf32&gt;
            %10 = linalg.add ins(%7, %extracted_slice_6 : tensor&lt;64x64xf32&gt;, tensor&lt;64x64xf32&gt;) outs(%extracted_slice_7 : tensor&lt;64x64xf32&gt;) -&gt; tensor&lt;64x64xf32&gt;
            %inserted_slice = tensor.insert_slice %7 into %arg11[%arg7, %arg10] [64, 64] [1, 1] : tensor&lt;64x64xf32&gt; into tensor&lt;128x128xf32&gt;
            %inserted_slice_8 = tensor.insert_slice %10 into %arg12[%arg7, %arg10] [64, 64] [1, 1] : tensor&lt;64x64xf32&gt; into tensor&lt;128x128xf32&gt;
            scf.yield %inserted_slice, %inserted_slice_8 : tensor&lt;128x128xf32&gt;, tensor&lt;128x128xf32&gt;
          }
          scf.yield %6#<!-- -->0, %6#<!-- -->1 : tensor&lt;128x128xf32&gt;, tensor&lt;128x128xf32&gt;
        }
        scf.forall.in_parallel {
          tensor.parallel_insert_slice %5#<!-- -->1 into %arg6[%3, %4] [128, 128] [1, 1] : tensor&lt;128x128xf32&gt; into tensor&lt;256x256xf32&gt;
          tensor.parallel_insert_slice %5#<!-- -->0 into %arg5[%3, %4] [128, 128] [1, 1] : tensor&lt;128x128xf32&gt; into tensor&lt;256x256xf32&gt;
        }
      }
      return %2#<!-- -->1 : tensor&lt;256x256xf32&gt;
    }
  }
}

Looking forward to your suggestion and review, thanks.


Patch is 46.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/94190.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+511-225)
  • (modified) mlir/test/Interfaces/TilingInterface/tile-and-fuse-consumer.mlir (+96)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a530117..9dd730e64a030 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -13,6 +13,7 @@
 #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
 
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Affine/Utils.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Arith/Utils/Utils.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -1103,98 +1104,6 @@ mlir::scf::tileConsumerAndFuseProducersUsingSCF(
 // tileAndFuseConsumerUsingSCF implementation.
 //===----------------------------------------------------------------------===//
 
-/// A utility function that checks whether the only use of the result of a
-/// tensor.insert_slice op is in a scf.yield op.
-static LogicalResult
-checkAssumptionForFusingConsumer(tensor::InsertSliceOp candidateSliceOp) {
-  Value result = candidateSliceOp.getResult();
-  Value::use_range uses = result.getUses();
-  if (!llvm::hasSingleElement(uses)) {
-    LLVM_DEBUG(llvm::dbgs() << "Too many uses of the candidate slice op\n");
-    return failure();
-  }
-  OpOperand &operandUse = (*uses.begin());
-  Operation *userOp = operandUse.getOwner();
-  if (!isa<scf::YieldOp>(userOp)) {
-    LLVM_DEBUG(llvm::dbgs()
-               << "Expected scf.yield to be the only user, but got -> "
-               << (*userOp));
-    return failure();
-  }
-  if (result.getDefiningOp()->getBlock() != userOp->getBlock()) {
-    LLVM_DEBUG(llvm::dbgs() << "Expected tensor.insert_slice and scf.yield to "
-                               "be in the same block\n");
-    return failure();
-  }
-  return success();
-}
-
-/// Fetches the OpOperand of the only user (and use) of the value `val` which
-/// implements `TilingInterface` and `DestinationStyleOpInterface`. Returns
-/// failure otherwise.
-static FailureOr<OpOperand *> getConsumerFromUses(Value val,
-                                                  Block *containingOpBlock) {
-  // Step 1. Check that the value has exactly one use.
-  if (!llvm::hasSingleElement(val.getUses()))
-    return failure();
-  // Step 2. Get uses.
-  OpOperand &operand = (*val.getUses().begin());
-  Operation *consumerOp = operand.getOwner();
-  // TODO: We have to init result of consumer before scf.for, use
-  //       DestinationStyleOpInterface to get result shape from init for now.
-  //       Add support for other op such as op has InferTypeOpInterface.
-  if (!isa<TilingInterface>(consumerOp) ||
-      !isa<DestinationStyleOpInterface>(consumerOp))
-    return failure();
-  if (containingOpBlock != consumerOp->getBlock())
-    return failure();
-  return &operand;
-}
-
-/// Fetch the untiled consumer of a scf.for's result which is yielded by a
-/// tensor.insert_slice. This function makes the following assumptions :
-/// 1.  tensor.insert_slice has scf.yield as its only user.
-/// 2.  scf.for's corresponding result has only one use.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::InsertSliceOp candidateSliceOp) {
-  if (failed(checkAssumptionForFusingConsumer(candidateSliceOp)))
-    return failure();
-  Value sliceResult = candidateSliceOp.getResult();
-  // Step 1. Fetch the corresponding output.
-  OpOperand &yieldOpOperand = (*sliceResult.getUses().begin());
-  unsigned resultNumber = yieldOpOperand.getOperandNumber();
-  // Step 2. Check containing op is scf.for.
-  Operation *containingOp = candidateSliceOp->getParentOp();
-  auto forOp = dyn_cast<scf::ForOp>(containingOp);
-  if (!forOp)
-    return failure();
-  Value resultingValue = forOp->getResult(resultNumber);
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
-/// Fetch the first untiled consumer of a scf.forall's result which is yielded
-/// by a tensor.parallel_insert_slice.
-static FailureOr<OpOperand *>
-getUntiledConsumerFromSlice(tensor::ParallelInsertSliceOp candidateSliceOp) {
-  // Step 1. Fetch the corresponding output
-  Value sliceDest = candidateSliceOp.getDest();
-  auto iterArg = dyn_cast<BlockArgument>(sliceDest);
-  if (!iterArg)
-    return failure();
-  Operation *containingOp = iterArg.getOwner()->getParentOp();
-  if (containingOp != candidateSliceOp->getParentOp()->getParentOp())
-    return failure();
-  // Step 2. Check that the containing op is scf.forall.
-  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
-  if (!forallOp)
-    return failure();
-  Value resultingValue =
-      forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
-
-  return getConsumerFromUses(resultingValue, containingOp->getBlock());
-}
-
 /// This utility currently checks whether the loop either :-
 /// 1. Yields exactly one result.
 /// 2. Has consumer op as its first user and other users to be in the same
@@ -1220,31 +1129,116 @@ static LogicalResult checkAssumptionForLoop(Operation *loopOp,
   return success();
 }
 
-/// A utility to fetch an untiled consumer of
-/// tensor.insert_slice/tensor.parallel_insert_slice.
-static FailureOr<OpOperand *> getUntiledConsumerFromSlice(Operation *sliceOp) {
-  if (auto insertSlice = dyn_cast<tensor::InsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(insertSlice);
-  } else if (auto parallelInsertSlice =
-                 dyn_cast<tensor::ParallelInsertSliceOp>(sliceOp)) {
-    return getUntiledConsumerFromSlice(parallelInsertSlice);
-  } else {
+// Traverse and collect all outer loops of given sliceOp, sorted by
+// outer-to-inner. If `untilLoop` found, stop walk through in advance.
+static SmallVector<LoopLikeOpInterface> getOuterLoopsOfSliceOp(
+    OffsetSizeAndStrideOpInterface sliceOp,
+    std::optional<LoopLikeOpInterface> untilLoop = std::nullopt) {
+  SmallVector<LoopLikeOpInterface> outerLoops;
+  auto forOp = sliceOp->getParentOfType<LoopLikeOpInterface>();
+  while (forOp) {
+    outerLoops.push_back(forOp);
+    if (untilLoop.has_value() && *untilLoop == forOp)
+      break;
+    forOp = forOp->getParentOfType<LoopLikeOpInterface>();
+  }
+  return {outerLoops.rbegin(), outerLoops.rend()};
+}
+
+// Get the Result of top-level Loop which yield the target InsertSliceOp. E.g
+// ```
+// %1 = scf.for
+//  %2 = scf.for
+//   %3 = scf.for
+//      ...
+//      %4 = insert
+//      yield %4
+//   %5 = insert %3
+//   yield %5
+//  yield %2
+// ```
+// @param targetSliceOp: %4 = insert
+// @return Result Value: %1
+//         Collected insertSliceOp List during walk including targetSliceOp:
+//                %4 = insert and %5 = insert %3
+static FailureOr<std::pair<Value, SmallVector<OffsetSizeAndStrideOpInterface>>>
+getResultOfTopLevelLoopYieldInsertSliceOp(
+    OffsetSizeAndStrideOpInterface targetSliceOp, int curDepth = 0,
+    int maxDepth = 5) {
+  // control recursive time in avoid of stack overflow
+  if (curDepth > maxDepth)
+    return failure();
+
+  SmallVector<OffsetSizeAndStrideOpInterface> candidateSliceOpList;
+  candidateSliceOpList.push_back(targetSliceOp);
+  Value resultOfLoop;
+  if (auto sliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(
+          targetSliceOp.getOperation())) {
+    Value destValue = sliceOp.getDest();
+    auto iterArg = cast<BlockArgument>(destValue);
+    auto forallOp = dyn_cast<scf::ForallOp>(iterArg.getOwner()->getParentOp());
+    if (!forallOp)
+      return failure();
+    resultOfLoop = forallOp.getTiedOpResult(forallOp.getTiedOpOperand(iterArg));
+  } else if (auto sliceOp = dyn_cast<tensor::InsertSliceOp>(
+                 targetSliceOp.getOperation())) {
+    Value resultValue = sliceOp.getResult();
+    for (auto &useOperand : resultValue.getUses()) {
+      if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        if (llvm::detail::isPresent(resultOfLoop))
+          return failure();
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+      }
+    }
+  }
+
+  if (!llvm::detail::isPresent(resultOfLoop))
     return failure();
+
+  while (true) {
+    bool walkThroughOuterLoop = false;
+    for (auto &useOperand : resultOfLoop.getUses()) {
+      if (auto sliceOp =
+              dyn_cast<OffsetSizeAndStrideOpInterface>(useOperand.getOwner())) {
+        auto resultAndSliceOpsPair =
+            getResultOfTopLevelLoopYieldInsertSliceOp(sliceOp, curDepth + 1);
+        if (failed(resultAndSliceOpsPair))
+          return failure();
+        candidateSliceOpList.append((*resultAndSliceOpsPair).second.begin(),
+                                    (*resultAndSliceOpsPair).second.end());
+        return std::make_pair((*resultAndSliceOpsPair).first,
+                              candidateSliceOpList);
+      } else if (auto yieldOp = dyn_cast<scf::YieldOp>(useOperand.getOwner())) {
+        // walk through outer loop
+        auto forOp = dyn_cast<LoopLikeOpInterface>(yieldOp->getParentOp());
+        if (!forOp)
+          return failure();
+        resultOfLoop = forOp->getResult(useOperand.getOperandNumber());
+        walkThroughOuterLoop = true;
+        break;
+      }
+    }
+    if (!walkThroughOuterLoop)
+      break;
   }
+  return std::make_pair(resultOfLoop, candidateSliceOpList);
 }
 
 /// After fusing consumer into scf.for we want to modify the scf.yield operation
 /// to reflect the same by returning the values yielded by the tiled consumer.
 static void
 fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
-                      TilingResult &tilingResult,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                      ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
+                      ResultRange tilingResult,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
                       ArrayRef<BlockArgument> bbArgs) {
   scf::YieldOp oldTerminatorOp =
       cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
   unsigned totalOldResults = oldTerminatorOp->getNumResults();
-  unsigned totalTiledResults = tilingResult.tiledOps[0]->getNumResults();
+  unsigned totalTiledResults = tilingResult.size();
   SmallVector<Value> newYieldOperands;
   newYieldOperands.reserve(totalOldResults + totalTiledResults);
   for (auto oldResult : oldTerminatorOp.getResults()) {
@@ -1253,8 +1247,7 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
   rewriter.setInsertionPointAfter(oldTerminatorOp);
   Location loc = newForOp.getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tilingResult.tiledOps[0]->getResults(), bbArgs,
-                       resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     Value newInsertSliceOp = rewriter.create<tensor::InsertSliceOp>(
@@ -1267,18 +1260,17 @@ fixTerminatorSCFYield(RewriterBase &rewriter, scf::ForOp newForOp,
 
 /// After fusing consumer into scf.forall we want to yield each of the resulting
 /// values by the tiled consumer within scf.forall.in_parallel region.
-static void
-fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
-                           SmallVector<Value> tiledResults,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultOffsets,
-                           ArrayRef<SmallVector<OpFoldResult>> &resultSizes,
-                           ArrayRef<BlockArgument> bbArgs) {
+static void fixTerminatorSCFInParallel(
+    RewriterBase &rewriter, scf::ForallOp newForallOp, ResultRange tilingResult,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+    ArrayRef<BlockArgument> bbArgs) {
   scf::InParallelOp newTerminatorOp = newForallOp.getTerminator();
   rewriter.setInsertionPointToStart(newTerminatorOp.getBody());
   Location firstYieldOpLoc =
       (*(newTerminatorOp.getYieldingOps().begin())).getLoc();
   for (auto [tiledResult, bbArg, resultOffset, resultSize] :
-       llvm::zip_equal(tiledResults, bbArgs, resultOffsets, resultSizes)) {
+       llvm::zip_equal(tilingResult, bbArgs, resultOffsets, resultSizes)) {
     SmallVector<OpFoldResult> strides(resultOffset.size(),
                                       rewriter.getIndexAttr(1));
     rewriter.create<tensor::ParallelInsertSliceOp>(
@@ -1286,6 +1278,180 @@ fixTerminatorSCFInParallel(RewriterBase &rewriter, scf::ForallOp newForallOp,
   }
 }
 
+// If the top level loop of nested loop structure is scf.forall, need to create
+// additional tensor.extract_slice for its new appended `shared_outs` in order
+// to pass correct local memory for inner loops. E.g.
+//
+// scf.forall shared_outs(%o1=..., %o2=...) {
+//     %local_o1 = extract_slice %o1
+//     // fix new appended `shared_out` %o2
+//     %local_o2 = extract_slice %o2
+//     scf.for init_args(%init1=%local_o1, %init2=%local_o2) {
+//        ...
+//     }
+//     ...
+// }
+static void
+fixSharedOutSCFForall(RewriterBase &rewriter, scf::ForallOp outerLoop,
+                      LoopLikeOpInterface innerLoop,
+                      SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+                      SmallVector<SmallVector<OpFoldResult>> &resultSizes,
+                      unsigned newInitSize,
+                      SmallVector<tensor::ExtractSliceOp> &newExtractOps) {
+  rewriter.setInsertionPoint(innerLoop);
+  Location Loc = outerLoop.getLoc();
+  MutableArrayRef<BlockArgument> bbArgs = outerLoop.getBody()->getArguments();
+
+  SmallVector<tensor::ExtractSliceOp> newOps;
+  newOps.reserve(resultOffsets.size());
+  for (auto [bbArg, offset, sizes] : llvm::zip_equal(
+           bbArgs.take_back(newInitSize), resultOffsets, resultSizes)) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    newOps.push_back(newExtractOp);
+  }
+  newExtractOps = newOps;
+}
+
+// If outerMost loop of nested loop structure is `scf.forall`, need to deal with
+// DpsInit of tiled consumer
+static void fixDpsInitsOfTiledConsumer(
+    RewriterBase &rewriter, Operation *tiledConsumer,
+    ArrayRef<BlockArgument> bbArgs,
+    SmallVector<SmallVector<OpFoldResult>> &resultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &resultSizes) {
+  rewriter.setInsertionPoint(tiledConsumer);
+  Location Loc = tiledConsumer->getLoc();
+  for (auto &&[bbArg, offset, sizes, dpsInit] :
+       llvm::zip_equal(bbArgs, resultOffsets, resultSizes,
+                       cast<DestinationStyleOpInterface>(tiledConsumer)
+                           .getDpsInitsMutable())) {
+    SmallVector<OpFoldResult> strides(offset.size(), rewriter.getIndexAttr(1));
+    auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
+        Loc, bbArg, offset, sizes, strides);
+    dpsInit.set(newExtractOp.getResult());
+  }
+}
+
+// compute all results tile by given SliceOp along operand
+static LogicalResult computeAllResultTileForOpGivenOperandSliceOp(
+    RewriterBase &rewriter, TilingInterface tilableOp, unsigned operandNumber,
+    OffsetSizeAndStrideOpInterface ossSliceOp,
+    SmallVector<SmallVector<OpFoldResult>> &allResultOffsets,
+    SmallVector<SmallVector<OpFoldResult>> &allResultSizes) {
+  // 1. check all stride all 1
+  if (llvm::any_of(ossSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(ossSliceOp, "ossSliceOp has stride");
+  }
+  // 2. compute iteration domain Tile from input position
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+  if (failed(tilableOp.getIterationDomainTileFromOperandTile(
+          rewriter, operandNumber, ossSliceOp.getMixedOffsets(),
+          ossSliceOp.getMixedSizes(), iterDomainOffsets, iterDomainSizes))) {
+    return rewriter.notifyMatchFailure(
+        tilableOp, "can't get iter domain position from input position");
+  }
+  unsigned totalNumResultsOfConsumer = tilableOp->getNumResults();
+  SmallVector<SmallVector<OpFoldResult>> resultOffsets(
+      totalNumResultsOfConsumer);
+  SmallVector<SmallVector<OpFoldResult>> resultSizes(totalNumResultsOfConsumer);
+  // 3. compute result Tile by resultNumber
+  for (auto [idx, v] : llvm::enumerate(tilableOp->getResults())) {
+    if (failed(tilableOp.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultOffsets[idx], resultSizes[idx]))) {
+      return rewriter.notifyMatchFailure(
+          tilableOp,
+          "can't get result domain position from iter domain position");
+    }
+  }
+  allResultOffsets = resultOffsets;
+  allResultSizes = resultSizes;
+  return success();
+}
+
+// Considering multi-level tensor.*SliceOp maybe based on different
+// coordination, this utility computes the real OFFSET coordinated on ROOT
+// SliceOp. E.g
+//             %0 = insert_slice %1 into %2[OFFSET1] [SIZE1]
+//         %3 = insert_slice %4 into %5[OFFSET2] [SIZE2]
+//
+// where the coordination can be illustrated as follow:
+//
+//  %3 ----------------------------------
+//  |         |         |
+//  | OFFSET2 | OFFSET1 |
+//  | ------ %0         |
+//  |                   |
+//  |                   |
+//  |------------------ %1 ------ |
+//  |                   |  SIZE1  |
+//  |                   |         |
+//  |                   |         |
+//  |                   | ------- |
+//  |
+//
+// The real OFFSET of %1 coordinated on %3 is actually `OFFSET1` + `OFFSET2`
+static FailureOr<SmallVector<OpFoldResult>>
+computeRealOffsetsCoordinatedRootSliceOp(
+    RewriterBase &rewriter, Location loc,
+    OffsetSizeAndStrideOpInterface candidateSliceOp,
+    MutableArrayRef<OffsetSizeAndStrideOpInterface> candidateSliceOpList) {
+  if (llvm::any_of(candidateSliceOp.getMixedStrides(), [](OpFoldResult stride) {
+        return !isConstantIntValue(stride, 1);
+      })) {
+    return rewriter.notifyMatchFailure(candidateSliceOp,
+                                       "candidateSliceOp has stride");
+  }
+  SmallVector<OpFoldResult> realOffsets = candidateSliceOp.getMixedOffsets();
+  // real offsets equals to accumulative offsets of outer candidates
+  for (auto iter = candidateSliceOpList.rbegin(); *iter != candidateSliceOp;
+       iter++) {
+    // assert each outer candidate slice has no stride
+    if (llvm::any_of(iter->getMixedStrides(), [](OpFoldResult stride) {
+          return !isConstantIntValue(stride, 1);
+        })) {
+      return failure();
+    }
+    for (auto &&[ofr1, ofr2] :
+         llvm::zip_equal(realOffsets, iter->getMixedOffsets())) {
+      using AVE = affine::AffineValueExpr;
+      affine::AffineBuilder ab(rewriter, loc);
+      AffineExpr dim0, dim1, sym;
+      bindDims(rewriter.getContext(), dim0, dim1);
+      bindSymbols(rewriter.getContext(), sym);
+      auto aveOffset1 = AVE(dim0).bind(ofr1), aveOffset2 = AVE(dim1).bind(ofr2);
+      ofr1 = ab.add(aveOffset1, aveOffset2);
+    }
+  }
+  return realOffsets;
+}
+
+// Get the first tilable user of given Value and check its domination at the
+// same time
+static FailureOr<OpOperand *>
+getTilableConsumerOperandFirstUseVal(Value val, Operation *loopOp) {
+  for (auto &useOfval : val.getUses()) {
+    Operation *consumerOp = useOfval.getOwner();
+    // 1. Check whether consumerOp is tilable
+    if (!isa<TilingInterface>(consumerOp) ||
+        !isa<DestinationStyleOpInterface>(consumerOp))
+      continue;
+    // 2. check stay in same block with loopOp
+    if (loopOp->getBlock() != consumerOp->getBlock())
+      continue;
+    // 3. check no other user before it
+    if (failed(checkAssumptionForLoop(loopOp, consumerOp))) {
+      continue;
+    }
+    return &useOfval;
+  }
+  return failure();
+}
+
 /// Implementation of fusing consumer of a single slice by computing the
 /// slice of the consumer in-place for scf loop.
 FailureOr<scf::SCFFuseConsumerOfSliceResult>
@@ -1297,10 +1463,29 @@ mlir::scf::tileAndFuseConsumerOfSlice(RewriterBase &rewriter,
 
   bool is...
[truncated]

@llvmbot llvmbot added the mlir label Jun 3, 2024
Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Yun-Fly - thanks for starting on this!

A few starter nit comments from my end.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to look deeper. This is changing things much more than I would expect.

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch 3 times, most recently from 1ab45c1 to 74e3119 Compare June 5, 2024 07:56
@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch from 74e3119 to 9c04ad4 Compare June 5, 2024 13:29
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 5, 2024

CI issue has been solved, ready for review :)

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems to me like this should be done by multiple application of existing transformations and not by creating a new custom transformation that unrolls both in C++

@MaheshRavishankar
Copy link
Contributor

Seems to me like this should be done by multiple application of existing transformations and not by creating a new custom transformation that unrolls both in C++

I agree with Nicolas' comment here. This is the way tile and fuse is supposed to work. You start with

%0 = <producer>
%1 = <consumer>(...%0...)

First tile the consumer

%0 = <producer>
%1 = scf.for ... shared_outs/init(%arg0 =...) {
   %2 = scf.for ... shared_outs/init(%arg1 = %arg0) {
      %3 = <consumer>(...%0...)
       %4 = ... insert_slice %3 into %arg1 ...
       scf.yield %4 
   }
   scf.yield %2
}

then you fuse %0 within the scf.for nest that is created during tiling of consumer to get

%1 = scf.for ... shared_outs/init(%arg0 =...) {
   %2 = scf.for ... shared_outs/init(%arg1 = %arg0) {
     %0 = <producer>
      %3 = <consumer>(...%0...)
       %4 = ... insert_slice %3 into %arg1 ...
       scf.yield %4 
   }
   scf.yield %2
}

So you are fusing the operation into an "immediately created" tiled loop nest. The more general case you are looking for can be done through repeated application.

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 6, 2024

Hi, @nicolasvasilache @MaheshRavishankar , try to reply both in one thread.

this should be done by multiple application of existing transformations

Could you detail more about how to apply multiple existing transformations by an example?

First tile the consumer
....
then you fuse %0 within the scf.for nest that is created during tiling of consumer to get

  1. The difference is the fusion direction: consumer-to-producer or producer-to-consumer. IMO, this is two different but both feasible solution for fusion transform. In general, it should also be functionally enabled and provide an option for users to select case by case. I guess what you mean here is tileConsumerAndFuseProducersUsingSCF using tileAndFuseProducerOfSlice. But, as the counterpart, this patch targets on another technical path tileAndFuseConsumerOfSlice, just as same as previous merged PR which does not support nested loop structure currently.
  2. From tiling perspective, the major difference between consumer-to-producer or producer-to-consumer is that which one takes higher priority to decide how to partition the tiling size by iteration domain. For instance, if we tile consumer first and then fuse producer as you illustrated:
    a. the tiling size of producer comes from tiled consumer by tiling propagation based on AffineMap.
    b. producer has to force itself to fit the iteration domain already generated by consumer, which may bring redundant iteration loop.
  3. Based on 2, a typical use-case where producer-to-consumer maybe more suitable than consumer-to-producer is that matmul+post-op fusion. As you known, matmul is computation sensitive and many developers have strong demand on hand-writing user-defined template with nested and complex loop to deal with multi-level tile size for peek performance, particularly for either GPU or CPU. If we start fusion with tiling post-op(like relu), the computation of matmul will put up with an elementwise operation.

Again, this patch is the extension of already merged PR involving producer-to-consumer fusion as well.

CC: @ZhennanQin.

@MaheshRavishankar
Copy link
Contributor

Lets start with your example above. I think your input is

%0 = linalg.fill .. outs(%empty)
%1 = linalg.matmul ... outs(%0)
%2 = linalg.add (..., %1)

You can first tile the linalg.matmul

%0 = linalg.fill ... outs(%empty)
%1 = scf.forall ... shared_outs(%arg0 = %0) {
   %2 = tensor.extract_slice %arg0[...]
   %3 = linalg.matmul ... outs(%2)
   scf.forall.in_parallel {
       tensor.insert_in_parallel %3 into %arg0
   }
}
%2 = linalg.add (.., %1)

You can fuse the fill and the add in to get

%0 = scf.forall ... shared_outs(%arg0 = %empty) {
    %1 = tensor.extract_slice %arg0
    %2 = linalg.fill ... outs(%1)
    %3 = linalg.matmul ... outs(%2)
    %4 = linalg.add
    scf.forall.in_parallel {
       tensor.insert_in_parallel %4 into %arg0
   }
}

Now you can apply the same two steps again for the second level of tiling and use scf.for instead. Doesnt that give you what you are looking for?

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 7, 2024

@MaheshRavishankar Thanks for you explanation! I can get both your points now.

Lets start with your example above. I think your input is

%0 = linalg.fill .. outs(%empty)
%1 = linalg.matmul ... outs(%0)
%2 = linalg.add (..., %1)

I agree with you if input starts from this way, which couples tiling and fusion step by step, recursively call tileAndFuse is another good option to do that. However, what if matmul has already been completely tiled into nested loop before fusion? For example, developers have some hand-writing optimized matmul template or kernel specific for GPU or CPU, actually decoupling with the fusion stage, Then, they want to fuse post add op. In fact, it is the initial motivation of this patch as I illustrated at the top description.

%0 = linalg.fill ... outs(%empty)
%1 = scf.forall ... shared_outs(%arg0 = %0) {
   %2 = tensor.extract_slice %arg0[...]
   %3 = scf.for ... iter_args(%arg1=%2) {
      %4 = scf.for ... iter_args(%arg2=%3) {
         %5 = tensor.extract_slice %arg2[...]
         %6 = linalg.matmul ... outs(%5)
         %7 = tensor.insert_slice %6 
         scf.yield %7
      }
      scf.yield %4
   }
   scf.forall.in_parallel {
       tensor.insert_in_parallel %3 into %arg0
   }
}
%2 = linalg.add (.., %1)

With current implementation, although it is possible to fuse add at outermost scf.forall(%1), it seems hard to recursively fuse it into next level loop scf.for(%3) without any tensor.insert_slice.

This is changing things much more than I would expect.

Compared with fusing consumer from outer to inner step by step with multiple application, the overall logic of this patch can be simplified into three steps for your review:

  1. enhance getUntiledConsumer to get real consumer of given candidate slice.
  2. use what [MLIR][SCF] Add an API to fuse consumer to a producer within scf loop #88712 have done to fuse real consumer into parent loop of candidate slice.
  3. restore OUTER LOOPs by the way similar to existing addInitOperandsToLoopNest method.

As you may see, only 1 an 3 is newly added. The other change may just involve some code refactor for better reuse.

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch from 9c04ad4 to ec9640c Compare June 25, 2024 08:59
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jun 25, 2024

Hi, @MaheshRavishankar @nicolasvasilache.

I have refactored the overall implementation as you advice using multiple application of existing transform. To solve the problem what I mentioned in above thread, some of previous code have to stay, say getResultOfTopLevelLoopYieldInsertSliceOp to collect a chain of candidate sliceOps.

With current implementation, although it is possible to fuse add at outermost scf.forall(%1), it seems hard to recursively fuse it into next level loop scf.for(%3) without any tensor.insert_slice.

In this PR, the original tileAndFuseConsumerOfSlice was renamed to tileAndFuseConsumerOfSliceImpl to represent basic process for single scf loop(also enhanced to support perfectly outer loops with scf.yield terminator only), while tileAndFuseConsumerOfSlice now represents more powerful functionality to deal with complex nested loop structure.

This version maybe much friendly to you with less changes. This PR is quite important for further development. Please help to go on review. Thanks!

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Jul 2, 2024

I have refactored the overall implementation as you advised using multiple application of existing transform.

Hi, @MaheshRavishankar @Abhishek-Varma @nicolasvasilache. Sincerely looking forward to your new comments!

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch from ec9640c to 2ffce48 Compare July 5, 2024 04:14
@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch 3 times, most recently from 74437c4 to e11eae4 Compare July 8, 2024 08:47
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. I really request we decouple the start with any loop nest and tile and fuse into it part of the changes here and go more incrementally. The end goals seems to be mixed in with all changes that is adding complexity from the get go. Could we start with just adding support consumer fusion with a single nested scf.for bfore generalizing it. That itself has enough complexity.

return failure();
// Step 1. Check that the value has exactly one use excluding `insertSliceOp`
// or `ParallelInsertSliceOp`.
OpOperand *operand = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that is annoying... this is just dead code. Maybe we should figure out how to remove those extract_slice and insert_slice.

return failure();
// Step 1. Check that the value has exactly one use excluding `insertSliceOp`
// or `ParallelInsertSliceOp`.
OpOperand *operand = nullptr;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be easier if we just erase the candidateSliceOp on the first fusion (which probably has some uses which are extract_slices) so that we dont have to make this more complicated?

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Sep 6, 2024

Could we start with just adding support consumer fusion with a single nested scf.for bfore generalizing it. That itself has enough complexity.

I see. I will try to furtherly decouple current changes to support single nested scf,for.

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch from 0884a18 to 0b9355b Compare September 9, 2024 04:59
@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Sep 9, 2024

Hi, @MaheshRavishankar, I have furtherly cleaned up the irrelevant code to merely support single nested scf.for as a start of complex nested loop structure. As the result, the test .mlir has been simplified as well.

IMO, this patch fallback to focus on how do we reconstruct nested scf.for with new inits before the consumer. This is much easier to review than previous versions.

@Yun-Fly Yun-Fly changed the title [mlir][scf] Extend consumer fuse to nested loop structure [mlir][scf] Extend consumer fuse to single nested scf.for Sep 9, 2024
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I just have one last minor comment after which this can land.

A follow up to this would be to change the consumer fusion code to use the addInitOperandsToLoopNest method cause that already accounts for lot of the complexity here of adding new inits to the tile loop nest.

@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch from 0b9355b to 9ff4abc Compare September 10, 2024 13:00
@Yun-Fly Yun-Fly force-pushed the yunfei/fuse_consumer_nested_loop branch from 9ff4abc to 38956c2 Compare September 10, 2024 13:05
Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for cleaning this up! I left one more comment on the use of rewriter. Please address before landing, but this looks good to me.

@Yun-Fly Yun-Fly merged commit 2d4bdfb into llvm:main Sep 12, 2024
8 checks passed
Copy link

@Yun-Fly Congratulations on having your first Pull Request (PR) merged into the LLVM Project!

Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR.

Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues.

How to do this, and the rest of the post-merge process, is covered in detail here.

If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again.

If you don't get any reports, no action is required from you. Your changes are working as expected, well done!

@llvm-ci
Copy link
Collaborator

llvm-ci commented Sep 12, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia running on mlir-nvidia while building mlir at step 5 "build-check-mlir-build-only".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/3524

Here is the relevant piece of the build log for the reference
Step 5 (build-check-mlir-build-only) failure: build (failure)
...
25.519 [170/7/4794] Creating library symlink lib/libMLIRToLLVMIRTranslationRegistration.so
25.520 [169/7/4795] Creating library symlink lib/libMLIRTestTransformDialect.so
25.523 [168/7/4796] Creating library symlink lib/libLLVMOrcDebugging.so
25.524 [168/6/4797] Creating library symlink lib/libMyExtensionCh4.so
25.605 [168/5/4798] Linking CXX executable tools/mlir/unittests/Dialect/Transform/MLIRTransformDialectTests
25.605 [168/4/4799] Linking CXX shared library lib/libMLIRROCDLTarget.so.20.0git
25.613 [167/4/4800] Creating library symlink lib/libMLIRROCDLTarget.so
25.624 [167/3/4801] Linking CXX shared library lib/libMLIRCAPITarget.so.20.0git
25.630 [166/3/4802] Creating library symlink lib/libMLIRCAPITarget.so
25.751 [166/2/4803] Building CXX object tools/mlir/lib/Dialect/SCF/Transforms/CMakeFiles/obj.MLIRSCFTransforms.dir/TileUsingInterface.cpp.o
FAILED: tools/mlir/lib/Dialect/SCF/Transforms/CMakeFiles/obj.MLIRSCFTransforms.dir/TileUsingInterface.cpp.o 
CCACHE_CPP2=yes CCACHE_HASHDIR=yes /usr/bin/ccache /usr/bin/clang++ -DGTEST_HAS_RTTI=0 -D_DEBUG -D_GLIBCXX_ASSERTIONS -D_GNU_SOURCE -D__STDC_CONSTANT_MACROS -D__STDC_FORMAT_MACROS -D__STDC_LIMIT_MACROS -I/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/tools/mlir/lib/Dialect/SCF/Transforms -I/vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/lib/Dialect/SCF/Transforms -I/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/include -I/vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/llvm/include -I/vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/include -I/vol/worker/mlir-nvidia/mlir-nvidia/llvm.obj/tools/mlir/include -fPIC -fno-semantic-interposition -fvisibility-inlines-hidden -Werror=date-time -Werror=unguarded-availability-new -Wall -Wextra -Wno-unused-parameter -Wwrite-strings -Wcast-qual -Wmissing-field-initializers -pedantic -Wno-long-long -Wc++98-compat-extra-semi -Wimplicit-fallthrough -Wcovered-switch-default -Wno-noexcept-type -Wnon-virtual-dtor -Wdelete-non-virtual-dtor -Wsuggest-override -Wstring-conversion -Wmisleading-indentation -Wctad-maybe-unsupported -fdiagnostics-color -ffunction-sections -fdata-sections -Wundef -Werror=mismatched-tags -Werror=global-constructors -O3 -DNDEBUG  -fno-exceptions -funwind-tables -fno-rtti -UNDEBUG -std=c++17 -MD -MT tools/mlir/lib/Dialect/SCF/Transforms/CMakeFiles/obj.MLIRSCFTransforms.dir/TileUsingInterface.cpp.o -MF tools/mlir/lib/Dialect/SCF/Transforms/CMakeFiles/obj.MLIRSCFTransforms.dir/TileUsingInterface.cpp.o.d -o tools/mlir/lib/Dialect/SCF/Transforms/CMakeFiles/obj.MLIRSCFTransforms.dir/TileUsingInterface.cpp.o -c /vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
/vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp:1794:49: error: reference to local binding 'index' declared in enclosing lambda expression
          tiledDestStyleOp.getDpsInitsMutable()[index].set(destSlice);
                                                ^
/vol/worker/mlir-nvidia/mlir-nvidia/llvm.src/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp:1787:26: note: 'index' declared here
      for (const auto &&[index, newRegionArg] :
                         ^
1 error generated.
29.697 [166/1/4804] Building CXX object lib/CodeGen/AsmPrinter/CMakeFiles/LLVMAsmPrinter.dir/AsmPrinter.cpp.o
ninja: build stopped: subcommand failed.

@kazutakahirata
Copy link
Contributor

I've reverted your patch because of a build failure reported at:

https://lab.llvm.org/buildbot/#/builders/138/builds/3524

Meanwhile, in my environment with clang-16.0.6 as the host compiler, I see:

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp:1794:49: error: captured structured bindings are a C++20 extension [-Werror,-Wc++20-extensions]                              
          tiledDestStyleOp.getDpsInitsMutable()[index].set(destSlice);

@Yun-Fly
Copy link
Contributor Author

Yun-Fly commented Sep 12, 2024

I've reverted your patch because of a build failure reported at:

https://lab.llvm.org/buildbot/#/builders/138/builds/3524

Meanwhile, in my environment with clang-16.0.6 as the host compiler, I see:

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp:1794:49: error: captured structured bindings are a C++20 extension [-Werror,-Wc++20-extensions]                              
          tiledDestStyleOp.getDpsInitsMutable()[index].set(destSlice);

Sorry for that. I have reopened the new mirror PR(#108318) with build fix.

Yun-Fly added a commit that referenced this pull request Sep 12, 2024
Refactor current consumer fusion based on `addInitOperandsToLoopNest` to support single nested `scf.for`, E.g.

```
%0 = scf.for() {
  %1 = scf.for() {
     tiledProducer
  }
  yield %1
}
%2 = consumer ins(%0)
```

Compared with #94190, this PR fix build failure by making C++17 happy.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants