Skip to content

Commit 78f656e

Browse files
committed
Merge branch 'scanMlirSupport' into scan-IRBuilder-Support
2 parents 6d59b7e + 4834618 commit 78f656e

File tree

3 files changed

+331
-90
lines changed

3 files changed

+331
-90
lines changed

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

+212-60
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@
4747

4848
using namespace mlir;
4949

50+
llvm::SmallDenseMap<llvm::Value *, llvm::Type *> ReductionVarToType;
51+
llvm::OpenMPIRBuilder::InsertPointTy
52+
parallelAllocaIP; // TODO: change this alloca IP to point to originalvar
53+
// allocaIP. ReductionDecl need to be linked to scan var.
5054
namespace {
5155
static llvm::omp::ScheduleKind
5256
convertToScheduleKind(std::optional<omp::ClauseScheduleKind> schedKind) {
@@ -86,7 +90,9 @@ class OpenMPLoopInfoStackFrame
8690
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
8791
public:
8892
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OpenMPLoopInfoStackFrame)
89-
llvm::CanonicalLoopInfo *loopInfo = nullptr;
93+
// For constructs like scan, one Loop info frame can contain multiple
94+
// Canonical Loops
95+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
9096
};
9197

9298
/// Custom error class to signal translation errors that don't need reporting,
@@ -169,6 +175,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
169175
if (op.getDistScheduleChunkSize())
170176
result = todo("dist_schedule with chunk_size");
171177
};
178+
auto checkExclusive = [&todo](auto op, LogicalResult &result) {
179+
if (!op.getExclusiveVars().empty())
180+
result = todo("exclusive");
181+
};
172182
auto checkHint = [](auto op, LogicalResult &) {
173183
if (op.getHint())
174184
op.emitWarning("hint clause discarded");
@@ -232,8 +242,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
232242
op.getReductionSyms())
233243
result = todo("reduction");
234244
if (op.getReductionMod() &&
235-
op.getReductionMod().value() != omp::ReductionModifier::defaultmod)
236-
result = todo("reduction with modifier");
245+
op.getReductionMod().value() == omp::ReductionModifier::task)
246+
result = todo("reduction with task modifier");
237247
};
238248
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
239249
if (!op.getTaskReductionVars().empty() || op.getTaskReductionByref() ||
@@ -253,6 +263,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
253263
checkOrder(op, result);
254264
})
255265
.Case([&](omp::OrderedRegionOp op) { checkParLevelSimd(op, result); })
266+
.Case([&](omp::ScanOp op) { checkExclusive(op, result); })
256267
.Case([&](omp::SectionsOp op) {
257268
checkAllocate(op, result);
258269
checkPrivate(op, result);
@@ -382,15 +393,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
382393
/// Find the loop information structure for the loop nest being translated. It
383394
/// will return a `null` value unless called from the translation function for
384395
/// a loop wrapper operation after successfully translating its body.
385-
static llvm::CanonicalLoopInfo *
386-
findCurrentLoopInfo(LLVM::ModuleTranslation &moduleTranslation) {
387-
llvm::CanonicalLoopInfo *loopInfo = nullptr;
396+
static SmallVector<llvm::CanonicalLoopInfo *>
397+
findCurrentLoopInfos(LLVM::ModuleTranslation &moduleTranslation) {
398+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
388399
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
389400
[&](OpenMPLoopInfoStackFrame &frame) {
390-
loopInfo = frame.loopInfo;
401+
loopInfos = frame.loopInfos;
391402
return WalkResult::interrupt();
392403
});
393-
return loopInfo;
404+
return loopInfos;
394405
}
395406

396407
/// Converts the given region that appears within an OpenMP dialect operation to
@@ -1133,6 +1144,11 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
11331144
// variables. Although this could be done after allocas, we don't want to mess
11341145
// up with the alloca insertion point.
11351146
for (unsigned i = 0; i < op.getNumReductionVars(); ++i) {
1147+
1148+
llvm::Type *reductionType =
1149+
moduleTranslation.convertType(reductionDecls[i].getType());
1150+
ReductionVarToType[privateReductionVariables[i]] = reductionType;
1151+
11361152
SmallVector<llvm::Value *, 1> phis;
11371153

11381154
// map block argument to initializer region
@@ -1206,9 +1222,11 @@ static void collectReductionInfo(
12061222
atomicGen = owningAtomicReductionGens[i];
12071223
llvm::Value *variable =
12081224
moduleTranslation.lookupValue(loop.getReductionVars()[i]);
1209-
reductionInfos.push_back(
1210-
{moduleTranslation.convertType(reductionDecls[i].getType()), variable,
1211-
privateReductionVariables[i],
1225+
llvm::Type *reductionType =
1226+
moduleTranslation.convertType(reductionDecls[i].getType());
1227+
ReductionVarToType[privateReductionVariables[i]] = reductionType;
1228+
reductionInfos.push_back(
1229+
{reductionType, variable, privateReductionVariables[i],
12121230
/*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
12131231
owningReductionGens[i],
12141232
/*ReductionGenClang=*/nullptr, atomicGen});
@@ -2342,27 +2360,60 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23422360
if (failed(handleError(regionBlock, opInst)))
23432361
return failure();
23442362

2345-
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2346-
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2347-
2348-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2349-
ompBuilder->applyWorkshareLoop(
2350-
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2351-
convertToScheduleKind(schedule), chunk, isSimd,
2352-
scheduleMod == omp::ScheduleModifier::monotonic,
2353-
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2354-
workshareLoopType);
2355-
2356-
if (failed(handleError(wsloopIP, opInst)))
2357-
return failure();
2358-
2359-
// Process the reductions if required.
2360-
if (failed(createReductionsAndCleanup(
2361-
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2362-
privateReductionVariables, isByRef, wsloopOp.getNowait(),
2363-
/*isTeamsReduction=*/false)))
2364-
return failure();
2363+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
2364+
findCurrentLoopInfos(moduleTranslation);
2365+
auto inputLoopFinishIp = loopInfos.front()->getAfterIP();
2366+
bool isInScanRegion =
2367+
wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
2368+
mlir::omp::ReductionModifier::inscan);
2369+
if (isInScanRegion) {
2370+
builder.restoreIP(inputLoopFinishIp);
2371+
SmallVector<OwningReductionGen> owningReductionGens;
2372+
SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2373+
SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2374+
collectReductionInfo(wsloopOp, builder, moduleTranslation, reductionDecls,
2375+
owningReductionGens, owningAtomicReductionGens,
2376+
privateReductionVariables, reductionInfos);
2377+
llvm::BasicBlock *cont = splitBB(builder, false, "omp.scan.loop.cont");
2378+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
2379+
ompBuilder->emitScanReduction(builder.saveIP(), reductionInfos);
2380+
if (failed(handleError(redIP, opInst)))
2381+
return failure();
23652382

2383+
builder.restoreIP(*redIP);
2384+
builder.CreateBr(cont);
2385+
}
2386+
for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
2387+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2388+
ompBuilder->applyWorkshareLoop(
2389+
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
2390+
convertToScheduleKind(schedule), chunk, isSimd,
2391+
scheduleMod == omp::ScheduleModifier::monotonic,
2392+
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2393+
workshareLoopType);
2394+
2395+
if (failed(handleError(wsloopIP, opInst)))
2396+
return failure();
2397+
}
2398+
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2399+
if (isInScanRegion) {
2400+
SmallVector<Region *> reductionRegions;
2401+
llvm::transform(reductionDecls, std::back_inserter(reductionRegions),
2402+
[](omp::DeclareReductionOp reductionDecl) {
2403+
return &reductionDecl.getCleanupRegion();
2404+
});
2405+
if (failed(inlineOmpRegionCleanup(
2406+
reductionRegions, privateReductionVariables, moduleTranslation,
2407+
builder, "omp.reduction.cleanup")))
2408+
return failure();
2409+
} else {
2410+
// Process the reductions if required.
2411+
if (failed(createReductionsAndCleanup(
2412+
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2413+
privateReductionVariables, isByRef, wsloopOp.getNowait(),
2414+
/*isTeamsReduction=*/false)))
2415+
return failure();
2416+
}
23662417
return cleanupPrivateVars(builder, moduleTranslation, wsloopOp.getLoc(),
23672418
privateVarsInfo.llvmVars,
23682419
privateVarsInfo.privatizers);
@@ -2528,6 +2579,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
25282579

25292580
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
25302581
findAllocaInsertPoint(builder, moduleTranslation);
2582+
parallelAllocaIP = allocaIP;
25312583
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
25322584

25332585
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
@@ -2553,6 +2605,64 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
25532605
llvm_unreachable("Unknown ClauseOrderKind kind");
25542606
}
25552607

2608+
static LogicalResult
2609+
convertOmpScan(Operation &opInst, llvm::IRBuilderBase &builder,
2610+
LLVM::ModuleTranslation &moduleTranslation) {
2611+
if (failed(checkImplementationStatus(opInst)))
2612+
return failure();
2613+
auto scanOp = cast<omp::ScanOp>(opInst);
2614+
bool isInclusive = scanOp.hasInclusiveVars();
2615+
SmallVector<llvm::Value *> llvmScanVars;
2616+
SmallVector<llvm::Type *> llvmScanVarsType;
2617+
mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars();
2618+
if (!isInclusive)
2619+
mlirScanVars = scanOp.getExclusiveVars();
2620+
for (auto val : mlirScanVars) {
2621+
llvm::Value *llvmVal = moduleTranslation.lookupValue(val);
2622+
llvmScanVars.push_back(llvmVal);
2623+
llvmScanVarsType.push_back(ReductionVarToType[llvmVal]);
2624+
val.getDefiningOp();
2625+
}
2626+
auto parallelOp = scanOp->getParentOfType<omp::ParallelOp>();
2627+
if (!parallelOp) {
2628+
return failure();
2629+
}
2630+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP = parallelAllocaIP;
2631+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
2632+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2633+
moduleTranslation.getOpenMPBuilder()->createScan(
2634+
ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive);
2635+
if (failed(handleError(afterIP, opInst)))
2636+
return failure();
2637+
2638+
builder.restoreIP(*afterIP);
2639+
2640+
// TODO: The argument of LoopnestOp is stored into the index variable and this
2641+
// variable is used across scan operation. However that makes the mlir
2642+
// invalid.(`Intra-iteration dependences from a statement in the structured
2643+
// block sequence that precede a scan directive to a statement in the
2644+
// structured block sequence that follows a scan directive must not exist,
2645+
// except for dependences for the list items specified in an inclusive or
2646+
// exclusive clause.`). The argument of LoopNestOp need to be loaded again
2647+
// after ScanOp again so mlir generated is valid.
2648+
auto parentOp = scanOp->getParentOp();
2649+
auto loopOp = cast<omp::LoopNestOp>(parentOp);
2650+
if (loopOp) {
2651+
auto &firstBlock = *(scanOp->getParentRegion()->getBlocks()).begin();
2652+
auto &ins = *(firstBlock.begin());
2653+
if (isa<LLVM::StoreOp>(ins)) {
2654+
LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
2655+
auto src = moduleTranslation.lookupValue(storeOp->getOperand(0));
2656+
if (src == moduleTranslation.lookupValue(
2657+
(loopOp.getRegion().getArguments())[0])) {
2658+
auto dest = moduleTranslation.lookupValue(storeOp->getOperand(1));
2659+
builder.CreateStore(src, dest);
2660+
}
2661+
}
2662+
}
2663+
return success();
2664+
}
2665+
25562666
/// Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
25572667
static LogicalResult
25582668
convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2626,13 +2736,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
26262736
return failure();
26272737

26282738
builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin());
2629-
llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation);
2630-
ompBuilder->applySimd(loopInfo, alignedVars,
2631-
simdOp.getIfExpr()
2632-
? moduleTranslation.lookupValue(simdOp.getIfExpr())
2633-
: nullptr,
2634-
order, simdlen, safelen);
2635-
2739+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
2740+
findCurrentLoopInfos(moduleTranslation);
2741+
for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
2742+
ompBuilder->applySimd(
2743+
loopInfo, alignedVars,
2744+
simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr())
2745+
: nullptr,
2746+
order, simdlen, safelen);
2747+
}
26362748
return cleanupPrivateVars(builder, moduleTranslation, simdOp.getLoc(),
26372749
privateVarsInfo.llvmVars,
26382750
privateVarsInfo.privatizers);
@@ -2698,16 +2810,51 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
26982810
ompLoc.DL);
26992811
computeIP = loopInfos.front()->getPreheaderIP();
27002812
}
2813+
if (auto wsloopOp = loopOp->getParentOfType<omp::WsloopOp>()) {
2814+
bool isInScanRegion =
2815+
wsloopOp.getReductionMod() && (wsloopOp.getReductionMod().value() ==
2816+
mlir::omp::ReductionModifier::inscan);
2817+
if (isInScanRegion) {
2818+
//TODO: Handle nesting if Scan loop is nested in a loop
2819+
assert(loopOp.getNumLoops() == 1);
2820+
llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
2821+
ompBuilder->createCanonicalScanLoops(
2822+
loc, bodyGen, lowerBound, upperBound, step,
2823+
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP,
2824+
"loop");
2825+
2826+
if (failed(handleError(loopResults, *loopOp)))
2827+
return failure();
2828+
auto inputLoop = loopResults->front();
2829+
auto scanLoop = loopResults->back();
2830+
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
2831+
[&](OpenMPLoopInfoStackFrame &frame) {
2832+
frame.loopInfos.push_back(inputLoop);
2833+
frame.loopInfos.push_back(scanLoop);
2834+
return WalkResult::interrupt();
2835+
});
2836+
builder.restoreIP(scanLoop->getAfterIP());
2837+
return success();
2838+
} else {
2839+
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2840+
ompBuilder->createCanonicalLoop(
2841+
loc, bodyGen, lowerBound, upperBound, step,
2842+
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
2843+
if (failed(handleError(loopResult, *loopOp)))
2844+
return failure();
2845+
loopInfos.push_back(*loopResult);
2846+
}
2847+
} else {
2848+
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2849+
ompBuilder->createCanonicalLoop(
2850+
loc, bodyGen, lowerBound, upperBound, step,
2851+
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
27012852

2702-
llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2703-
ompBuilder->createCanonicalLoop(
2704-
loc, bodyGen, lowerBound, upperBound, step,
2705-
/*IsSigned=*/true, loopOp.getLoopInclusive(), computeIP);
2706-
2707-
if (failed(handleError(loopResult, *loopOp)))
2708-
return failure();
2853+
if (failed(handleError(loopResult, *loopOp)))
2854+
return failure();
27092855

2710-
loopInfos.push_back(*loopResult);
2856+
loopInfos.push_back(*loopResult);
2857+
}
27112858
}
27122859

27132860
// Collapse loops. Store the insertion point because LoopInfos may get
@@ -2719,7 +2866,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
27192866
// after applying transformations.
27202867
moduleTranslation.stackWalk<OpenMPLoopInfoStackFrame>(
27212868
[&](OpenMPLoopInfoStackFrame &frame) {
2722-
frame.loopInfo = ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {});
2869+
frame.loopInfos.push_back(
2870+
ompBuilder->collapseLoops(ompLoc.DL, loopInfos, {}));
27232871
return WalkResult::interrupt();
27242872
});
27252873

@@ -4328,19 +4476,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
43284476
llvm::omp::WorksharingLoopType::DistributeStaticLoop;
43294477
bool loopNeedsBarrier = false;
43304478
llvm::Value *chunk = nullptr;
4331-
4332-
llvm::CanonicalLoopInfo *loopInfo =
4333-
findCurrentLoopInfo(moduleTranslation);
4334-
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4335-
ompBuilder->applyWorkshareLoop(
4336-
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4337-
convertToScheduleKind(schedule), chunk, isSimd,
4338-
scheduleMod == omp::ScheduleModifier::monotonic,
4339-
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4340-
workshareLoopType);
4341-
4342-
if (!wsloopIP)
4343-
return wsloopIP.takeError();
4479+
SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
4480+
findCurrentLoopInfos(moduleTranslation);
4481+
for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
4482+
llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4483+
ompBuilder->applyWorkshareLoop(
4484+
ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier,
4485+
convertToScheduleKind(schedule), chunk, isSimd,
4486+
scheduleMod == omp::ScheduleModifier::monotonic,
4487+
scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4488+
workshareLoopType);
4489+
4490+
if (!wsloopIP)
4491+
return wsloopIP.takeError();
4492+
}
43444493
}
43454494

43464495
if (failed(cleanupPrivateVars(builder, moduleTranslation,
@@ -5373,6 +5522,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
53735522
.Case([&](omp::SimdOp) {
53745523
return convertOmpSimd(*op, builder, moduleTranslation);
53755524
})
5525+
.Case([&](omp::ScanOp) {
5526+
return convertOmpScan(*op, builder, moduleTranslation);
5527+
})
53765528
.Case([&](omp::AtomicReadOp) {
53775529
return convertOmpAtomicRead(*op, builder, moduleTranslation);
53785530
})

0 commit comments

Comments
 (0)