16
16
#include " mlir/Dialect/MemRef/IR/MemRef.h"
17
17
#include " mlir/Dialect/SCF/IR/SCF.h"
18
18
#include " mlir/Dialect/SCF/Transforms/Transforms.h"
19
+ #include " mlir/Dialect/SCF/Utils/Utils.h"
19
20
#include " mlir/IR/Builders.h"
20
21
#include " mlir/IR/IRMapping.h"
21
22
#include " mlir/IR/OpDefinition.h"
@@ -30,207 +31,6 @@ namespace mlir {
30
31
using namespace mlir ;
31
32
using namespace mlir ::scf;
32
33
33
- // / Verify there are no nested ParallelOps.
34
- static bool hasNestedParallelOp (ParallelOp ploop) {
35
- auto walkResult =
36
- ploop.getBody ()->walk ([](ParallelOp) { return WalkResult::interrupt (); });
37
- return walkResult.wasInterrupted ();
38
- }
39
-
40
- // / Verify equal iteration spaces.
41
- static bool equalIterationSpaces (ParallelOp firstPloop,
42
- ParallelOp secondPloop) {
43
- if (firstPloop.getNumLoops () != secondPloop.getNumLoops ())
44
- return false ;
45
-
46
- auto matchOperands = [&](const OperandRange &lhs,
47
- const OperandRange &rhs) -> bool {
48
- // TODO: Extend this to support aliases and equal constants.
49
- return std::equal (lhs.begin (), lhs.end (), rhs.begin ());
50
- };
51
- return matchOperands (firstPloop.getLowerBound (),
52
- secondPloop.getLowerBound ()) &&
53
- matchOperands (firstPloop.getUpperBound (),
54
- secondPloop.getUpperBound ()) &&
55
- matchOperands (firstPloop.getStep (), secondPloop.getStep ());
56
- }
57
-
58
- // / Checks if the parallel loops have mixed access to the same buffers. Returns
59
- // / `true` if the first parallel loop writes to the same indices that the second
60
- // / loop reads.
61
- static bool haveNoReadsAfterWriteExceptSameIndex (
62
- ParallelOp firstPloop, ParallelOp secondPloop,
63
- const IRMapping &firstToSecondPloopIndices,
64
- llvm::function_ref<bool (Value, Value)> mayAlias) {
65
- DenseMap<Value, SmallVector<ValueRange, 1 >> bufferStores;
66
- SmallVector<Value> bufferStoresVec;
67
- firstPloop.getBody ()->walk ([&](memref::StoreOp store) {
68
- bufferStores[store.getMemRef ()].push_back (store.getIndices ());
69
- bufferStoresVec.emplace_back (store.getMemRef ());
70
- });
71
- auto walkResult = secondPloop.getBody ()->walk ([&](memref::LoadOp load) {
72
- Value loadMem = load.getMemRef ();
73
- // Stop if the memref is defined in secondPloop body. Careful alias analysis
74
- // is needed.
75
- auto *memrefDef = loadMem.getDefiningOp ();
76
- if (memrefDef && memrefDef->getBlock () == load->getBlock ())
77
- return WalkResult::interrupt ();
78
-
79
- for (Value store : bufferStoresVec)
80
- if (store != loadMem && mayAlias (store, loadMem))
81
- return WalkResult::interrupt ();
82
-
83
- auto write = bufferStores.find (loadMem);
84
- if (write == bufferStores.end ())
85
- return WalkResult::advance ();
86
-
87
- // Check that at last one store was retrieved
88
- if (!write ->second .size ())
89
- return WalkResult::interrupt ();
90
-
91
- auto storeIndices = write ->second .front ();
92
-
93
- // Multiple writes to the same memref are allowed only on the same indices
94
- for (const auto &othStoreIndices : write ->second ) {
95
- if (othStoreIndices != storeIndices)
96
- return WalkResult::interrupt ();
97
- }
98
-
99
- // Check that the load indices of secondPloop coincide with store indices of
100
- // firstPloop for the same memrefs.
101
- auto loadIndices = load.getIndices ();
102
- if (storeIndices.size () != loadIndices.size ())
103
- return WalkResult::interrupt ();
104
- for (int i = 0 , e = storeIndices.size (); i < e; ++i) {
105
- if (firstToSecondPloopIndices.lookupOrDefault (storeIndices[i]) !=
106
- loadIndices[i]) {
107
- auto *storeIndexDefOp = storeIndices[i].getDefiningOp ();
108
- auto *loadIndexDefOp = loadIndices[i].getDefiningOp ();
109
- if (storeIndexDefOp && loadIndexDefOp) {
110
- if (!isMemoryEffectFree (storeIndexDefOp))
111
- return WalkResult::interrupt ();
112
- if (!isMemoryEffectFree (loadIndexDefOp))
113
- return WalkResult::interrupt ();
114
- if (!OperationEquivalence::isEquivalentTo (
115
- storeIndexDefOp, loadIndexDefOp,
116
- [&](Value storeIndex, Value loadIndex) {
117
- if (firstToSecondPloopIndices.lookupOrDefault (storeIndex) !=
118
- firstToSecondPloopIndices.lookupOrDefault (loadIndex))
119
- return failure ();
120
- else
121
- return success ();
122
- },
123
- /* markEquivalent=*/ nullptr ,
124
- OperationEquivalence::Flags::IgnoreLocations)) {
125
- return WalkResult::interrupt ();
126
- }
127
- } else
128
- return WalkResult::interrupt ();
129
- }
130
- }
131
- return WalkResult::advance ();
132
- });
133
- return !walkResult.wasInterrupted ();
134
- }
135
-
136
- // / Analyzes dependencies in the most primitive way by checking simple read and
137
- // / write patterns.
138
- static LogicalResult
139
- verifyDependencies (ParallelOp firstPloop, ParallelOp secondPloop,
140
- const IRMapping &firstToSecondPloopIndices,
141
- llvm::function_ref<bool (Value, Value)> mayAlias) {
142
- if (!haveNoReadsAfterWriteExceptSameIndex (
143
- firstPloop, secondPloop, firstToSecondPloopIndices, mayAlias))
144
- return failure ();
145
-
146
- IRMapping secondToFirstPloopIndices;
147
- secondToFirstPloopIndices.map (secondPloop.getBody ()->getArguments (),
148
- firstPloop.getBody ()->getArguments ());
149
- return success (haveNoReadsAfterWriteExceptSameIndex (
150
- secondPloop, firstPloop, secondToFirstPloopIndices, mayAlias));
151
- }
152
-
153
- static bool isFusionLegal (ParallelOp firstPloop, ParallelOp secondPloop,
154
- const IRMapping &firstToSecondPloopIndices,
155
- llvm::function_ref<bool (Value, Value)> mayAlias) {
156
- return !hasNestedParallelOp (firstPloop) &&
157
- !hasNestedParallelOp (secondPloop) &&
158
- equalIterationSpaces (firstPloop, secondPloop) &&
159
- succeeded (verifyDependencies (firstPloop, secondPloop,
160
- firstToSecondPloopIndices, mayAlias));
161
- }
162
-
163
- // / Prepends operations of firstPloop's body into secondPloop's body.
164
- // / Updates secondPloop with new loop.
165
- static void fuseIfLegal (ParallelOp firstPloop, ParallelOp &secondPloop,
166
- OpBuilder builder,
167
- llvm::function_ref<bool (Value, Value)> mayAlias) {
168
- Block *block1 = firstPloop.getBody ();
169
- Block *block2 = secondPloop.getBody ();
170
- IRMapping firstToSecondPloopIndices;
171
- firstToSecondPloopIndices.map (block1->getArguments (), block2->getArguments ());
172
-
173
- if (!isFusionLegal (firstPloop, secondPloop, firstToSecondPloopIndices,
174
- mayAlias))
175
- return ;
176
-
177
- DominanceInfo dom;
178
- // We are fusing first loop into second, make sure there are no users of the
179
- // first loop results between loops.
180
- for (Operation *user : firstPloop->getUsers ())
181
- if (!dom.properlyDominates (secondPloop, user, /* enclosingOpOk*/ false ))
182
- return ;
183
-
184
- ValueRange inits1 = firstPloop.getInitVals ();
185
- ValueRange inits2 = secondPloop.getInitVals ();
186
-
187
- SmallVector<Value> newInitVars (inits1.begin (), inits1.end ());
188
- newInitVars.append (inits2.begin (), inits2.end ());
189
-
190
- IRRewriter b (builder);
191
- b.setInsertionPoint (secondPloop);
192
- auto newSecondPloop = b.create <ParallelOp>(
193
- secondPloop.getLoc (), secondPloop.getLowerBound (),
194
- secondPloop.getUpperBound (), secondPloop.getStep (), newInitVars);
195
-
196
- Block *newBlock = newSecondPloop.getBody ();
197
- auto term1 = cast<ReduceOp>(block1->getTerminator ());
198
- auto term2 = cast<ReduceOp>(block2->getTerminator ());
199
-
200
- b.inlineBlockBefore (block2, newBlock, newBlock->begin (),
201
- newBlock->getArguments ());
202
- b.inlineBlockBefore (block1, newBlock, newBlock->begin (),
203
- newBlock->getArguments ());
204
-
205
- ValueRange results = newSecondPloop.getResults ();
206
- if (!results.empty ()) {
207
- b.setInsertionPointToEnd (newBlock);
208
-
209
- ValueRange reduceArgs1 = term1.getOperands ();
210
- ValueRange reduceArgs2 = term2.getOperands ();
211
- SmallVector<Value> newReduceArgs (reduceArgs1.begin (), reduceArgs1.end ());
212
- newReduceArgs.append (reduceArgs2.begin (), reduceArgs2.end ());
213
-
214
- auto newReduceOp = b.create <scf::ReduceOp>(term2.getLoc (), newReduceArgs);
215
-
216
- for (auto &&[i, reg] : llvm::enumerate (llvm::concat<Region>(
217
- term1.getReductions (), term2.getReductions ()))) {
218
- Block &oldRedBlock = reg.front ();
219
- Block &newRedBlock = newReduceOp.getReductions ()[i].front ();
220
- b.inlineBlockBefore (&oldRedBlock, &newRedBlock, newRedBlock.begin (),
221
- newRedBlock.getArguments ());
222
- }
223
-
224
- firstPloop.replaceAllUsesWith (results.take_front (inits1.size ()));
225
- secondPloop.replaceAllUsesWith (results.take_back (inits2.size ()));
226
- }
227
- term1->erase ();
228
- term2->erase ();
229
- firstPloop.erase ();
230
- secondPloop.erase ();
231
- secondPloop = newSecondPloop;
232
- }
233
-
234
34
void mlir::scf::naivelyFuseParallelOps (
235
35
Region ®ion, llvm::function_ref<bool (Value, Value)> mayAlias) {
236
36
OpBuilder b (region);
@@ -259,7 +59,7 @@ void mlir::scf::naivelyFuseParallelOps(
259
59
}
260
60
for (MutableArrayRef<ParallelOp> ploops : ploopChains) {
261
61
for (int i = 0 , e = ploops.size (); i + 1 < e; ++i)
262
- fuseIfLegal (ploops[i], ploops[i + 1 ], b, mayAlias);
62
+ mlir:: fuseIfLegal (ploops[i], ploops[i + 1 ], b, mayAlias);
263
63
}
264
64
}
265
65
}
0 commit comments