Skip to content

Commit 76cc040

Browse files
committed
Refactor LoopFuseSiblingOp and support parallel fusion
1 parent 68f4e46 commit 76cc040

File tree

5 files changed

+304
-230
lines changed

5 files changed

+304
-230
lines changed

mlir/include/mlir/Dialect/SCF/Utils/Utils.h

+16
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,12 @@ Loops tilePerfectlyNested(scf::ForOp rootForOp, ArrayRef<Value> sizes);
156156
void getPerfectlyNestedLoops(SmallVectorImpl<scf::ForOp> &nestedLoops,
157157
scf::ForOp root);
158158

159+
/// Prepends operations of firstPloop's body into secondPloop's body.
160+
/// Updates secondPloop with new loop.
161+
void fuseIfLegal(scf::ParallelOp firstPloop, scf::ParallelOp &secondPloop,
162+
OpBuilder builder,
163+
llvm::function_ref<bool(Value, Value)> mayAlias);
164+
159165
/// Given two scf.forall loops, `target` and `source`, fuses `target` into
160166
/// `source`. Assumes that the given loops are siblings and are independent of
161167
/// each other.
@@ -177,6 +183,16 @@ scf::ForallOp fuseIndependentSiblingForallLoops(scf::ForallOp target,
177183
scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source,
178184
RewriterBase &rewriter);
179185

186+
/// Given two scf.parallel loops, `target` and `source`, fuses `target` into
187+
/// `source`. Assumes that the given loops are siblings and are independent of
188+
/// each other.
189+
///
190+
/// This function does not perform any legality checks and simply fuses the
191+
/// loops. The caller is responsible for ensuring that the loops are legal to
192+
/// fuse.
193+
scf::ParallelOp fuseIndependentSiblingParallelLoops(scf::ParallelOp target,
194+
scf::ParallelOp source,
195+
RewriterBase &rewriter);
180196
} // namespace mlir
181197

182198
#endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

+25-28
Original file line numberDiff line numberDiff line change
@@ -442,39 +442,32 @@ static DiagnosedSilenceableFailure isOpSibling(Operation *target,
442442
return DiagnosedSilenceableFailure::success();
443443
}
444444

445-
/// Check if `target` scf.forall can be fused into `source` scf.forall.
445+
/// Check if `target` scf loop can be fused into `source` scf loop.
446+
/// Applies for scf.for, scf.forall, and scf.parallel.
446447
///
447448
/// This simply checks if both loops have the same bounds, steps and mapping.
448449
/// No attempt is made at checking that the side effects of `target` and
449450
/// `source` are independent of each other.
450-
static bool isForallWithIdenticalConfiguration(Operation *target,
451-
Operation *source) {
452-
auto targetOp = dyn_cast<scf::ForallOp>(target);
453-
auto sourceOp = dyn_cast<scf::ForallOp>(source);
454-
if (!targetOp || !sourceOp)
455-
return false;
456-
457-
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
458-
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
459-
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
460-
targetOp.getMapping() == sourceOp.getMapping();
461-
}
462-
463-
/// Check if `target` scf.for can be fused into `source` scf.for.
464-
///
465-
/// This simply checks if both loops have the same bounds and steps. No attempt
466-
/// is made at checking that the side effects of `target` and `source` are
467-
/// independent of each other.
468-
static bool isForWithIdenticalConfiguration(Operation *target,
469-
Operation *source) {
470-
auto targetOp = dyn_cast<scf::ForOp>(target);
471-
auto sourceOp = dyn_cast<scf::ForOp>(source);
451+
template <typename LoopTy>
452+
static bool isLoopWithIdenticalConfiguration(Operation *target,
453+
Operation *source) {
454+
static_assert(llvm::is_one_of<LoopTy, scf::ForallOp, scf::ForOp,
455+
scf::ParallelOp>::value,
456+
"applies to only `forall`, `for` and `parallel`");
457+
auto targetOp = dyn_cast<LoopTy>(target);
458+
auto sourceOp = dyn_cast<LoopTy>(source);
472459
if (!targetOp || !sourceOp)
473460
return false;
474461

475-
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
476-
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
477-
targetOp.getStep() == sourceOp.getStep();
462+
if constexpr (std::is_same_v<LoopTy, scf::ForallOp>)
463+
return targetOp.getMixedLowerBound() == sourceOp.getMixedLowerBound() &&
464+
targetOp.getMixedUpperBound() == sourceOp.getMixedUpperBound() &&
465+
targetOp.getMixedStep() == sourceOp.getMixedStep() &&
466+
targetOp.getMapping() == sourceOp.getMapping();
467+
else
468+
return targetOp.getLowerBound() == sourceOp.getLowerBound() &&
469+
targetOp.getUpperBound() == sourceOp.getUpperBound() &&
470+
targetOp.getStep() == sourceOp.getStep();
478471
}
479472

480473
DiagnosedSilenceableFailure
@@ -502,12 +495,16 @@ transform::LoopFuseSiblingOp::apply(transform::TransformRewriter &rewriter,
502495

503496
Operation *fusedLoop;
504497
/// TODO: Support fusion for loop-like ops besides scf.for and scf.forall.
505-
if (isForWithIdenticalConfiguration(target, source)) {
498+
if (isLoopWithIdenticalConfiguration<scf::ForOp>(target, source)) {
506499
fusedLoop = fuseIndependentSiblingForLoops(
507500
cast<scf::ForOp>(target), cast<scf::ForOp>(source), rewriter);
508-
} else if (isForallWithIdenticalConfiguration(target, source)) {
501+
} else if (isLoopWithIdenticalConfiguration<scf::ForallOp>(target, source)) {
509502
fusedLoop = fuseIndependentSiblingForallLoops(
510503
cast<scf::ForallOp>(target), cast<scf::ForallOp>(source), rewriter);
504+
} else if (isLoopWithIdenticalConfiguration<scf::ParallelOp>(target,
505+
source)) {
506+
fusedLoop = fuseIndependentSiblingParallelLoops(
507+
cast<scf::ParallelOp>(target), cast<scf::ParallelOp>(source), rewriter);
511508
} else
512509
return emitSilenceableFailure(target->getLoc())
513510
<< "operations cannot be fused";

mlir/lib/Dialect/SCF/Transforms/ParallelLoopFusion.cpp

+2-202
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1717
#include "mlir/Dialect/SCF/IR/SCF.h"
1818
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
19+
#include "mlir/Dialect/SCF/Utils/Utils.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/IRMapping.h"
2122
#include "mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
3031
using namespace mlir;
3132
using namespace mlir::scf;
3233

33-
/// Verify there are no nested ParallelOps.
34-
static bool hasNestedParallelOp(ParallelOp ploop) {
35-
auto walkResult =
36-
ploop.getBody()->walk([](ParallelOp) { return WalkResult::interrupt(); });
37-
return walkResult.wasInterrupted();
38-
}
39-
40-
/// Verify equal iteration spaces.
41-
static bool equalIterationSpaces(ParallelOp firstPloop,
42-
ParallelOp secondPloop) {
43-
if (firstPloop.getNumLoops() != secondPloop.getNumLoops())
44-
return false;
45-
46-
auto matchOperands = [&](const OperandRange &lhs,
47-
const OperandRange &rhs) -> bool {
48-
// TODO: Extend this to support aliases and equal constants.
49-
return std::equal(lhs.begin(), lhs.end(), rhs.begin());
50-
};
51-
return matchOperands(firstPloop.getLowerBound(),
52-
secondPloop.getLowerBound()) &&
53-
matchOperands(firstPloop.getUpperBound(),
54-
secondPloop.getUpperBound()) &&
55-
matchOperands(firstPloop.getStep(), secondPloop.getStep());
56-
}
57-
58-
/// Checks if the parallel loops have mixed access to the same buffers. Returns
59-
/// `true` if the first parallel loop writes to the same indices that the second
60-
/// loop reads.
61-
static bool haveNoReadsAfterWriteExceptSameIndex(
62-
ParallelOp firstPloop, ParallelOp secondPloop,
63-
const IRMapping &firstToSecondPloopIndices,
64-
llvm::function_ref<bool(Value, Value)> mayAlias) {
65-
DenseMap<Value, SmallVector<ValueRange, 1>> bufferStores;
66-
SmallVector<Value> bufferStoresVec;
67-
firstPloop.getBody()->walk([&](memref::StoreOp store) {
68-
bufferStores[store.getMemRef()].push_back(store.getIndices());
69-
bufferStoresVec.emplace_back(store.getMemRef());
70-
});
71-
auto walkResult = secondPloop.getBody()->walk([&](memref::LoadOp load) {
72-
Value loadMem = load.getMemRef();
73-
// Stop if the memref is defined in secondPloop body. Careful alias analysis
74-
// is needed.
75-
auto *memrefDef = loadMem.getDefiningOp();
76-
if (memrefDef && memrefDef->getBlock() == load->getBlock())
77-
return WalkResult::interrupt();
78-
79-
for (Value store : bufferStoresVec)
80-
if (store != loadMem && mayAlias(store, loadMem))
81-
return WalkResult::interrupt();
82-
83-
auto write = bufferStores.find(loadMem);
84-
if (write == bufferStores.end())
85-
return WalkResult::advance();
86-
87-
// Check that at last one store was retrieved
88-
if (!write->second.size())
89-
return WalkResult::interrupt();
90-
91-
auto storeIndices = write->second.front();
92-
93-
// Multiple writes to the same memref are allowed only on the same indices
94-
for (const auto &othStoreIndices : write->second) {
95-
if (othStoreIndices != storeIndices)
96-
return WalkResult::interrupt();
97-
}
98-
99-
// Check that the load indices of secondPloop coincide with store indices of
100-
// firstPloop for the same memrefs.
101-
auto loadIndices = load.getIndices();
102-
if (storeIndices.size() != loadIndices.size())
103-
return WalkResult::interrupt();
104-
for (int i = 0, e = storeIndices.size(); i < e; ++i) {
105-
if (firstToSecondPloopIndices.lookupOrDefault(storeIndices[i]) !=
106-
loadIndices[i]) {
107-
auto *storeIndexDefOp = storeIndices[i].getDefiningOp();
108-
auto *loadIndexDefOp = loadIndices[i].getDefiningOp();
109-
if (storeIndexDefOp && loadIndexDefOp) {
110-
if (!isMemoryEffectFree(storeIndexDefOp))
111-
return WalkResult::interrupt();
112-
if (!isMemoryEffectFree(loadIndexDefOp))
113-
return WalkResult::interrupt();
114-
if (!OperationEquivalence::isEquivalentTo(
115-
storeIndexDefOp, loadIndexDefOp,
116-
[&](Value storeIndex, Value loadIndex) {
117-
if (firstToSecondPloopIndices.lookupOrDefault(storeIndex) !=
118-
firstToSecondPloopIndices.lookupOrDefault(loadIndex))
119-
return failure();
120-
else
121-
return success();
122-
},
123-
/*markEquivalent=*/nullptr,
124-
OperationEquivalence::Flags::IgnoreLocations)) {
125-
return WalkResult::interrupt();
126-
}
127-
} else
128-
return WalkResult::interrupt();
129-
}
130-
}
131-
return WalkResult::advance();
132-
});
133-
return !walkResult.wasInterrupted();
134-
}
135-
136-
/// Analyzes dependencies in the most primitive way by checking simple read and
137-
/// write patterns.
138-
static LogicalResult
139-
verifyDependencies(ParallelOp firstPloop, ParallelOp secondPloop,
140-
const IRMapping &firstToSecondPloopIndices,
141-
llvm::function_ref<bool(Value, Value)> mayAlias) {
142-
if (!haveNoReadsAfterWriteExceptSameIndex(
143-
firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144-
return failure();
145-
146-
IRMapping secondToFirstPloopIndices;
147-
secondToFirstPloopIndices.map(secondPloop.getBody()->getArguments(),
148-
firstPloop.getBody()->getArguments());
149-
return success(haveNoReadsAfterWriteExceptSameIndex(
150-
secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151-
}
152-
153-
static bool isFusionLegal(ParallelOp firstPloop, ParallelOp secondPloop,
154-
const IRMapping &firstToSecondPloopIndices,
155-
llvm::function_ref<bool(Value, Value)> mayAlias) {
156-
return !hasNestedParallelOp(firstPloop) &&
157-
!hasNestedParallelOp(secondPloop) &&
158-
equalIterationSpaces(firstPloop, secondPloop) &&
159-
succeeded(verifyDependencies(firstPloop, secondPloop,
160-
firstToSecondPloopIndices, mayAlias));
161-
}
162-
163-
/// Prepends operations of firstPloop's body into secondPloop's body.
164-
/// Updates secondPloop with new loop.
165-
static void fuseIfLegal(ParallelOp firstPloop, ParallelOp &secondPloop,
166-
OpBuilder builder,
167-
llvm::function_ref<bool(Value, Value)> mayAlias) {
168-
Block *block1 = firstPloop.getBody();
169-
Block *block2 = secondPloop.getBody();
170-
IRMapping firstToSecondPloopIndices;
171-
firstToSecondPloopIndices.map(block1->getArguments(), block2->getArguments());
172-
173-
if (!isFusionLegal(firstPloop, secondPloop, firstToSecondPloopIndices,
174-
mayAlias))
175-
return;
176-
177-
DominanceInfo dom;
178-
// We are fusing first loop into second, make sure there are no users of the
179-
// first loop results between loops.
180-
for (Operation *user : firstPloop->getUsers())
181-
if (!dom.properlyDominates(secondPloop, user, /*enclosingOpOk*/ false))
182-
return;
183-
184-
ValueRange inits1 = firstPloop.getInitVals();
185-
ValueRange inits2 = secondPloop.getInitVals();
186-
187-
SmallVector<Value> newInitVars(inits1.begin(), inits1.end());
188-
newInitVars.append(inits2.begin(), inits2.end());
189-
190-
IRRewriter b(builder);
191-
b.setInsertionPoint(secondPloop);
192-
auto newSecondPloop = b.create<ParallelOp>(
193-
secondPloop.getLoc(), secondPloop.getLowerBound(),
194-
secondPloop.getUpperBound(), secondPloop.getStep(), newInitVars);
195-
196-
Block *newBlock = newSecondPloop.getBody();
197-
auto term1 = cast<ReduceOp>(block1->getTerminator());
198-
auto term2 = cast<ReduceOp>(block2->getTerminator());
199-
200-
b.inlineBlockBefore(block2, newBlock, newBlock->begin(),
201-
newBlock->getArguments());
202-
b.inlineBlockBefore(block1, newBlock, newBlock->begin(),
203-
newBlock->getArguments());
204-
205-
ValueRange results = newSecondPloop.getResults();
206-
if (!results.empty()) {
207-
b.setInsertionPointToEnd(newBlock);
208-
209-
ValueRange reduceArgs1 = term1.getOperands();
210-
ValueRange reduceArgs2 = term2.getOperands();
211-
SmallVector<Value> newReduceArgs(reduceArgs1.begin(), reduceArgs1.end());
212-
newReduceArgs.append(reduceArgs2.begin(), reduceArgs2.end());
213-
214-
auto newReduceOp = b.create<scf::ReduceOp>(term2.getLoc(), newReduceArgs);
215-
216-
for (auto &&[i, reg] : llvm::enumerate(llvm::concat<Region>(
217-
term1.getReductions(), term2.getReductions()))) {
218-
Block &oldRedBlock = reg.front();
219-
Block &newRedBlock = newReduceOp.getReductions()[i].front();
220-
b.inlineBlockBefore(&oldRedBlock, &newRedBlock, newRedBlock.begin(),
221-
newRedBlock.getArguments());
222-
}
223-
224-
firstPloop.replaceAllUsesWith(results.take_front(inits1.size()));
225-
secondPloop.replaceAllUsesWith(results.take_back(inits2.size()));
226-
}
227-
term1->erase();
228-
term2->erase();
229-
firstPloop.erase();
230-
secondPloop.erase();
231-
secondPloop = newSecondPloop;
232-
}
233-
23434
void mlir::scf::naivelyFuseParallelOps(
23535
Region &region, llvm::function_ref<bool(Value, Value)> mayAlias) {
23636
OpBuilder b(region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
25959
}
26060
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
26161
for (int i = 0, e = ploops.size(); i + 1 < e; ++i)
262-
fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
62+
mlir::fuseIfLegal(ploops[i], ploops[i + 1], b, mayAlias);
26363
}
26464
}
26565
}

0 commit comments

Comments
 (0)