Skip to content

[MLIR][Affine] Add missing check on fusion compute tolerance on a path #128454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

bondhugula
Copy link
Contributor

When profitability analysis can't be performed, we should still be respecting the compute tolerance specified. Refactor to pull the additional computation factor computation and check.

Fixes: #54541

@llvmbot
Copy link
Member

llvmbot commented Feb 24, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-affine

Author: Uday Bondhugula (bondhugula)

Changes

When profitability analysis can't be performed, we should still be respecting the compute tolerance specified. Refactor to pull the additional computation factor computation and check.

Fixes: #54541


Full diff: https://github.com/llvm/llvm-project/pull/128454.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp (+109-46)
  • (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+75)
diff --git a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
index 5add7df849286..7945e156dcd31 100644
--- a/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp
@@ -15,7 +15,6 @@
 #include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
 #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/Analysis/Utils.h"
-#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Affine/LoopFusionUtils.h"
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/Dialect/Affine/Utils.h"
@@ -274,6 +273,58 @@ getDominanceFilterForPrivateMemRefRepl(Block *sliceInsertionBlock,
   return firstAncestor;
 }
 
+/// Returns the amount of additional (redundant) computation that will be done
+/// as a fraction of the total computation if `srcForOp` is fused into
+/// `dstForOp` at depth `depth`. The method returns the compute cost of the
+/// slice and the fused nest's compute cost in the trailing output arguments.
+static std::optional<double> getAdditionalComputeFraction(
+    AffineForOp srcForOp, AffineForOp dstForOp, unsigned depth,
+    ArrayRef<ComputationSliceState> depthSliceUnions, int64_t &sliceCost,
+    int64_t &fusedLoopNestComputeCost) {
+  LLVM_DEBUG(llvm::dbgs() << "Determining additional compute fraction...\n";);
+  // Compute cost of sliced and unsliced src loop nest.
+  // Walk src loop nest and collect stats.
+  LoopNestStats srcLoopNestStats;
+  if (!getLoopNestStats(srcForOp, &srcLoopNestStats)) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to get source loop nest stats.\n");
+    return std::nullopt;
+  }
+
+  // Compute cost of dst loop nest.
+  LoopNestStats dstLoopNestStats;
+  if (!getLoopNestStats(dstForOp, &dstLoopNestStats)) {
+    LLVM_DEBUG(llvm::dbgs() << "Failed to get destination loop nest stats.\n");
+    return std::nullopt;
+  }
+
+  // Compute op instance count for the src loop nest without iteration slicing.
+  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
+
+  // Compute op cost for the dst loop nest.
+  uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
+
+  const ComputationSliceState &slice = depthSliceUnions[depth - 1];
+  // Skip slice union if it wasn't computed for this depth.
+  if (slice.isEmpty()) {
+    LLVM_DEBUG(llvm::dbgs() << "Slice wasn't computed.\n");
+    return std::nullopt;
+  }
+
+  if (!getFusionComputeCost(srcForOp, srcLoopNestStats, dstForOp,
+                            dstLoopNestStats, slice,
+                            &fusedLoopNestComputeCost)) {
+    LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+    return std::nullopt;
+  }
+
+  double additionalComputeFraction =
+      fusedLoopNestComputeCost /
+          (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
+      1;
+
+  return additionalComputeFraction;
+}
+
 // Creates and returns a private (single-user) memref for fused loop rooted at
 // 'forOp', with (potentially reduced) memref size based on the memref region
 // written to by `storeOps` at depth 'dstLoopDepth'. 'sliceInsertionBlock'
@@ -384,20 +435,19 @@ static Value createPrivateMemRef(AffineForOp forOp,
 }
 
 // Checks the profitability of fusing a backwards slice of the loop nest
-// surrounding 'srcOpInst' into the loop nest surrounding 'dstLoadOpInsts'.
-// The argument 'srcStoreOpInst' is used to calculate the storage reduction on
-// the memref being produced and consumed, which is an input to the cost model.
-// For producer-consumer fusion, 'srcStoreOpInst' will be the same as
-// 'srcOpInst', as we are slicing w.r.t to that producer. For input-reuse
-// fusion, 'srcOpInst' will be the src loop nest LoadOp which reads from the
-// same memref as dst loop nest load ops, and 'srcStoreOpInst' will be the
-// unique store op in the src node, which will be used to check that the write
-// region is the same after input-reuse fusion. Computation slices are provided
-// in 'depthSliceUnions' for each legal fusion depth. The maximal depth at which
-// fusion is legal is provided in 'maxLegalFusionDepth'. Returns true if it is
-// profitable to fuse the candidate loop nests. Returns false otherwise.
-// `dstLoopDepth` is set to the most profitable depth at which to materialize
-// the source loop nest slice.
+// `srcForOp` into the loop nest surrounding 'dstLoadOpInsts'. The argument
+// 'srcStoreOpInst' is used to calculate the storage reduction on the memref
+// being produced and consumed, which is an input to the cost model. For
+// producer-consumer fusion, 'srcStoreOpInst' will be the same as 'srcOpInst',
+// as we are slicing w.r.t to that producer. For input-reuse fusion, 'srcOpInst'
+// will be the src loop nest LoadOp which reads from the same memref as dst loop
+// nest load ops, and 'srcStoreOpInst' will be the unique store op in the src
+// node, which will be used to check that the write region is the same after
+// input-reuse fusion. Computation slices are provided in 'depthSliceUnions' for
+// each legal fusion depth. The maximal depth at which fusion is legal is
+// provided in 'maxLegalFusionDepth'. Returns true if it is profitable to fuse
+// the candidate loop nests. Returns false otherwise. `dstLoopDepth` is set to
+// the most profitable depth at which to materialize the source loop nest slice.
 // The profitability model executes the following steps:
 // *) Computes the backward computation slice at 'srcOpInst'. This
 //    computation slice of the loop nest surrounding 'srcOpInst' is
@@ -422,15 +472,16 @@ static Value createPrivateMemRef(AffineForOp forOp,
 //    is lower.
 // TODO: Extend profitability analysis to support scenarios with multiple
 // stores.
-static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
+static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
                                AffineForOp dstForOp,
                                ArrayRef<ComputationSliceState> depthSliceUnions,
                                unsigned maxLegalFusionDepth,
                                unsigned *dstLoopDepth,
                                double computeToleranceThreshold) {
   LLVM_DEBUG({
-    llvm::dbgs() << "Checking whether fusion is profitable between src op:\n";
-    llvm::dbgs() << ' ' << *srcOpInst << " and destination loop:\n";
+    llvm::dbgs()
+        << "Checking whether fusion is profitable between source nest:\n";
+    llvm::dbgs() << ' ' << srcForOp << " and destination nest:\n";
     llvm::dbgs() << dstForOp << "\n";
   });
 
@@ -440,12 +491,10 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   }
 
   // Compute cost of sliced and unsliced src loop nest.
-  SmallVector<AffineForOp, 4> srcLoopIVs;
-  getAffineForIVs(*srcOpInst, &srcLoopIVs);
 
   // Walk src loop nest and collect stats.
   LoopNestStats srcLoopNestStats;
-  if (!getLoopNestStats(srcLoopIVs[0], &srcLoopNestStats))
+  if (!getLoopNestStats(srcForOp, &srcLoopNestStats))
     return false;
 
   // Compute cost of dst loop nest.
@@ -467,7 +516,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
   std::optional<unsigned> bestDstLoopDepth;
 
   // Compute op instance count for the src loop nest without iteration slicing.
-  uint64_t srcLoopNestCost = getComputeCost(srcLoopIVs[0], srcLoopNestStats);
+  uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
 
   // Compute src loop nest write region size.
   MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
@@ -494,18 +543,21 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     if (slice.isEmpty())
       continue;
 
+    // Compute cost of the slice separately, i.e, the compute cost of the slice
+    // if all outer trip counts are one.
+    int64_t sliceCost;
+
     int64_t fusedLoopNestComputeCost;
-    if (!getFusionComputeCost(srcLoopIVs[0], srcLoopNestStats, dstForOp,
-                              dstLoopNestStats, slice,
-                              &fusedLoopNestComputeCost)) {
-      LLVM_DEBUG(llvm::dbgs() << "Unable to compute fusion compute cost\n");
+
+    auto mayAdditionalComputeFraction =
+        getAdditionalComputeFraction(srcForOp, dstForOp, i, depthSliceUnions,
+                                     sliceCost, fusedLoopNestComputeCost);
+    if (!mayAdditionalComputeFraction) {
+      LLVM_DEBUG(llvm::dbgs()
+                 << "Can't determine additional compute fraction.\n");
       continue;
     }
-
-    double additionalComputeFraction =
-        fusedLoopNestComputeCost /
-            (static_cast<double>(srcLoopNestCost) + dstLoopNestCost) -
-        1;
+    double additionalComputeFraction = *mayAdditionalComputeFraction;
 
     // Determine what the slice write MemRefRegion would be, if the src loop
     // nest slice 'slice' were to be inserted into the dst loop nest at loop
@@ -530,14 +582,6 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
     }
     int64_t sliceWriteRegionSizeBytes = *maybeSliceWriteRegionSizeBytes;
 
-    // If we are fusing for reuse, check that write regions remain the same.
-    // TODO: Write region check should check sizes and offsets in
-    // each dimension, so that we are sure they are covering the same memref
-    // region. Also, move this out to a isMemRefRegionSuperSet helper function.
-    if (srcOpInst != srcStoreOpInst &&
-        sliceWriteRegionSizeBytes != srcWriteRegionSizeBytes)
-      continue;
-
     double storageReduction = static_cast<double>(srcWriteRegionSizeBytes) /
                               static_cast<double>(sliceWriteRegionSizeBytes);
 
@@ -595,7 +639,7 @@ static bool isFusionProfitable(Operation *srcOpInst, Operation *srcStoreOpInst,
                    << minFusedLoopNestComputeCost << "\n");
 
   auto dstMemSize = getMemoryFootprintBytes(dstForOp);
-  auto srcMemSize = getMemoryFootprintBytes(srcLoopIVs[0]);
+  auto srcMemSize = getMemoryFootprintBytes(srcForOp);
 
   std::optional<double> storageReduction;
 
@@ -840,6 +884,8 @@ struct GreedyFusion {
         LLVM_DEBUG(llvm::dbgs()
                    << "Trying to fuse producer loop nest " << srcId
                    << " with consumer loop nest " << dstId << "\n");
+        LLVM_DEBUG(llvm::dbgs() << "Compute tolerance threshold: "
+                                << computeToleranceThreshold << '\n');
         LLVM_DEBUG(llvm::dbgs()
                    << "Producer loop nest:\n"
                    << *srcNode->op << "\n and consumer loop nest:\n"
@@ -926,6 +972,9 @@ struct GreedyFusion {
           continue;
         }
 
+        LLVM_DEBUG(llvm::dbgs() << "Max legal depth for fusion: "
+                                << maxLegalFusionDepth << '\n');
+
         // Check if fusion would be profitable. We skip profitability analysis
         // for maximal fusion since we already know the maximal legal depth to
         // fuse.
@@ -945,14 +994,28 @@ struct GreedyFusion {
           // if only one of the stores is involved the producer-consumer
           // relationship of the candidate loops.
           assert(!producerStores.empty() && "Expected producer store");
-          if (producerStores.size() > 1)
+          if (producerStores.size() > 1) {
             LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
                                        "supported for this case\n");
-          else if (!isFusionProfitable(producerStores[0], producerStores[0],
-                                       dstAffineForOp, depthSliceUnions,
-                                       maxLegalFusionDepth, &bestDstLoopDepth,
-                                       computeToleranceThreshold))
+            // We will still fuse if fusion obeys the specified compute
+            // tolerance at the max legal depth.
+            int64_t sliceCost;
+            int64_t fusedLoopNestComputeCost;
+            auto fraction = getAdditionalComputeFraction(
+                srcAffineForOp, dstAffineForOp, maxLegalFusionDepth,
+                depthSliceUnions, sliceCost, fusedLoopNestComputeCost);
+            if (!fraction || fraction > computeToleranceThreshold) {
+              LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
+                                         "compute tolerance. Not fusing.\n");
+              continue;
+            }
+          }
+          if (!isFusionProfitable(srcAffineForOp, producerStores[0],
+                                  dstAffineForOp, depthSliceUnions,
+                                  maxLegalFusionDepth, &bestDstLoopDepth,
+                                  computeToleranceThreshold)) {
             continue;
+          }
         }
 
         assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");
@@ -1169,7 +1232,7 @@ struct GreedyFusion {
         // load op is treated as the src "store" op for fusion profitability
         // purposes. The footprint of the load in the slice relative to the
         // unfused source's determines reuse.
-        if (!isFusionProfitable(sibLoadOpInst, sibLoadOpInst, dstAffineForOp,
+        if (!isFusionProfitable(sibAffineForOp, sibLoadOpInst, dstAffineForOp,
                                 depthSliceUnions, maxLegalFusionDepth,
                                 &bestDstLoopDepth, computeToleranceThreshold))
           continue;
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 42d5ce632188e..1fca35836fcc2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -1,4 +1,5 @@
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
+// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
 // RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
 // All fusion: producer-consumer and sibling.
@@ -495,3 +496,77 @@ func.func @test_add_slice_bounds() {
   }
   return
 }
+
+// -----
+
+// From  https://github.com/llvm/llvm-project/issues/54541
+
+#map = affine_map<(d0) -> (d0 mod 65536)>
+// ZERO-TOLERANCE-LABEL: func @zero_tolerance
+func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
+%3 : memref<30xi64>,
+%4 : memref<30xi64>,
+%5 : memref<30xi64>,
+%6 : memref<30xi64>
+) {
+  %c65536 = arith.constant 65536 : index
+  %cst = arith.constant 0.000000e+00 : f64
+  %cst_0 = arith.constant 0x4320000000380004 : f64
+  %cst_1 = arith.constant 5.000000e-01 : f64
+  %0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
+  %1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
+  %2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
+  // The two nests shouldn't be fused when a zero tolerance is specified.
+  // ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 131072 {
+    %7 = affine.apply #map(%arg2)
+    %8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
+    %9 = arith.cmpi ult, %arg2, %c65536 : index
+    %10 = complex.im %8 : complex<f64>
+    %11 = complex.re %8 : complex<f64>
+    %12 = arith.select %9, %11, %10 : f64
+    %13 = arith.cmpf olt, %12, %cst : f64
+    %14 = arith.negf %12 : f64
+    %15 = arith.select %13, %14, %12 : f64
+    %16 = arith.mulf %15, %cst_0 : f64
+    %17 = arith.addf %16, %cst_1 : f64
+    %18 = arith.fptosi %17 : f64 to i128
+    affine.store %18, %2[%arg2] : memref<131072xi128>
+    affine.store %13, %1[%arg2] : memref<131072xi1>
+  }
+  // ZERO-TOLERANCE:      affine.for %{{.*}} = 0 to 30
+  // ZERO-TOLERANCE-NEXT:   affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 30 {
+    affine.for %arg3 = 0 to 131072 {
+      %7 = affine.load %6[%arg2] : memref<30xi64>
+      %8 = affine.load %3[%arg2] : memref<30xi64>
+      %9 = affine.load %5[%arg2] : memref<30xi64>
+      %10 = affine.load %4[%arg2] : memref<30xi64>
+      %11 = affine.load %2[%arg3] : memref<131072xi128>
+      %12 = affine.load %1[%arg3] : memref<131072xi1>
+      %13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
+      %14 = arith.subi %7, %13 : i64
+      %15 = arith.select %12, %14, %13 : i64
+      affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
+    }
+  }
+  func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
+  // ZERO-TOLERANCE:      affine.for %{{.*}} = 0 to 30
+  // ZERO-TOLERANCE-NEXT:   affine.for %{{.*}} = 0 to 131072
+  affine.for %arg2 = 0 to 30 {
+    affine.for %arg3 = 0 to 131072 {
+      %7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
+      affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
+    }
+  }
+  // Under maximal fusion, just one nest.
+  // PRODUCER-CONSUMER-MAXIMAL:      affine.for %{{.*}} = 0 to 30
+  // PRODUCER-CONSUMER-MAXIMAL-NEXT:   affine.for %{{.*}} = 0 to 131072
+  // PRODUCER-CONSUMER-MAXIMAL-NOT:  affine.for %{{.*}}
+  memref.dealloc %2 : memref<131072xi128>
+  memref.dealloc %1 : memref<131072xi1>
+  memref.dealloc %0 : memref<30x131072xi64>
+  return
+}
+func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
+func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64

When profitability analysis can't be performed, we should still be
respecting the compute tolerance specified. Refactor to pull the
additional computation factor computation and check.

Fixes: llvm#54541
@bondhugula bondhugula force-pushed the uday/fusion_compute_tolerance_handling branch from 6d31ee5 to 681ffac Compare February 25, 2025 07:11
@bondhugula
Copy link
Contributor Author

Obvious fix/changes that fixes a filed bug. Merging.

@bondhugula bondhugula merged commit 0087523 into llvm:main Feb 25, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

affine-loop-fusion does not respect fusion-compute-tolerance
2 participants