15
15
#include " mlir/Dialect/Affine/Analysis/AffineStructures.h"
16
16
#include " mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17
17
#include " mlir/Dialect/Affine/Analysis/Utils.h"
18
- #include " mlir/Dialect/Affine/IR/AffineOps.h"
19
18
#include " mlir/Dialect/Affine/LoopFusionUtils.h"
20
19
#include " mlir/Dialect/Affine/LoopUtils.h"
21
20
#include " mlir/Dialect/Affine/Utils.h"
@@ -473,7 +472,8 @@ static Value createPrivateMemRef(AffineForOp forOp,
473
472
// is lower.
474
473
// TODO: Extend profitability analysis to support scenarios with multiple
475
474
// stores.
476
- static bool isFusionProfitable (AffineForOp srcForOp, Operation *srcStoreOpInst,
475
+ static bool isFusionProfitable (AffineForOp srcForOp,
476
+ ArrayRef<Operation *> producerStores,
477
477
AffineForOp dstForOp,
478
478
ArrayRef<ComputationSliceState> depthSliceUnions,
479
479
unsigned maxLegalFusionDepth,
@@ -503,6 +503,35 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
503
503
if (!getLoopNestStats (dstForOp, &dstLoopNestStats))
504
504
return false ;
505
505
506
+ // We limit profitability analysis to only scenarios with
507
+ // a single producer store for now. Note that some multi-store
508
+ // producer scenarios will still go through profitability analysis
509
+ // if only one of the stores is involved in the producer-consumer
510
+ // relationship of the candidate loops.
511
+ // TODO: Suppport multiple producer stores in profitability
512
+ // analysis.
513
+ if (producerStores.size () > 1 ) {
514
+ LLVM_DEBUG (llvm::dbgs () << " Limited profitability analysis. Not "
515
+ " supported for multiple producer store case.\n " );
516
+ int64_t sliceCost;
517
+ int64_t fusedLoopNestComputeCost;
518
+ // We will still fuse if fusion obeys the specified compute
519
+ // tolerance at the max legal depth.
520
+ auto fraction = getAdditionalComputeFraction (
521
+ srcForOp, dstForOp, maxLegalFusionDepth, depthSliceUnions, sliceCost,
522
+ fusedLoopNestComputeCost);
523
+ if (!fraction || fraction > computeToleranceThreshold) {
524
+ LLVM_DEBUG (llvm::dbgs () << " Additional computation exceeds "
525
+ " compute tolerance. Not fusing.\n " );
526
+ return false ;
527
+ }
528
+ LLVM_DEBUG (llvm::dbgs ()
529
+ << " Considering fusion profitable at max legal depth.\n " );
530
+ return true ;
531
+ }
532
+
533
+ Operation *srcStoreOp = producerStores.front ();
534
+
506
535
// Search for min cost value for 'dstLoopDepth'. At each value of
507
536
// 'dstLoopDepth' from 'maxLegalLoopDepth' to '1', compute computation slice
508
537
// bounds between 'srcOpInst' and each op in 'dstOpinsts' (taking the union
@@ -516,12 +545,9 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
516
545
// The best loop depth at which to materialize the slice.
517
546
std::optional<unsigned > bestDstLoopDepth;
518
547
519
- // Compute op instance count for the src loop nest without iteration slicing.
520
- uint64_t srcLoopNestCost = getComputeCost (srcForOp, srcLoopNestStats);
521
-
522
548
// Compute src loop nest write region size.
523
- MemRefRegion srcWriteRegion (srcStoreOpInst ->getLoc ());
524
- if (failed (srcWriteRegion.compute (srcStoreOpInst , /* loopDepth=*/ 0 ))) {
549
+ MemRefRegion srcWriteRegion (srcStoreOp ->getLoc ());
550
+ if (failed (srcWriteRegion.compute (srcStoreOp , /* loopDepth=*/ 0 ))) {
525
551
LLVM_DEBUG (llvm::dbgs ()
526
552
<< " Unable to compute MemRefRegion for source operation\n " );
527
553
return false ;
@@ -533,7 +559,10 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
533
559
return false ;
534
560
int64_t srcWriteRegionSizeBytes = *maybeSrcWriteRegionSizeBytes;
535
561
536
- // Compute op instance count for the src loop nest.
562
+ // Compute op instance count for the src loop nest without iteration slicing.
563
+ uint64_t srcLoopNestCost = getComputeCost (srcForOp, srcLoopNestStats);
564
+
565
+ // Compute op instance count for the destination loop nest.
537
566
uint64_t dstLoopNestCost = getComputeCost (dstForOp, dstLoopNestStats);
538
567
539
568
// Evaluate all depth choices for materializing the slice in the destination
@@ -563,9 +592,8 @@ static bool isFusionProfitable(AffineForOp srcForOp, Operation *srcStoreOpInst,
563
592
// Determine what the slice write MemRefRegion would be, if the src loop
564
593
// nest slice 'slice' were to be inserted into the dst loop nest at loop
565
594
// depth 'i'.
566
- MemRefRegion sliceWriteRegion (srcStoreOpInst->getLoc ());
567
- if (failed (sliceWriteRegion.compute (srcStoreOpInst, /* loopDepth=*/ 0 ,
568
- &slice))) {
595
+ MemRefRegion sliceWriteRegion (srcStoreOp->getLoc ());
596
+ if (failed (sliceWriteRegion.compute (srcStoreOp, /* loopDepth=*/ 0 , &slice))) {
569
597
LLVM_DEBUG (llvm::dbgs ()
570
598
<< " Failed to compute slice write region at loopDepth: " << i
571
599
<< " \n " );
@@ -1025,21 +1053,13 @@ struct GreedyFusion {
1025
1053
cast<AffineWriteOpInterface>(op).getMemRef ()))
1026
1054
producerStores.push_back (op);
1027
1055
1028
- // TODO: Suppport multiple producer stores in profitability
1029
- // analysis. We limit profitability analysis to only scenarios with
1030
- // a single producer store for now. Note that some multi-store
1031
- // producer scenarios will still go through profitability analysis
1032
- // if only one of the stores is involved the producer-consumer
1033
- // relationship of the candidate loops.
1034
1056
assert (!producerStores.empty () && " Expected producer store" );
1035
- if (producerStores.size () > 1 )
1036
- LLVM_DEBUG (llvm::dbgs () << " Skipping profitability analysis. Not "
1037
- " supported for this case\n " );
1038
- else if (!isFusionProfitable (srcAffineForOp, producerStores[0 ],
1039
- dstAffineForOp, depthSliceUnions,
1040
- maxLegalFusionDepth, &bestDstLoopDepth,
1041
- computeToleranceThresholdToUse))
1057
+ if (!isFusionProfitable (srcAffineForOp, producerStores,
1058
+ dstAffineForOp, depthSliceUnions,
1059
+ maxLegalFusionDepth, &bestDstLoopDepth,
1060
+ computeToleranceThresholdToUse)) {
1042
1061
continue ;
1062
+ }
1043
1063
}
1044
1064
1045
1065
assert (bestDstLoopDepth > 0 && " Unexpected loop fusion depth" );
0 commit comments