Skip to content

Commit 2e37f28

Browse files
authored
[MLIR][OpenMP] Update omp.wsloop translation to LLVM IR (4/5) (#89214)
This patch introduces minimal changes to the MLIR to LLVM IR translation of `omp.wsloop` to support the loop wrapper approach. There is `omp.loop_nest` related translation code that should be extracted and shared among all loop operations (e.g. `omp.simd`). This would possibly also help in the addition of support for compound constructs later on. This first approach is only intended to keep things running after the transition to loop wrappers and not to add support for other use cases enabled by that transition. This PR on its own will not pass premerge tests. All patches in the stack are needed before it can be compiled and passes tests.
1 parent 8843d54 commit 2e37f28

9 files changed

+572
-466
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -916,49 +916,50 @@ static LogicalResult inlineReductionCleanup(
916916
static LogicalResult
917917
convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
918918
LLVM::ModuleTranslation &moduleTranslation) {
919-
auto loop = cast<omp::WsloopOp>(opInst);
920-
const bool isByRef = loop.getByref();
919+
auto wsloopOp = cast<omp::WsloopOp>(opInst);
920+
auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());
921+
const bool isByRef = wsloopOp.getByref();
922+
921923
// TODO: this should be in the op verifier instead.
922-
if (loop.getLowerBound().empty())
924+
if (loopOp.getLowerBound().empty())
923925
return failure();
924926

925927
// Static is the default.
926928
auto schedule =
927-
loop.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
929+
wsloopOp.getScheduleVal().value_or(omp::ClauseScheduleKind::Static);
928930

929931
// Find the loop configuration.
930-
llvm::Value *step = moduleTranslation.lookupValue(loop.getStep()[0]);
932+
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[0]);
931933
llvm::Type *ivType = step->getType();
932934
llvm::Value *chunk = nullptr;
933-
if (loop.getScheduleChunkVar()) {
935+
if (wsloopOp.getScheduleChunkVar()) {
934936
llvm::Value *chunkVar =
935-
moduleTranslation.lookupValue(loop.getScheduleChunkVar());
937+
moduleTranslation.lookupValue(wsloopOp.getScheduleChunkVar());
936938
chunk = builder.CreateSExtOrTrunc(chunkVar, ivType);
937939
}
938940

939941
SmallVector<omp::DeclareReductionOp> reductionDecls;
940-
collectReductionDecls(loop, reductionDecls);
942+
collectReductionDecls(wsloopOp, reductionDecls);
941943
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
942944
findAllocaInsertPoint(builder, moduleTranslation);
943945

944946
SmallVector<llvm::Value *> privateReductionVariables;
945947
DenseMap<Value, llvm::Value *> reductionVariableMap;
946948
if (!isByRef) {
947-
allocByValReductionVars(loop, builder, moduleTranslation, allocaIP,
949+
allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP,
948950
reductionDecls, privateReductionVariables,
949951
reductionVariableMap);
950952
}
951953

952954
// Before the loop, store the initial values of reductions into reduction
953955
// variables. Although this could be done after allocas, we don't want to mess
954956
// up with the alloca insertion point.
955-
MutableArrayRef<BlockArgument> reductionArgs =
956-
loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
957-
for (unsigned i = 0; i < loop.getNumReductionVars(); ++i) {
957+
ArrayRef<BlockArgument> reductionArgs = wsloopOp.getRegion().getArguments();
958+
for (unsigned i = 0; i < wsloopOp.getNumReductionVars(); ++i) {
958959
SmallVector<llvm::Value *> phis;
959960

960961
// map block argument to initializer region
961-
mapInitializationArg(loop, moduleTranslation, reductionDecls, i);
962+
mapInitializationArg(wsloopOp, moduleTranslation, reductionDecls, i);
962963

963964
if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
964965
"omp.reduction.neutral", builder,
@@ -977,7 +978,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
977978

978979
privateReductionVariables.push_back(var);
979980
moduleTranslation.mapValue(reductionArgs[i], phis[0]);
980-
reductionVariableMap.try_emplace(loop.getReductionVars()[i], phis[0]);
981+
reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]);
981982
} else {
982983
// for by-ref case the store is inside of the reduction region
983984
builder.CreateStore(phis[0], privateReductionVariables[i]);
@@ -1008,33 +1009,34 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
10081009
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
10091010
// Make sure further conversions know about the induction variable.
10101011
moduleTranslation.mapValue(
1011-
loop.getRegion().front().getArgument(loopInfos.size()), iv);
1012+
loopOp.getRegion().front().getArgument(loopInfos.size()), iv);
10121013

10131014
// Capture the body insertion point for use in nested loops. BodyIP of the
10141015
// CanonicalLoopInfo always points to the beginning of the entry block of
10151016
// the body.
10161017
bodyInsertPoints.push_back(ip);
10171018

1018-
if (loopInfos.size() != loop.getNumLoops() - 1)
1019+
if (loopInfos.size() != loopOp.getNumLoops() - 1)
10191020
return;
10201021

10211022
// Convert the body of the loop.
10221023
builder.restoreIP(ip);
1023-
convertOmpOpRegions(loop.getRegion(), "omp.wsloop.region", builder,
1024+
convertOmpOpRegions(loopOp.getRegion(), "omp.wsloop.region", builder,
10241025
moduleTranslation, bodyGenStatus);
10251026
};
10261027

10271028
// Delegate actual loop construction to the OpenMP IRBuilder.
1028-
// TODO: this currently assumes Wsloop is semantically similar to SCF loop,
1029-
// i.e. it has a positive step, uses signed integer semantics. Reconsider
1030-
// this code when Wsloop clearly supports more cases.
1029+
// TODO: this currently assumes omp.loop_nest is semantically similar to SCF
1030+
// loop, i.e. it has a positive step, uses signed integer semantics.
1031+
// Reconsider this code when the nested loop operation clearly supports more
1032+
// cases.
10311033
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
1032-
for (unsigned i = 0, e = loop.getNumLoops(); i < e; ++i) {
1034+
for (unsigned i = 0, e = loopOp.getNumLoops(); i < e; ++i) {
10331035
llvm::Value *lowerBound =
1034-
moduleTranslation.lookupValue(loop.getLowerBound()[i]);
1036+
moduleTranslation.lookupValue(loopOp.getLowerBound()[i]);
10351037
llvm::Value *upperBound =
1036-
moduleTranslation.lookupValue(loop.getUpperBound()[i]);
1037-
llvm::Value *step = moduleTranslation.lookupValue(loop.getStep()[i]);
1038+
moduleTranslation.lookupValue(loopOp.getUpperBound()[i]);
1039+
llvm::Value *step = moduleTranslation.lookupValue(loopOp.getStep()[i]);
10381040

10391041
// Make sure loop trip count are emitted in the preheader of the outermost
10401042
// loop at the latest so that they are all available for the new collapsed
@@ -1047,7 +1049,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
10471049
}
10481050
loopInfos.push_back(ompBuilder->createCanonicalLoop(
10491051
loc, bodyGen, lowerBound, upperBound, step,
1050-
/*IsSigned=*/true, loop.getInclusive(), computeIP));
1052+
/*IsSigned=*/true, loopOp.getInclusive(), computeIP));
10511053

10521054
if (failed(bodyGenStatus))
10531055
return failure();
@@ -1062,13 +1064,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
10621064
allocaIP = findAllocaInsertPoint(builder, moduleTranslation);
10631065

10641066
// TODO: Handle doacross loops when the ordered clause has a parameter.
1065-
bool isOrdered = loop.getOrderedVal().has_value();
1067+
bool isOrdered = wsloopOp.getOrderedVal().has_value();
10661068
std::optional<omp::ScheduleModifier> scheduleModifier =
1067-
loop.getScheduleModifier();
1068-
bool isSimd = loop.getSimdModifier();
1069+
wsloopOp.getScheduleModifier();
1070+
bool isSimd = wsloopOp.getSimdModifier();
10691071

10701072
ompBuilder->applyWorkshareLoop(
1071-
ompLoc.DL, loopInfo, allocaIP, !loop.getNowait(),
1073+
ompLoc.DL, loopInfo, allocaIP, !wsloopOp.getNowait(),
10721074
convertToScheduleKind(schedule), chunk, isSimd,
10731075
scheduleModifier == omp::ScheduleModifier::monotonic,
10741076
scheduleModifier == omp::ScheduleModifier::nonmonotonic, isOrdered);
@@ -1080,15 +1082,15 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
10801082
builder.restoreIP(afterIP);
10811083

10821084
// Process the reductions if required.
1083-
if (loop.getNumReductionVars() == 0)
1085+
if (wsloopOp.getNumReductionVars() == 0)
10841086
return success();
10851087

10861088
// Create the reduction generators. We need to own them here because
10871089
// ReductionInfo only accepts references to the generators.
10881090
SmallVector<OwningReductionGen> owningReductionGens;
10891091
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
10901092
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
1091-
collectReductionInfo(loop, builder, moduleTranslation, reductionDecls,
1093+
collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
10921094
owningReductionGens, owningAtomicReductionGens,
10931095
privateReductionVariables, reductionInfos);
10941096

@@ -1099,9 +1101,9 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
10991101
builder.SetInsertPoint(tempTerminator);
11001102
llvm::OpenMPIRBuilder::InsertPointTy contInsertPoint =
11011103
ompBuilder->createReductions(builder.saveIP(), allocaIP, reductionInfos,
1102-
loop.getNowait(), isByRef);
1104+
wsloopOp.getNowait(), isByRef);
11031105
if (!contInsertPoint.getBlock())
1104-
return loop->emitOpError() << "failed to convert reductions";
1106+
return wsloopOp->emitOpError() << "failed to convert reductions";
11051107
auto nextInsertionPoint =
11061108
ompBuilder->createBarrier(contInsertPoint, llvm::omp::OMPD_for);
11071109
tempTerminator->eraseFromParent();

mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
1212
%loop_ub = llvm.mlir.constant(9 : i32) : i32
1313
%loop_lb = llvm.mlir.constant(0 : i32) : i32
1414
%loop_step = llvm.mlir.constant(1 : i32) : i32
15-
omp.wsloop for (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
16-
%gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
17-
llvm.store %loop_cnt, %gep : i32, !llvm.ptr
18-
omp.yield
15+
omp.wsloop {
16+
omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
17+
%gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
18+
llvm.store %loop_cnt, %gep : i32, !llvm.ptr
19+
omp.yield
20+
}
21+
omp.terminator
1922
}
2023
omp.terminator
2124
}

mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,16 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
88
%loop_ub = llvm.mlir.constant(99 : i32) : i32
99
%loop_lb = llvm.mlir.constant(0 : i32) : i32
1010
%loop_step = llvm.mlir.constant(1 : index) : i32
11-
omp.wsloop for (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
12-
%1 = llvm.add %arg1, %arg2 : i32
13-
%2 = llvm.mul %arg2, %loop_ub overflow<nsw> : i32
14-
%3 = llvm.add %arg1, %2 :i32
15-
%4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i32) -> !llvm.ptr, i32
16-
llvm.store %1, %4 : i32, !llvm.ptr
17-
omp.yield
11+
omp.wsloop {
12+
omp.loop_nest (%arg1, %arg2) : i32 = (%loop_lb, %loop_lb) to (%loop_ub, %loop_ub) inclusive step (%loop_step, %loop_step) {
13+
%1 = llvm.add %arg1, %arg2 : i32
14+
%2 = llvm.mul %arg2, %loop_ub overflow<nsw> : i32
15+
%3 = llvm.add %arg1, %2 :i32
16+
%4 = llvm.getelementptr %arg0[%3] : (!llvm.ptr, i32) -> !llvm.ptr, i32
17+
llvm.store %1, %4 : i32, !llvm.ptr
18+
omp.yield
19+
}
20+
omp.terminator
1821
}
1922
llvm.return
2023
}

mlir/test/Target/LLVMIR/omptarget-wsloop.mlir

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
88
%loop_ub = llvm.mlir.constant(9 : i32) : i32
99
%loop_lb = llvm.mlir.constant(0 : i32) : i32
1010
%loop_step = llvm.mlir.constant(1 : i32) : i32
11-
omp.wsloop for (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
12-
%gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
13-
llvm.store %loop_cnt, %gep : i32, !llvm.ptr
14-
omp.yield
11+
omp.wsloop {
12+
omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
13+
%gep = llvm.getelementptr %arg0[0, %loop_cnt] : (!llvm.ptr, i32) -> !llvm.ptr, !llvm.array<10 x i32>
14+
llvm.store %loop_cnt, %gep : i32, !llvm.ptr
15+
omp.yield
16+
}
17+
omp.terminator
1518
}
1619
llvm.return
1720
}
@@ -20,8 +23,11 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
2023
%loop_ub = llvm.mlir.constant(9 : i32) : i32
2124
%loop_lb = llvm.mlir.constant(0 : i32) : i32
2225
%loop_step = llvm.mlir.constant(1 : i32) : i32
23-
omp.wsloop for (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
24-
omp.yield
26+
omp.wsloop {
27+
omp.loop_nest (%loop_cnt) : i32 = (%loop_lb) to (%loop_ub) inclusive step (%loop_step) {
28+
omp.yield
29+
}
30+
omp.terminator
2531
}
2632
llvm.return
2733
}

mlir/test/Target/LLVMIR/openmp-data-target-device.mlir

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,23 @@ module attributes { } {
3131
%18 = llvm.mlir.constant(1 : i64) : i64
3232
%19 = llvm.alloca %18 x i32 {pinned} : (i64) -> !llvm.ptr<5>
3333
%20 = llvm.addrspacecast %19 : !llvm.ptr<5> to !llvm.ptr
34-
omp.wsloop for (%arg2) : i32 = (%16) to (%15) inclusive step (%16) {
35-
llvm.store %arg2, %20 : i32, !llvm.ptr
36-
%21 = llvm.load %20 : !llvm.ptr -> i32
37-
%22 = llvm.sext %21 : i32 to i64
38-
%23 = llvm.mlir.constant(1 : i64) : i64
39-
%24 = llvm.mlir.constant(0 : i64) : i64
40-
%25 = llvm.sub %22, %23 overflow<nsw> : i64
41-
%26 = llvm.mul %25, %23 overflow<nsw> : i64
42-
%27 = llvm.mul %26, %23 overflow<nsw> : i64
43-
%28 = llvm.add %27, %24 overflow<nsw> : i64
44-
%29 = llvm.mul %23, %17 overflow<nsw> : i64
45-
%30 = llvm.getelementptr %arg0[%28] : (!llvm.ptr, i64) -> !llvm.ptr, i32
46-
llvm.store %21, %30 : i32, !llvm.ptr
47-
omp.yield
34+
omp.wsloop {
35+
omp.loop_nest (%arg2) : i32 = (%16) to (%15) inclusive step (%16) {
36+
llvm.store %arg2, %20 : i32, !llvm.ptr
37+
%21 = llvm.load %20 : !llvm.ptr -> i32
38+
%22 = llvm.sext %21 : i32 to i64
39+
%23 = llvm.mlir.constant(1 : i64) : i64
40+
%24 = llvm.mlir.constant(0 : i64) : i64
41+
%25 = llvm.sub %22, %23 overflow<nsw> : i64
42+
%26 = llvm.mul %25, %23 overflow<nsw> : i64
43+
%27 = llvm.mul %26, %23 overflow<nsw> : i64
44+
%28 = llvm.add %27, %24 overflow<nsw> : i64
45+
%29 = llvm.mul %23, %17 overflow<nsw> : i64
46+
%30 = llvm.getelementptr %arg0[%28] : (!llvm.ptr, i64) -> !llvm.ptr, i32
47+
llvm.store %21, %30 : i32, !llvm.ptr
48+
omp.yield
49+
}
50+
omp.terminator
4851
}
4952
omp.terminator
5053
}

0 commit comments

Comments
 (0)