Skip to content

Commit 5079194

Browse files
committed
[MLIR] Fix arbitrary checks in affine LICM
Fix arbitrary checks in affine LICM. Drop unnecessary (too much) debug logging. This pass is still unsound due to not handling aliases. This will have to be handled later. Add/update comments.
1 parent 70b9440 commit 5079194

File tree

2 files changed

+53
-94
lines changed

2 files changed

+53
-94
lines changed

mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp

Lines changed: 50 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,9 @@
1212

1313
#include "mlir/Dialect/Affine/Passes.h"
1414

15-
#include "mlir/Analysis/SliceAnalysis.h"
16-
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17-
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
18-
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1915
#include "mlir/Dialect/Affine/Analysis/Utils.h"
20-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
21-
#include "mlir/Dialect/Affine/LoopUtils.h"
22-
#include "mlir/Dialect/Affine/Utils.h"
23-
#include "mlir/Dialect/Arith/IR/Arith.h"
2416
#include "mlir/Dialect/Func/IR/FuncOps.h"
25-
#include "mlir/IR/AffineExpr.h"
26-
#include "mlir/IR/AffineMap.h"
27-
#include "mlir/IR/Builders.h"
28-
#include "mlir/IR/Matchers.h"
2917
#include "mlir/Interfaces/SideEffectInterfaces.h"
30-
#include "llvm/ADT/DenseMap.h"
31-
#include "llvm/ADT/DenseSet.h"
32-
#include "llvm/ADT/SmallPtrSet.h"
33-
#include "llvm/Support/CommandLine.h"
3418
#include "llvm/Support/Debug.h"
3519
#include "llvm/Support/raw_ostream.h"
3620

@@ -41,17 +25,21 @@ namespace affine {
4125
} // namespace affine
4226
} // namespace mlir
4327

44-
#define DEBUG_TYPE "licm"
28+
#define DEBUG_TYPE "affine-licm"
4529

4630
using namespace mlir;
4731
using namespace mlir::affine;
4832

4933
namespace {
5034

5135
/// Affine loop invariant code motion (LICM) pass.
52-
/// TODO: The pass is missing zero-trip tests.
53-
/// TODO: This code should be removed once the new LICM pass can handle its
54-
/// uses.
36+
/// TODO: The pass is missing zero tripcount tests.
37+
/// TODO: When compared to the other standard LICM pass, this pass
38+
/// has some special handling for affine read/write ops but such handling
39+
/// requires aliasing to be sound, and as such this pass is unsound. In
40+
/// addition, this handling is nothing particular to affine memory ops but would
41+
/// apply to any memory read/write effect ops. Either aliasing should be handled
42+
/// or this pass can be removed and the standard LICM can be used.
5543
struct LoopInvariantCodeMotion
5644
: public affine::impl::AffineLoopInvariantCodeMotionBase<
5745
LoopInvariantCodeMotion> {
@@ -61,100 +49,80 @@ struct LoopInvariantCodeMotion
6149
} // namespace
6250

6351
static bool
64-
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
52+
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
6553
SmallPtrSetImpl<Operation *> &opsWithUsers,
6654
SmallPtrSetImpl<Operation *> &opsToHoist);
67-
static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
55+
static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
6856
SmallPtrSetImpl<Operation *> &opsWithUsers,
6957
SmallPtrSetImpl<Operation *> &opsToHoist);
7058

7159
static bool
72-
areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
73-
ValueRange iterArgs,
60+
areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
7461
SmallPtrSetImpl<Operation *> &opsWithUsers,
7562
SmallPtrSetImpl<Operation *> &opsToHoist);
7663

77-
// Returns true if the individual op is loop invariant.
78-
static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
64+
/// Returns true if `op` is invariant on `loop`.
65+
static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
7966
SmallPtrSetImpl<Operation *> &opsWithUsers,
8067
SmallPtrSetImpl<Operation *> &opsToHoist) {
81-
LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
68+
Value iv = loop.getInductionVar();
8269

8370
if (auto ifOp = dyn_cast<AffineIfOp>(op)) {
84-
if (!checkInvarianceOfNestedIfOps(ifOp, indVar, iterArgs, opsWithUsers,
85-
opsToHoist))
71+
if (!checkInvarianceOfNestedIfOps(ifOp, loop, opsWithUsers, opsToHoist))
8672
return false;
8773
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
88-
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
89-
opsWithUsers, opsToHoist))
74+
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), loop, opsWithUsers,
75+
opsToHoist))
9076
return false;
9177
} else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
92-
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
93-
opsWithUsers, opsToHoist))
78+
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), loop, opsWithUsers,
79+
opsToHoist))
9480
return false;
9581
} else if (!isMemoryEffectFree(&op) &&
96-
!isa<AffineReadOpInterface, AffineWriteOpInterface,
97-
AffinePrefetchOp>(&op)) {
82+
!isa<AffineReadOpInterface, AffineWriteOpInterface>(&op)) {
9883
// Check for side-effecting ops. Affine read/write ops are handled
9984
// separately below.
10085
return false;
101-
} else if (!matchPattern(&op, m_Constant())) {
86+
} else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
10287
// Register op in the set of ops that have users.
10388
opsWithUsers.insert(&op);
104-
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
105-
auto read = dyn_cast<AffineReadOpInterface>(op);
106-
Value memref = read ? read.getMemRef()
107-
: cast<AffineWriteOpInterface>(op).getMemRef();
108-
for (auto *user : memref.getUsers()) {
109-
// If this memref has a user that is a DMA, give up because these
110-
// operations write to this memref.
111-
if (isa<AffineDmaStartOp, AffineDmaWaitOp>(user))
89+
SmallVector<AffineForOp, 8> userIVs;
90+
auto read = dyn_cast<AffineReadOpInterface>(op);
91+
Value memref =
92+
read ? read.getMemRef() : cast<AffineWriteOpInterface>(op).getMemRef();
93+
for (auto *user : memref.getUsers()) {
94+
// If the memref used by the load/store is used in a store elsewhere in
95+
// the loop nest, we do not hoist. Similarly, if the memref used in a
96+
// load is also being stored too, we do not hoist the load.
97+
// FIXME: This is missing checking aliases.
98+
if (&op == user)
99+
continue;
100+
if (hasEffect<MemoryEffects::Write>(user, memref) ||
101+
(hasEffect<MemoryEffects::Read>(user, memref) &&
102+
isa<AffineWriteOpInterface>(op))) {
103+
userIVs.clear();
104+
getAffineForIVs(*user, &userIVs);
105+
// Check that userIVs don't contain the for loop around the op.
106+
if (llvm::is_contained(userIVs, loop))
112107
return false;
113-
// If the memref used by the load/store is used in a store elsewhere in
114-
// the loop nest, we do not hoist. Similarly, if the memref used in a
115-
// load is also being stored too, we do not hoist the load.
116-
if (isa<AffineWriteOpInterface>(user) ||
117-
(isa<AffineReadOpInterface>(user) &&
118-
isa<AffineWriteOpInterface>(op))) {
119-
if (&op != user) {
120-
SmallVector<AffineForOp, 8> userIVs;
121-
getAffineForIVs(*user, &userIVs);
122-
// Check that userIVs don't contain the for loop around the op.
123-
if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar)))
124-
return false;
125-
}
126-
}
127108
}
128109
}
129-
130-
if (op.getNumOperands() == 0 && !isa<AffineYieldOp>(op)) {
131-
LLVM_DEBUG(llvm::dbgs() << "Non-constant op with 0 operands\n");
132-
return false;
133-
}
134110
}
135111

136112
// Check operands.
113+
ValueRange iterArgs = loop.getRegionIterArgs();
137114
for (unsigned int i = 0; i < op.getNumOperands(); ++i) {
138115
auto *operandSrc = op.getOperand(i).getDefiningOp();
139116

140-
LLVM_DEBUG(
141-
op.getOperand(i).print(llvm::dbgs() << "Iterating on operand\n"));
142-
143117
// If the loop IV is the operand, this op isn't loop invariant.
144-
if (indVar == op.getOperand(i)) {
145-
LLVM_DEBUG(llvm::dbgs() << "Loop IV is the operand\n");
118+
if (iv == op.getOperand(i))
146119
return false;
147-
}
148120

149121
// If the one of the iter_args is the operand, this op isn't loop invariant.
150-
if (llvm::is_contained(iterArgs, op.getOperand(i))) {
151-
LLVM_DEBUG(llvm::dbgs() << "One of the iter_args is the operand\n");
122+
if (llvm::is_contained(iterArgs, op.getOperand(i)))
152123
return false;
153-
}
154124

155125
if (operandSrc) {
156-
LLVM_DEBUG(llvm::dbgs() << *operandSrc << "Iterating on operand src\n");
157-
158126
// If the value was defined in the loop (outside of the if/else region),
159127
// and that operation itself wasn't meant to be hoisted, then mark this
160128
// operation loop dependent.
@@ -170,14 +138,13 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
170138

171139
// Checks if all ops in a region (i.e. list of blocks) are loop invariant.
172140
static bool
173-
areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
174-
ValueRange iterArgs,
141+
areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
175142
SmallPtrSetImpl<Operation *> &opsWithUsers,
176143
SmallPtrSetImpl<Operation *> &opsToHoist) {
177144

178145
for (auto &b : blockList) {
179146
for (auto &op : b) {
180-
if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist))
147+
if (!isOpLoopInvariant(op, loop, opsWithUsers, opsToHoist))
181148
return false;
182149
}
183150
}
@@ -187,40 +154,36 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
187154

188155
// Returns true if the affine.if op can be hoisted.
189156
static bool
190-
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
157+
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
191158
SmallPtrSetImpl<Operation *> &opsWithUsers,
192159
SmallPtrSetImpl<Operation *> &opsToHoist) {
193-
if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), indVar, iterArgs,
160+
if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), loop,
194161
opsWithUsers, opsToHoist))
195162
return false;
196163

197-
if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), indVar, iterArgs,
164+
if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), loop,
198165
opsWithUsers, opsToHoist))
199166
return false;
200167

201168
return true;
202169
}
203170

204171
void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
205-
auto *loopBody = forOp.getBody();
206-
auto indVar = forOp.getInductionVar();
207-
ValueRange iterArgs = forOp.getRegionIterArgs();
208-
209172
// This is the place where hoisted instructions would reside.
210173
OpBuilder b(forOp.getOperation());
211174

212175
SmallPtrSet<Operation *, 8> opsToHoist;
213176
SmallVector<Operation *, 8> opsToMove;
214177
SmallPtrSet<Operation *, 8> opsWithUsers;
215178

216-
for (auto &op : *loopBody) {
179+
for (Operation &op : *forOp.getBody()) {
217180
// Register op in the set of ops that have users. This set is used
218181
// to prevent hoisting ops that depend on these ops that are
219182
// not being hoisted.
220183
if (!op.use_empty())
221184
opsWithUsers.insert(&op);
222185
if (!isa<AffineYieldOp>(op)) {
223-
if (isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) {
186+
if (isOpLoopInvariant(op, forOp, opsWithUsers, opsToHoist)) {
224187
opsToMove.push_back(&op);
225188
}
226189
}
@@ -231,18 +194,13 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
231194
for (auto *op : opsToMove) {
232195
op->moveBefore(forOp);
233196
}
234-
235-
LLVM_DEBUG(forOp->print(llvm::dbgs() << "Modified loop\n"));
236197
}
237198

238199
void LoopInvariantCodeMotion::runOnOperation() {
239200
// Walk through all loops in a function in innermost-loop-first order. This
240201
// way, we first LICM from the inner loop, and place the ops in
241202
// the outer loop, which in turn can be further LICM'ed.
242-
getOperation().walk([&](AffineForOp op) {
243-
LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n"));
244-
runOnAffineForOp(op);
245-
});
203+
getOperation().walk([&](AffineForOp op) { runOnAffineForOp(op); });
246204
}
247205

248206
std::unique_ptr<OperationPass<func::FuncOp>>

mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,15 +855,16 @@ func.func @affine_prefetch_invariant() {
855855
affine.for %i0 = 0 to 10 {
856856
affine.for %i1 = 0 to 10 {
857857
%1 = affine.load %0[%i0, %i1] : memref<10x10xf32>
858+
// A prefetch shouldn't be hoisted.
858859
affine.prefetch %0[%i0, %i0], write, locality<0>, data : memref<10x10xf32>
859860
}
860861
}
861862

862863
// CHECK: memref.alloc() : memref<10x10xf32>
863864
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
864-
// CHECK-NEXT: affine.prefetch
865865
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
866-
// CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}} : memref<10x10xf32>
866+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} : memref<10x10xf32>
867+
// CHECK-NEXT: affine.prefetch
867868
// CHECK-NEXT: }
868869
// CHECK-NEXT: }
869870
return

0 commit comments

Comments
 (0)