Skip to content

Commit 0087523

Browse files
authored
[MLIR][Affine] Add missing check on fusion compute tolerance on a path (#128454)
When profitability analysis can't be performed, we should still be respecting the compute tolerance specified. Refactor to pull the additional computation factor computation and check. Fixes: #54541
1 parent ea4e19d commit 0087523

File tree

2 files changed

+122
-24
lines changed

2 files changed

+122
-24
lines changed

mlir/lib/Dialect/Affine/Transforms/LoopFusion.cpp

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
1616
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1717
#include "mlir/Dialect/Affine/Analysis/Utils.h"
18-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
1918
#include "mlir/Dialect/Affine/LoopFusionUtils.h"
2019
#include "mlir/Dialect/Affine/LoopUtils.h"
2120
#include "mlir/Dialect/Affine/Utils.h"
@@ -473,7 +472,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
473472
// is lower.
474473
// TODO: Extend profitability analysis to support scenarios with multiple
475474
// stores.
476-
static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
475+
static bool isFusionProfitable(AffineForOp srcForOp,
476+
ArrayRef<Operation *> producerStores,
477477
AffineForOp dstForOp,
478478
ArrayRef<ComputationSliceState> depthSliceUnions,
479479
unsigned maxLegalFusionDepth,
@@ -503,6 +503,35 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
503503
if (!getLoopNestStats(dstForOp, &dstLoopNestStats))
504504
return false;
505505

506+
// We limit profitability analysis to only scenarios with
507+
// a single producer store for now. Note that some multi-store
508+
// producer scenarios will still go through profitability analysis
509+
// if only one of the stores is involved in the producer-consumer
510+
// relationship of the candidate loops.
511+
// TODO: Suppport multiple producer stores in profitability
512+
// analysis.
513+
if (producerStores.size() > 1) {
514+
LLVM_DEBUG(llvm::dbgs() << "Limited profitability analysis. Not "
515+
"supported for multiple producer store case.\n");
516+
int64_t sliceCost;
517+
int64_t fusedLoopNestComputeCost;
518+
// We will still fuse if fusion obeys the specified compute
519+
// tolerance at the max legal depth.
520+
auto fraction = getAdditionalComputeFraction(
521+
srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
522+
fusedLoopNestComputeCost);
523+
if (!fraction || fraction > computeToleranceThreshold) {
524+
LLVM_DEBUG(llvm::dbgs() << "Additional computation exceeds "
525+
"compute tolerance. Not fusing.\n");
526+
return false;
527+
}
528+
LLVM_DEBUG(llvm::dbgs()
529+
<< "Considering fusion profitable at max legal depth.\n");
530+
return true;
531+
}
532+
533+
Operation *srcStoreOp = producerStores.front();
534+
506535
// Search for min cost value for 'dstLoopDepth'. At each value of
507536
// 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
508537
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
@@ -516,12 +545,9 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
516545
// The best loop depth at which to materialize the slice.
517546
std::optional<unsigned> bestDstLoopDepth;
518547

519-
// Compute op instance count for the src loop nest without iteration slicing.
520-
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
521-
522548
// Compute src loop nest write region size.
523-
MemRefRegion srcWriteRegion(srcStoreOpInst->getLoc());
524-
if (failed(srcWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0))) {
549+
MemRefRegion srcWriteRegion(srcStoreOp->getLoc());
550+
if (failed(srcWriteRegion.compute(srcStoreOp, /*loopDepth=*/0))) {
525551
LLVM_DEBUG(llvm::dbgs()
526552
<< "Unable to compute MemRefRegion for source operation\n");
527553
return false;
@@ -533,7 +559,10 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
533559
return false;
534560
int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
535561

536-
// Compute op instance count for the src loop nest.
562+
// Compute op instance count for the src loop nest without iteration slicing.
563+
uint64_t srcLoopNestCost = getComputeCost(srcForOp, srcLoopNestStats);
564+
565+
// Compute op instance count for the destination loop nest.
537566
uint64_t dstLoopNestCost = getComputeCost(dstForOp, dstLoopNestStats);
538567

539568
// Evaluate all depth choices for materializing the slice in the destination
@@ -563,9 +592,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
563592
// Determine what the slice write MemRefRegion would be, if the src loop
564593
// nest slice 'slice' were to be inserted into the dst loop nest at loop
565594
// depth 'i'.
566-
MemRefRegion sliceWriteRegion(srcStoreOpInst->getLoc());
567-
if (failed(sliceWriteRegion.compute(srcStoreOpInst, /*loopDepth=*/0,
568-
&slice))) {
595+
MemRefRegion sliceWriteRegion(srcStoreOp->getLoc());
596+
if (failed(sliceWriteRegion.compute(srcStoreOp, /*loopDepth=*/0, &slice))) {
569597
LLVM_DEBUG(llvm::dbgs()
570598
<< "Failed to compute slice write region at loopDepth: " << i
571599
<< "\n");
@@ -1025,21 +1053,13 @@ struct GreedyFusion {
10251053
cast<AffineWriteOpInterface>(op).getMemRef()))
10261054
producerStores.push_back(op);
10271055

1028-
// TODO: Suppport multiple producer stores in profitability
1029-
// analysis. We limit profitability analysis to only scenarios with
1030-
// a single producer store for now. Note that some multi-store
1031-
// producer scenarios will still go through profitability analysis
1032-
// if only one of the stores is involved the producer-consumer
1033-
// relationship of the candidate loops.
10341056
assert(!producerStores.empty() && "Expected producer store");
1035-
if (producerStores.size() > 1)
1036-
LLVM_DEBUG(llvm::dbgs() << "Skipping profitability analysis. Not "
1037-
"supported for this case\n");
1038-
else if (!isFusionProfitable(srcAffineForOp, producerStores[0],
1039-
dstAffineForOp, depthSliceUnions,
1040-
maxLegalFusionDepth, &bestDstLoopDepth,
1041-
computeToleranceThresholdToUse))
1057+
if (!isFusionProfitable(srcAffineForOp, producerStores,
1058+
dstAffineForOp, depthSliceUnions,
1059+
maxLegalFusionDepth, &bestDstLoopDepth,
1060+
computeToleranceThresholdToUse)) {
10421061
continue;
1062+
}
10431063
}
10441064

10451065
assert(bestDstLoopDepth > 0 && "Unexpected loop fusion depth");

mlir/test/Dialect/Affine/loop-fusion-4.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER
2+
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{compute-tolerance=0.0}))' -split-input-file | FileCheck %s --check-prefix=ZERO-TOLERANCE
23
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{mode=producer maximal}))' -split-input-file | FileCheck %s --check-prefix=PRODUCER-CONSUMER-MAXIMAL
34
// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(affine-loop-fusion{maximal mode=sibling}))' -split-input-file | FileCheck %s --check-prefix=SIBLING-MAXIMAL
45
// All fusion: producer-consumer and sibling.
@@ -544,3 +545,80 @@ func.func @sibling_reduction(%input : memref<10xf32>, %output : memref<10xf32>,
544545
// SIBLING-MAXIMAL-NEXT: affine.store
545546
return
546547
}
548+
549+
// -----
550+
551+
// From https://github.com/llvm/llvm-project/issues/54541
552+
553+
#map = affine_map<(d0) -> (d0 mod 65536)>
554+
// ZERO-TOLERANCE-LABEL: func @zero_tolerance
555+
func.func @zero_tolerance(%arg0: memref<65536xcomplex<f64>>, %arg1: memref<30x131072xi64>,
556+
%3 : memref<30xi64>,
557+
%4 : memref<30xi64>,
558+
%5 : memref<30xi64>,
559+
%6 : memref<30xi64>
560+
) {
561+
%c65536 = arith.constant 65536 : index
562+
%cst = arith.constant 0.000000e+00 : f64
563+
%cst_0 = arith.constant 0x4320000000380004 : f64
564+
%cst_1 = arith.constant 5.000000e-01 : f64
565+
%0 = memref.alloc() {alignment = 128 : i64} : memref<30x131072xi64>
566+
%1 = memref.alloc() {alignment = 128 : i64} : memref<131072xi1>
567+
%2 = memref.alloc() {alignment = 128 : i64} : memref<131072xi128>
568+
// This nest nest shouldn't be fused in when a zero tolerance is specified.
569+
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 131072
570+
affine.for %arg2 = 0 to 131072 {
571+
%7 = affine.apply #map(%arg2)
572+
%8 = affine.load %arg0[%7] : memref<65536xcomplex<f64>>
573+
%9 = arith.cmpi ult, %arg2, %c65536 : index
574+
%10 = complex.im %8 : complex<f64>
575+
%11 = complex.re %8 : complex<f64>
576+
%12 = arith.select %9, %11, %10 : f64
577+
%13 = arith.cmpf olt, %12, %cst : f64
578+
%14 = arith.negf %12 : f64
579+
%15 = arith.select %13, %14, %12 : f64
580+
%16 = arith.mulf %15, %cst_0 : f64
581+
%17 = arith.addf %16, %cst_1 : f64
582+
%18 = arith.fptosi %17 : f64 to i128
583+
affine.store %18, %2[%arg2] : memref<131072xi128>
584+
affine.store %13, %1[%arg2] : memref<131072xi1>
585+
}
586+
// The next two nests are fused.
587+
// ZERO-TOLERANCE: affine.for %{{.*}} = 0 to 30
588+
// ZERO-TOLERANCE-NEXT: affine.for %{{.*}} = 0 to 131072
589+
// ZERO-TOLERANCE: func.call @__external_reduce_barrett
590+
// ZERO-TOLERANCE: affine.store
591+
// ZERO-TOLERANCE: affine.load
592+
// ZERO-TOLERANCE-NEXT: affine.store
593+
affine.for %arg2 = 0 to 30 {
594+
affine.for %arg3 = 0 to 131072 {
595+
%7 = affine.load %6[%arg2] : memref<30xi64>
596+
%8 = affine.load %3[%arg2] : memref<30xi64>
597+
%9 = affine.load %5[%arg2] : memref<30xi64>
598+
%10 = affine.load %4[%arg2] : memref<30xi64>
599+
%11 = affine.load %2[%arg3] : memref<131072xi128>
600+
%12 = affine.load %1[%arg3] : memref<131072xi1>
601+
%13 = func.call @__external_reduce_barrett(%7, %8, %9, %10, %11) {outputModFac = 1 : i64} : (i64, i64, i64, i64, i128) -> i64
602+
%14 = arith.subi %7, %13 : i64
603+
%15 = arith.select %12, %14, %13 : i64
604+
affine.store %15, %0[%arg2, %arg3] : memref<30x131072xi64>
605+
}
606+
}
607+
func.call @__external_levelwise_forward_ntt(%0) : (memref<30x131072xi64>) -> ()
608+
affine.for %arg2 = 0 to 30 {
609+
affine.for %arg3 = 0 to 131072 {
610+
%7 = affine.load %0[%arg2, %arg3] : memref<30x131072xi64>
611+
affine.store %7, %arg1[%arg2, %arg3] : memref<30x131072xi64>
612+
}
613+
}
614+
// Under maximal fusion, just one nest.
615+
// PRODUCER-CONSUMER-MAXIMAL: affine.for %{{.*}} = 0 to 30
616+
// PRODUCER-CONSUMER-MAXIMAL-NEXT: affine.for %{{.*}} = 0 to 131072
617+
// PRODUCER-CONSUMER-MAXIMAL-NOT: affine.for %{{.*}}
618+
memref.dealloc %2 : memref<131072xi128>
619+
memref.dealloc %1 : memref<131072xi1>
620+
memref.dealloc %0 : memref<30x131072xi64>
621+
return
622+
}
623+
func.func private @__external_levelwise_forward_ntt(memref<30x131072xi64>)
624+
func.func private @__external_reduce_barrett(i64, i64, i64, i64, i128) -> i64

0 commit comments

Comments
 (0)