Skip to content

[MLIR] Fix arbitrary checks in affine LICM #116469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 50 additions & 92 deletions mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,9 @@

#include "mlir/Dialect/Affine/Passes.h"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/AffineStructures.h"
#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
#include "mlir/Dialect/Affine/Analysis/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Affine/LoopUtils.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/Matchers.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"

Expand All @@ -41,17 +25,21 @@ namespace affine {
} // namespace affine
} // namespace mlir

#define DEBUG_TYPE "licm"
#define DEBUG_TYPE "affine-licm"

using namespace mlir;
using namespace mlir::affine;

namespace {

/// Affine loop invariant code motion (LICM) pass.
/// TODO: The pass is missing zero-trip tests.
/// TODO: This code should be removed once the new LICM pass can handle its
/// uses.
/// TODO: The pass is missing zero tripcount tests.
/// TODO: When compared to the other standard LICM pass, this pass
/// has some special handling for affine read/write ops but such handling
/// requires aliasing to be sound, and as such this pass is unsound. In
/// addition, this handling is nothing particular to affine memory ops but would
/// apply to any memory read/write effect ops. Either aliasing should be handled
/// or this pass can be removed and the standard LICM can be used.
struct LoopInvariantCodeMotion
: public affine::impl::AffineLoopInvariantCodeMotionBase<
LoopInvariantCodeMotion> {
Expand All @@ -61,100 +49,80 @@ struct LoopInvariantCodeMotion
} // namespace

static bool
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
SmallPtrSetImpl<Operation *> &opsWithUsers,
SmallPtrSetImpl<Operation *> &opsToHoist);
static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
SmallPtrSetImpl<Operation *> &opsWithUsers,
SmallPtrSetImpl<Operation *> &opsToHoist);

static bool
areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
ValueRange iterArgs,
areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
SmallPtrSetImpl<Operation *> &opsWithUsers,
SmallPtrSetImpl<Operation *> &opsToHoist);

// Returns true if the individual op is loop invariant.
static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
/// Returns true if `op` is invariant on `loop`.
static bool isOpLoopInvariant(Operation &op, AffineForOp loop,
SmallPtrSetImpl<Operation *> &opsWithUsers,
SmallPtrSetImpl<Operation *> &opsToHoist) {
LLVM_DEBUG(llvm::dbgs() << "iterating on op: " << op;);
Value iv = loop.getInductionVar();

if (auto ifOp = dyn_cast<AffineIfOp>(op)) {
if (!checkInvarianceOfNestedIfOps(ifOp, indVar, iterArgs, opsWithUsers,
opsToHoist))
if (!checkInvarianceOfNestedIfOps(ifOp, loop, opsWithUsers, opsToHoist))
return false;
} else if (auto forOp = dyn_cast<AffineForOp>(op)) {
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
opsWithUsers, opsToHoist))
if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), loop, opsWithUsers,
opsToHoist))
return false;
} else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
opsWithUsers, opsToHoist))
if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), loop, opsWithUsers,
opsToHoist))
return false;
} else if (!isMemoryEffectFree(&op) &&
!isa<AffineReadOpInterface, AffineWriteOpInterface,
AffinePrefetchOp>(&op)) {
!isa<AffineReadOpInterface, AffineWriteOpInterface>(&op)) {
// Check for side-effecting ops. Affine read/write ops are handled
// separately below.
return false;
} else if (!matchPattern(&op, m_Constant())) {
} else if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
// Register op in the set of ops that have users.
opsWithUsers.insert(&op);
if (isa<AffineReadOpInterface, AffineWriteOpInterface>(op)) {
auto read = dyn_cast<AffineReadOpInterface>(op);
Value memref = read ? read.getMemRef()
: cast<AffineWriteOpInterface>(op).getMemRef();
for (auto *user : memref.getUsers()) {
// If this memref has a user that is a DMA, give up because these
// operations write to this memref.
if (isa<AffineDmaStartOp, AffineDmaWaitOp>(user))
SmallVector<AffineForOp, 8> userIVs;
auto read = dyn_cast<AffineReadOpInterface>(op);
Value memref =
read ? read.getMemRef() : cast<AffineWriteOpInterface>(op).getMemRef();
for (auto *user : memref.getUsers()) {
// If the memref used by the load/store is used in a store elsewhere in
// the loop nest, we do not hoist. Similarly, if the memref used in a
// load is also being stored too, we do not hoist the load.
// FIXME: This is missing checking aliases.
if (&op == user)
continue;
if (hasEffect<MemoryEffects::Write>(user, memref) ||
(hasEffect<MemoryEffects::Read>(user, memref) &&
isa<AffineWriteOpInterface>(op))) {
userIVs.clear();
getAffineForIVs(*user, &userIVs);
// Check that userIVs don't contain the for loop around the op.
if (llvm::is_contained(userIVs, loop))
return false;
// If the memref used by the load/store is used in a store elsewhere in
// the loop nest, we do not hoist. Similarly, if the memref used in a
// load is also being stored too, we do not hoist the load.
if (isa<AffineWriteOpInterface>(user) ||
(isa<AffineReadOpInterface>(user) &&
isa<AffineWriteOpInterface>(op))) {
if (&op != user) {
SmallVector<AffineForOp, 8> userIVs;
getAffineForIVs(*user, &userIVs);
// Check that userIVs don't contain the for loop around the op.
if (llvm::is_contained(userIVs, getForInductionVarOwner(indVar)))
return false;
}
}
}
}

if (op.getNumOperands() == 0 && !isa<AffineYieldOp>(op)) {
LLVM_DEBUG(llvm::dbgs() << "Non-constant op with 0 operands\n");
return false;
}
}

// Check operands.
ValueRange iterArgs = loop.getRegionIterArgs();
for (unsigned int i = 0; i < op.getNumOperands(); ++i) {
auto *operandSrc = op.getOperand(i).getDefiningOp();

LLVM_DEBUG(
op.getOperand(i).print(llvm::dbgs() << "Iterating on operand\n"));

// If the loop IV is the operand, this op isn't loop invariant.
if (indVar == op.getOperand(i)) {
LLVM_DEBUG(llvm::dbgs() << "Loop IV is the operand\n");
if (iv == op.getOperand(i))
return false;
}

// If the one of the iter_args is the operand, this op isn't loop invariant.
if (llvm::is_contained(iterArgs, op.getOperand(i))) {
LLVM_DEBUG(llvm::dbgs() << "One of the iter_args is the operand\n");
if (llvm::is_contained(iterArgs, op.getOperand(i)))
return false;
}

if (operandSrc) {
LLVM_DEBUG(llvm::dbgs() << *operandSrc << "Iterating on operand src\n");

// If the value was defined in the loop (outside of the if/else region),
// and that operation itself wasn't meant to be hoisted, then mark this
// operation loop dependent.
Expand All @@ -170,14 +138,13 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,

// Checks if all ops in a region (i.e. list of blocks) are loop invariant.
static bool
areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,
ValueRange iterArgs,
areAllOpsInTheBlockListInvariant(Region &blockList, AffineForOp loop,
SmallPtrSetImpl<Operation *> &opsWithUsers,
SmallPtrSetImpl<Operation *> &opsToHoist) {

for (auto &b : blockList) {
for (auto &op : b) {
if (!isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist))
if (!isOpLoopInvariant(op, loop, opsWithUsers, opsToHoist))
return false;
}
}
Expand All @@ -187,40 +154,36 @@ areAllOpsInTheBlockListInvariant(Region &blockList, Value indVar,

// Returns true if the affine.if op can be hoisted.
static bool
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, Value indVar, ValueRange iterArgs,
checkInvarianceOfNestedIfOps(AffineIfOp ifOp, AffineForOp loop,
SmallPtrSetImpl<Operation *> &opsWithUsers,
SmallPtrSetImpl<Operation *> &opsToHoist) {
if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), indVar, iterArgs,
if (!areAllOpsInTheBlockListInvariant(ifOp.getThenRegion(), loop,
opsWithUsers, opsToHoist))
return false;

if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), indVar, iterArgs,
if (!areAllOpsInTheBlockListInvariant(ifOp.getElseRegion(), loop,
opsWithUsers, opsToHoist))
return false;

return true;
}

void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
auto *loopBody = forOp.getBody();
auto indVar = forOp.getInductionVar();
ValueRange iterArgs = forOp.getRegionIterArgs();

// This is the place where hoisted instructions would reside.
OpBuilder b(forOp.getOperation());

SmallPtrSet<Operation *, 8> opsToHoist;
SmallVector<Operation *, 8> opsToMove;
SmallPtrSet<Operation *, 8> opsWithUsers;

for (auto &op : *loopBody) {
for (Operation &op : *forOp.getBody()) {
// Register op in the set of ops that have users. This set is used
// to prevent hoisting ops that depend on these ops that are
// not being hoisted.
if (!op.use_empty())
opsWithUsers.insert(&op);
if (!isa<AffineYieldOp>(op)) {
if (isOpLoopInvariant(op, indVar, iterArgs, opsWithUsers, opsToHoist)) {
if (isOpLoopInvariant(op, forOp, opsWithUsers, opsToHoist)) {
opsToMove.push_back(&op);
}
}
Expand All @@ -231,18 +194,13 @@ void LoopInvariantCodeMotion::runOnAffineForOp(AffineForOp forOp) {
for (auto *op : opsToMove) {
op->moveBefore(forOp);
}

LLVM_DEBUG(forOp->print(llvm::dbgs() << "Modified loop\n"));
}

void LoopInvariantCodeMotion::runOnOperation() {
// Walk through all loops in a function in innermost-loop-first order. This
// way, we first LICM from the inner loop, and place the ops in
// the outer loop, which in turn can be further LICM'ed.
getOperation().walk([&](AffineForOp op) {
LLVM_DEBUG(op->print(llvm::dbgs() << "\nOriginal loop\n"));
runOnAffineForOp(op);
});
getOperation().walk([&](AffineForOp op) { runOnAffineForOp(op); });
}

std::unique_ptr<OperationPass<func::FuncOp>>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -855,15 +855,16 @@ func.func @affine_prefetch_invariant() {
affine.for %i0 = 0 to 10 {
affine.for %i1 = 0 to 10 {
%1 = affine.load %0[%i0, %i1] : memref<10x10xf32>
// A prefetch shouldn't be hoisted.
affine.prefetch %0[%i0, %i0], write, locality<0>, data : memref<10x10xf32>
}
}

// CHECK: memref.alloc() : memref<10x10xf32>
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: affine.prefetch
// CHECK-NEXT: affine.for %{{.*}} = 0 to 10 {
// CHECK-NEXT: %{{.*}} = affine.load %{{.*}}[%{{.*}} : memref<10x10xf32>
// CHECK-NEXT: affine.load %{{.*}}[%{{.*}} : memref<10x10xf32>
// CHECK-NEXT: affine.prefetch
// CHECK-NEXT: }
// CHECK-NEXT: }
return
Expand Down
Loading