Skip to content

[flang] Fixed LoopVersioning for array slices. #65703

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 176 additions & 41 deletions flang/lib/Optimizer/Transforms/LoopVersioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,72 @@ class LoopVersioningPass
void runOnOperation() override;
};

/// @struct ArgInfo
/// A structure to hold an argument, the size of the argument and dimension
/// information.
struct ArgInfo {
mlir::Value arg;
size_t size;
unsigned rank;
fir::BoxDimsOp dims[CFI_MAX_RANK];
};

/// @struct ArgsUsageInLoop
/// A structure providing information about the function arguments
/// usage by the instructions immediately nested in a loop.
struct ArgsUsageInLoop {
/// Mapping between the memref operand of an array indexing
/// operation (e.g. fir.coordinate_of) and the argument information.
llvm::DenseMap<mlir::Value, ArgInfo> usageInfo;
/// Some array indexing operations inside a loop cannot be transformed.
/// This vector holds the memref operands of such operations.
/// The vector is used to make sure that we do not try to transform
/// any outer loop, since this will imply the operation rewrite
/// in this loop.
llvm::SetVector<mlir::Value> cannotTransform;

// Debug dump of the structure members assuming that
// the information has been collected for the given loop.
void dump(fir::DoLoopOp loop) const {
// clang-format off
LLVM_DEBUG(
mlir::OpPrintingFlags printFlags;
printFlags.skipRegions();
llvm::dbgs() << "Arguments usage info for loop:\n";
loop.print(llvm::dbgs(), printFlags);
llvm::dbgs() << "\nUsed args:\n";
for (auto &use : usageInfo) {
mlir::Value v = use.first;
v.print(llvm::dbgs(), printFlags);
llvm::dbgs() << "\n";
}
llvm::dbgs() << "\nCannot transform args:\n";
for (mlir::Value arg : cannotTransform) {
arg.print(llvm::dbgs(), printFlags);
llvm::dbgs() << "\n";
}
llvm::dbgs() << "====\n"
);
// clang-format on
}

// Erase usageInfo and cannotTransform entries for a set
// of given arguments.
void eraseUsage(const llvm::SetVector<mlir::Value> &args) {
for (auto &arg : args)
usageInfo.erase(arg);
cannotTransform.set_subtract(args);
}

// Erase usageInfo and cannotTransform entries for a set
// of given arguments provided in the form of usageInfo map.
void eraseUsage(const llvm::DenseMap<mlir::Value, ArgInfo> &args) {
for (auto &arg : args) {
usageInfo.erase(arg.first);
cannotTransform.remove(arg.first);
}
}
};
} // namespace

/// @c replaceOuterUses - replace uses outside of @c op with result of @c
Expand Down Expand Up @@ -179,16 +245,6 @@ void LoopVersioningPass::runOnOperation() {
LLVM_DEBUG(llvm::dbgs() << "=== Begin " DEBUG_TYPE " ===\n");
mlir::func::FuncOp func = getOperation();

/// @c ArgInfo
/// A structure to hold an argument, the size of the argument and dimension
/// information.
struct ArgInfo {
mlir::Value arg;
size_t size;
unsigned rank;
fir::BoxDimsOp dims[CFI_MAX_RANK];
};

// First look for arguments with assumed shape = unknown extent in the lowest
// dimension.
LLVM_DEBUG(llvm::dbgs() << "Func-name:" << func.getSymName() << "\n");
Expand Down Expand Up @@ -224,58 +280,137 @@ void LoopVersioningPass::runOnOperation() {
}
}

if (argsOfInterest.empty())
if (argsOfInterest.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "No suitable arguments.\n=== End " DEBUG_TYPE " ===\n");
return;
}

struct OpsWithArgs {
mlir::Operation *op;
mlir::SmallVector<ArgInfo, 4> argsAndDims;
};
// Now see if those arguments are used inside any loop.
mlir::SmallVector<OpsWithArgs, 4> loopsOfInterest;
// A list of all loops in the function in post-order.
mlir::SmallVector<fir::DoLoopOp> originalLoops;
// Information about the arguments usage by the instructions
// immediately nested in a loop.
llvm::DenseMap<fir::DoLoopOp, ArgsUsageInLoop> argsInLoops;

// Traverse the loops in post-order and see
// if those arguments are used inside any loop.
func.walk([&](fir::DoLoopOp loop) {
mlir::Block &body = *loop.getBody();
mlir::SmallVector<ArgInfo, 4> argsInLoop;
auto &argsInLoop = argsInLoops[loop];
originalLoops.push_back(loop);
body.walk([&](mlir::Operation *op) {
// support either fir.array_coor or fir.coordinate_of
if (auto arrayCoor = mlir::dyn_cast<fir::ArrayCoorOp>(op)) {
// no support currently for sliced arrays
if (arrayCoor.getSlice())
return;
} else if (!mlir::isa<fir::CoordinateOp>(op)) {
// Support either fir.array_coor or fir.coordinate_of.
if (!mlir::isa<fir::ArrayCoorOp, fir::CoordinateOp>(op))
return;
}

// The current operation could be inside another loop than
// the one we're currently processing. Skip it, we'll get
// to it later.
// Process only operations immediately nested in the current loop.
if (op->getParentOfType<fir::DoLoopOp>() != loop)
return;
mlir::Value operand = op->getOperand(0);
for (auto a : argsOfInterest) {
if (a.arg == normaliseVal(operand)) {
// use the reboxed value, not the block arg when re-creating the loop:
// Use the reboxed value, not the block arg when re-creating the loop.
// TODO: should we check that the operand dominates the loop?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did I understand this correctly?

operand is an operand to the array indexing operation so it must dominate it. In the normal case it will also dominate the loop (the declare and reboxing are generated in the function entry block). But one could write valid IR where the reboxing does not dominate the outer loop so we ought to handle that case. Nice spot!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is the operand of array_coor/coordinate_of that would normally also dominate the loop, but this is not guaranteed. I am glad that we are on the same page. I will try to create a reproducer in FIR and fix it in a separate check-in.

Thank you for the review!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vzakhari is this issue fixed? If not, do you have a test that you could share?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vzakhari I'm looking into the "check that the operand dominates the loop", but I'm struggling to come up with some FIR that is valid and shows this issue. Did you get somewhere on a reproducer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may still be a problem. I just hand-modified the first case from loop-versioning.fir:

module {
  func.func @sum1d(%arg0: !fir.box<!fir.array<?xf64>> {fir.bindc_name = "a"}, %arg1: !fir.ref<i32> {fir.bindc_name = "n"}) {
    %decl = fir.declare %arg0 {uniq_name = "a"} : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
    %0 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QMmoduleFsum1dEi"}
    %1 = fir.alloca f64 {bindc_name = "sum", uniq_name = "_QMmoduleFsum1dEsum"}
    %cst = arith.constant 0.000000e+00 : f64
    fir.store %cst to %1 : !fir.ref<f64>
    %c1_i32 = arith.constant 1 : i32
    %2 = fir.convert %c1_i32 : (i32) -> index
    %3 = fir.load %arg1 : !fir.ref<i32>
    %4 = fir.convert %3 : (i32) -> index
    %c1 = arith.constant 1 : index
    %5 = fir.convert %2 : (index) -> i32
    %6:2 = fir.do_loop %arg2 = %2 to %4 step %c1 iter_args(%arg3 = %5) -> (index, i32) {
      %rebox = fir.rebox %decl : (!fir.box<!fir.array<?xf64>>) -> !fir.box<!fir.array<?xf64>>
      fir.store %arg3 to %0 : !fir.ref<i32>
      %7 = fir.load %1 : !fir.ref<f64>
      %8 = fir.load %0 : !fir.ref<i32>
      %9 = fir.convert %8 : (i32) -> i64
      %c1_i64 = arith.constant 1 : i64
      %10 = arith.subi %9, %c1_i64 : i64
      %11 = fir.coordinate_of %rebox, %10 : (!fir.box<!fir.array<?xf64>>, i64) -> !fir.ref<f64>
      %12 = fir.load %11 : !fir.ref<f64>
      %13 = arith.addf %7, %12 fastmath<contract> : f64
      fir.store %13 to %1 : !fir.ref<f64>
      %14 = arith.addi %arg2, %c1 : index
      %15 = fir.convert %c1 : (index) -> i32
      %16 = fir.load %0 : !fir.ref<i32>
      %17 = arith.addi %16, %15 : i32
      fir.result %14, %17 : index, i32
    }
    fir.store %6#1 to %0 : !fir.ref<i32>
    return
  }
}

It fails with an assertion: Assertion changed && "Expected operations to have changed"' failed.`, but I am not sure what it means exactly.

If I understand it correctly, as soon as the loop is proven to be safe to multiversion, we will transform it and we will use %rebox for generating the contiguity check before the loop, while it is defined inside the loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll do some debugging. The assert basically means "we didn't find anything in the loop to update", which of course means versioning didn't achieve anything.

We should identify this case and just say "nope, not doing that".

// If this might be a case, we should record such operands in
// argsInLoop.cannotTransform, so that they disable the transformation
// for the parent loops as well.
a.arg = operand;
// Only add if it's not already in the list.
if (std::find_if(argsInLoop.begin(), argsInLoop.end(), [&](auto it) {
return it.arg == a.arg;
}) == argsInLoop.end()) {

argsInLoop.push_back(a);
// No support currently for sliced arrays.
// This means that we cannot transform properly
// instructions referencing a.arg in the whole loop
// nest this loop is located in.
if (auto arrayCoor = mlir::dyn_cast<fir::ArrayCoorOp>(op))
if (arrayCoor.getSlice())
argsInLoop.cannotTransform.insert(a.arg);

if (argsInLoop.cannotTransform.contains(a.arg)) {
// Remove any previously recorded usage, if any.
argsInLoop.usageInfo.erase(a.arg);
break;
}

// Record the a.arg usage, if not recorded yet.
argsInLoop.usageInfo.try_emplace(a.arg, a);
break;
}
}
});

if (!argsInLoop.empty()) {
OpsWithArgs ops = {loop, argsInLoop};
loopsOfInterest.push_back(ops);
}
});
if (loopsOfInterest.empty())

// Dump loops info after initial collection.
// clang-format off
LLVM_DEBUG(
llvm::dbgs() << "Initial usage info:\n";
for (fir::DoLoopOp loop : originalLoops) {
auto &argsInLoop = argsInLoops[loop];
argsInLoop.dump(loop);
}
);
// clang-format on

// Clear argument usage for parent loops if an inner loop
// contains a non-transformable usage.
for (fir::DoLoopOp loop : originalLoops) {
auto &argsInLoop = argsInLoops[loop];
if (argsInLoop.cannotTransform.empty())
continue;

fir::DoLoopOp parent = loop;
while ((parent = parent->getParentOfType<fir::DoLoopOp>()))
argsInLoops[parent].eraseUsage(argsInLoop.cannotTransform);
}

// If an argument access can be optimized in a loop and
// its descendant loop, then it does not make sense to
// generate the contiguity check for the descendant loop.
// The check will be produced as part of the ancestor
// loop's transformation. So we can clear the argument
// usage for all descendant loops.
for (fir::DoLoopOp loop : originalLoops) {
auto &argsInLoop = argsInLoops[loop];
if (argsInLoop.usageInfo.empty())
continue;

loop.getBody()->walk([&](fir::DoLoopOp dloop) {
argsInLoops[dloop].eraseUsage(argsInLoop.usageInfo);
});
}

// clang-format off
LLVM_DEBUG(
llvm::dbgs() << "Final usage info:\n";
for (fir::DoLoopOp loop : originalLoops) {
auto &argsInLoop = argsInLoops[loop];
argsInLoop.dump(loop);
}
);
// clang-format on

// Reduce the collected information to a list of loops
// with attached arguments usage information.
// The list must hold the loops in post order, so that
// the inner loops are transformed before the outer loops.
struct OpsWithArgs {
mlir::Operation *op;
mlir::SmallVector<ArgInfo, 4> argsAndDims;
};
mlir::SmallVector<OpsWithArgs, 4> loopsOfInterest;
for (fir::DoLoopOp loop : originalLoops) {
auto &argsInLoop = argsInLoops[loop];
if (argsInLoop.usageInfo.empty())
continue;
OpsWithArgs info;
info.op = loop;
for (auto &arg : argsInLoop.usageInfo)
info.argsAndDims.push_back(arg.second);
loopsOfInterest.emplace_back(std::move(info));
}

if (loopsOfInterest.empty()) {
LLVM_DEBUG(llvm::dbgs()
<< "No loops to transform.\n=== End " DEBUG_TYPE " ===\n");
return;
}

// If we get here, there are loops to process.
fir::FirOpBuilder builder{module, std::move(kindMap)};
Expand Down
Loading