Skip to content

Commit d201507

Browse files
committed
[MLIR] Fix arbitrary checks in affine LICM
Fix arbitrary checks in affine LICM. Drop unnecessary (too much) debug logging. This pass is still unsound due to not handling aliases. This will have to be handled later. Add/update comments.
1 parent 70b9440 commit d201507

File tree

2 files changed

+52
-84
lines changed

2 files changed

+52
-84
lines changed

mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp

Lines changed: 49 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,9 @@
1212

1313
#include "mlir/Dialect/Affine/Passes.h"
1414

15-
#include "mlir/Analysis/SliceAnalysis.h"
16-
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
17-
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
18-
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
1915
#include "mlir/Dialect/Affine/Analysis/Utils.h"
20-
#include "mlir/Dialect/Affine/IR/AffineOps.h"
21-
#include "mlir/Dialect/Affine/LoopUtils.h"
22-
#include "mlir/Dialect/Affine/Utils.h"
23-
#include "mlir/Dialect/Arith/IR/Arith.h"
2416
#include "mlir/Dialect/Func/IR/FuncOps.h"
25-
#include "mlir/IR/AffineExpr.h"
26-
#include "mlir/IR/AffineMap.h"
27-
#include "mlir/IR/Builders.h"
28-
#include "mlir/IR/Matchers.h"
2917
#include "mlir/Interfaces/SideEffectInterfaces.h"
30-
#include "llvm/ADT/DenseMap.h"
31-
#include "llvm/ADT/DenseSet.h"
32-
#include "llvm/ADT/SmallPtrSet.h"
33-
#include "llvm/Support/CommandLine.h"
3418
#include "llvm/Support/Debug.h"
3519
#include "llvm/Support/raw_ostream.h"
3620

@@ -41,17 +25,21 @@ namespace affine {
4125
} // namespace affine
4226
} // namespace mlir
4327

44-
#define DEBUG_TYPE "licm"
28+
#define DEBUG_TYPE "affine-licm"
4529

4630
using namespace mlir;
4731
using namespace mlir::affine;
4832

4933
namespace {
5034

5135
/// Affine loop invariant code motion (LICM) pass.
52-
/// TODO: The pass is missing zero-trip tests.
53-
/// TODO: This code should be removed once the new LICM pass can handle its
54-
/// uses.
36+
/// TODO: The pass is missing zero tripcount tests.
37+
/// TODO: When compared to the other standard LICM pass, this pass
38+
/// has some special handling for affine read/write ops but such handling
39+
/// requires aliasing to be sound, and as such this pass is unsound. In
40+
/// addition, this handling is nothing particular to affine memory ops but would
41+
/// apply to any memory read/write effect ops. Either aliasing should be handled
42+
/// or this pass can be removed and the standard LICM can be used.
5543
struct LoopInvariantCodeMotion
5644
: public affine::impl::AffineLoopInvariantCodeMotionBase<
5745
LoopInvariantCodeMotion> {
@@ -61,100 +49,84 @@ struct LoopInvariantCodeMotion
6149
} // namespace
6250

6351
static bool
64-
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
52+
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
53+
ValueRange iterArgs,
6554
SmallPtrSetImpl<Operation *> &opsWithUsers,
6655
SmallPtrSetImpl<Operation *> &opsToHoist);
67-
static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
56+
static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
57+
ValueRange iterArgs,
6858
SmallPtrSetImpl<Operation *> &opsWithUsers,
6959
SmallPtrSetImpl<Operation *> &opsToHoist);
7060

7161
static bool
72-
areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
62+
areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
7363
ValueRange iterArgs,
7464
SmallPtrSetImpl<Operation *> &opsWithUsers,
7565
SmallPtrSetImpl<Operation *> &opsToHoist);
7666

7767
// Returns true if the individual op is loop invariant.
78-
static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
68+
static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
69+
ValueRange iterArgs,
7970
SmallPtrSetImpl<Operation *> &opsWithUsers,
8071
SmallPtrSetImpl<Operation *> &opsToHoist) {
81-
LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
72+
Value iv = loop.getInductionVar();
8273

8374
if (auto ifOp = dyn_cast<AffineIfOp>(op)) {
84-
if (!checkInvarianceOfNestedIfOps(ifOp, indVar, iterArgs, opsWithUsers,
75+
if (!checkInvarianceOfNestedIfOps(ifOp, loop, iterArgs, opsWithUsers,
8576
opsToHoist))
8677
return false;
8778
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
88-
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
79+
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), loop, iterArgs,
8980
opsWithUsers, opsToHoist))
9081
return false;
9182
} else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
92-
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
83+
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), loop, iterArgs,
9384
opsWithUsers, opsToHoist))
9485
return false;
9586
} else if (!isMemoryEffectFree(&op) &&
96-
!isa<AffineReadOpInterface, AffineWriteOpInterface,
97-
AffinePrefetchOp>(&op)) {
87+
!isa<AffineReadOpInterface, AffineWriteOpInterface>(&op)) {
9888
// Check for side-effecting ops. Affine read/write ops are handled
9989
// separately below.
10090
return false;
101-
} else if (!matchPattern(&op, m_Constant())) {
91+
} else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
10292
// Register op in the set of ops that have users.
10393
opsWithUsers.insert(&op);
104-
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
105-
auto read = dyn_cast<AffineReadOpInterface>(op);
106-
Value memref = read ? read.getMemRef()
107-
: cast<AffineWriteOpInterface>(op).getMemRef();
108-
for (auto *user : memref.getUsers()) {
109-
// If this memref has a user that is a DMA, give up because these
110-
// operations write to this memref.
111-
if (isa<AffineDmaStartOp, AffineDmaWaitOp>(user))
94+
SmallVector<AffineForOp, 8> userIVs;
95+
auto read = dyn_cast<AffineReadOpInterface>(op);
96+
Value memref =
97+
read ? read.getMemRef() : cast<AffineWriteOpInterface>(op).getMemRef();
98+
for (auto *user : memref.getUsers()) {
99+
// If the memref used by the load/store is used in a store elsewhere in
100+
// the loop nest, we do not hoist. Similarly, if the memref used in a
101+
// load is also being stored too, we do not hoist the load.
102+
// FIXME: This is missing checking aliases.
103+
if (&op == user)
104+
continue;
105+
if (hasEffect<MemoryEffects::Write>(user, memref) ||
106+
(hasEffect<MemoryEffects::Read>(user, memref) &&
107+
isa<AffineWriteOpInterface>(op))) {
108+
userIVs.clear();
109+
getAffineForIVs(*user, &userIVs);
110+
// Check that userIVs don't contain the for loop around the op.
111+
if (llvm::is_contained(userIVs, loop))
112112
return false;
113-
// If the memref used by the load/store is used in a store elsewhere in
114-
// the loop nest, we do not hoist. Similarly, if the memref used in a
115-
// load is also being stored too, we do not hoist the load.
116-
if (isa<AffineWriteOpInterface>(user) ||
117-
(isa<AffineReadOpInterface>(user) &&
118-
isa<AffineWriteOpInterface>(op))) {
119-
if (&op != user) {
120-
SmallVector<AffineForOp, 8> userIVs;
121-
getAffineForIVs(*user, &userIVs);
122-
// Check that userIVs don't contain the for loop around the op.
123-
if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar)))
124-
return false;
125-
}
126-
}
127113
}
128114
}
129-
130-
if (op.getNumOperands() == 0 && !isa<AffineYieldOp>(op)) {
131-
LLVM_DEBUG(llvm::dbgs() << "Non-constant op with 0 operands\n");
132-
return false;
133-
}
134115
}
135116

136117
// Check operands.
137118
for (unsigned int i = 0; i < op.getNumOperands(); ++i) {
138119
auto *operandSrc = op.getOperand(i).getDefiningOp();
139120

140-
LLVM_DEBUG(
141-
op.getOperand(i).print(llvm::dbgs() << "Iterating on operand\n"));
142-
143121
// If the loop IV is the operand, this op isn't loop invariant.
144-
if (indVar == op.getOperand(i)) {
145-
LLVM_DEBUG(llvm::dbgs() << "Loop IV is the operand\n");
122+
if (iv == op.getOperand(i))
146123
return false;
147-
}
148124

149125
// If the one of the iter_args is the operand, this op isn't loop invariant.
150-
if (llvm::is_contained(iterArgs, op.getOperand(i))) {
151-
LLVM_DEBUG(llvm::dbgs() << "One of the iter_args is the operand\n");
126+
if (llvm::is_contained(iterArgs, op.getOperand(i)))
152127
return false;
153-
}
154128

155129
if (operandSrc) {
156-
LLVM_DEBUG(llvm::dbgs() << *operandSrc << "Iterating on operand src\n");
157-
158130
// If the value was defined in the loop (outside of the if/else region),
159131
// and that operation itself wasn't meant to be hoisted, then mark this
160132
// operation loop dependent.
@@ -170,14 +142,14 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
170142

171143
// Checks if all ops in a region (i.e. list of blocks) are loop invariant.
172144
static bool
173-
areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
145+
areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
174146
ValueRange iterArgs,
175147
SmallPtrSetImpl<Operation *> &opsWithUsers,
176148
SmallPtrSetImpl<Operation *> &opsToHoist) {
177149

178150
for (auto &b : blockList) {
179151
for (auto &op : b) {
180-
if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist))
152+
if (!isOpLoopInvariant(op, loop, iterArgs, opsWithUsers, opsToHoist))
181153
return false;
182154
}
183155
}
@@ -187,14 +159,15 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
187159

188160
// Returns true if the affine.if op can be hoisted.
189161
static bool
190-
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
162+
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
163+
ValueRange iterArgs,
191164
SmallPtrSetImpl<Operation *> &opsWithUsers,
192165
SmallPtrSetImpl<Operation *> &opsToHoist) {
193-
if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), indVar, iterArgs,
166+
if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), loop, iterArgs,
194167
opsWithUsers, opsToHoist))
195168
return false;
196169

197-
if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), indVar, iterArgs,
170+
if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), loop, iterArgs,
198171
opsWithUsers, opsToHoist))
199172
return false;
200173

@@ -203,7 +176,6 @@ checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
203176

204177
void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
205178
auto *loopBody = forOp.getBody();
206-
auto indVar = forOp.getInductionVar();
207179
ValueRange iterArgs = forOp.getRegionIterArgs();
208180

209181
// This is the place where hoisted instructions would reside.
@@ -220,7 +192,7 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
220192
if (!op.use_empty())
221193
opsWithUsers.insert(&op);
222194
if (!isa<AffineYieldOp>(op)) {
223-
if (isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) {
195+
if (isOpLoopInvariant(op, forOp, iterArgs, opsWithUsers, opsToHoist)) {
224196
opsToMove.push_back(&op);
225197
}
226198
}
@@ -231,18 +203,13 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
231203
for (auto *op : opsToMove) {
232204
op->moveBefore(forOp);
233205
}
234-
235-
LLVM_DEBUG(forOp->print(llvm::dbgs() << "Modified loop\n"));
236206
}
237207

238208
void LoopInvariantCodeMotion::runOnOperation() {
239209
// Walk through all loops in a function in innermost-loop-first order. This
240210
// way, we first LICM from the inner loop, and place the ops in
241211
// the outer loop, which in turn can be further LICM'ed.
242-
getOperation().walk([&](AffineForOp op) {
243-
LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n"));
244-
runOnAffineForOp(op);
245-
});
212+
getOperation().walk([&](AffineForOp op) { runOnAffineForOp(op); });
246213
}
247214

248215
std::unique_ptr<OperationPass<func::FuncOp>>

mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -855,15 +855,16 @@ func.func @affine_prefetch_invariant() {
855855
affine.for %i0 = 0 to 10 {
856856
affine.for %i1 = 0 to 10 {
857857
%1 = affine.load %0[%i0, %i1] : memref<10x10xf32>
858+
// A prefetch shouldn't be hoisted.
858859
affine.prefetch %0[%i0, %i0], write, locality<0>, data : memref<10x10xf32>
859860
}
860861
}
861862

862863
// CHECK: memref.alloc() : memref<10x10xf32>
863864
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
864-
// CHECK-NEXT: affine.prefetch
865865
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
866-
// CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}} : memref<10x10xf32>
866+
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} : memref<10x10xf32>
867+
// CHECK-NEXT: affine.prefetch
867868
// CHECK-NEXT: }
868869
// CHECK-NEXT: }
869870
return

0 commit comments

Comments
 (0)