47
47
48
48
using namespace mlir ;
49
49
50
+ llvm::SmallDenseMap<llvm::Value *, llvm::Type *> ReductionVarToType;
51
+ llvm::OpenMPIRBuilder::InsertPointTy
52
+ parallelAllocaIP; // TODO: change this alloca IP to point to originalvar
53
+ // allocaIP. ReductionDecl need to be linked to scan var.
50
54
namespace {
51
55
static llvm::omp::ScheduleKind
52
56
convertToScheduleKind (std::optional<omp::ClauseScheduleKind> schedKind) {
@@ -86,7 +90,9 @@ class OpenMPLoopInfoStackFrame
86
90
: public LLVM::ModuleTranslation::StackFrameBase<OpenMPLoopInfoStackFrame> {
87
91
public:
88
92
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (OpenMPLoopInfoStackFrame)
89
- llvm::CanonicalLoopInfo *loopInfo = nullptr ;
93
+ // For constructs like scan, one Loop info frame can contain multiple
94
+ // Canonical Loops
95
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
90
96
};
91
97
92
98
// / Custom error class to signal translation errors that don't need reporting,
@@ -169,6 +175,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
169
175
if (op.getDistScheduleChunkSize ())
170
176
result = todo (" dist_schedule with chunk_size" );
171
177
};
178
+ auto checkExclusive = [&todo](auto op, LogicalResult &result) {
179
+ if (!op.getExclusiveVars ().empty ())
180
+ result = todo (" exclusive" );
181
+ };
172
182
auto checkHint = [](auto op, LogicalResult &) {
173
183
if (op.getHint ())
174
184
op.emitWarning (" hint clause discarded" );
@@ -232,8 +242,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
232
242
op.getReductionSyms ())
233
243
result = todo (" reduction" );
234
244
if (op.getReductionMod () &&
235
- op.getReductionMod ().value () != omp::ReductionModifier::defaultmod )
236
- result = todo (" reduction with modifier" );
245
+ op.getReductionMod ().value () == omp::ReductionModifier::task )
246
+ result = todo (" reduction with task modifier" );
237
247
};
238
248
auto checkTaskReduction = [&todo](auto op, LogicalResult &result) {
239
249
if (!op.getTaskReductionVars ().empty () || op.getTaskReductionByref () ||
@@ -253,6 +263,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
253
263
checkOrder (op, result);
254
264
})
255
265
.Case ([&](omp::OrderedRegionOp op) { checkParLevelSimd (op, result); })
266
+ .Case ([&](omp::ScanOp op) { checkExclusive (op, result); })
256
267
.Case ([&](omp::SectionsOp op) {
257
268
checkAllocate (op, result);
258
269
checkPrivate (op, result);
@@ -382,15 +393,15 @@ findAllocaInsertPoint(llvm::IRBuilderBase &builder,
382
393
// / Find the loop information structure for the loop nest being translated. It
383
394
// / will return a `null` value unless called from the translation function for
384
395
// / a loop wrapper operation after successfully translating its body.
385
- static llvm::CanonicalLoopInfo *
386
- findCurrentLoopInfo (LLVM::ModuleTranslation &moduleTranslation) {
387
- llvm::CanonicalLoopInfo *loopInfo = nullptr ;
396
+ static SmallVector< llvm::CanonicalLoopInfo *>
397
+ findCurrentLoopInfos (LLVM::ModuleTranslation &moduleTranslation) {
398
+ SmallVector< llvm::CanonicalLoopInfo *> loopInfos ;
388
399
moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
389
400
[&](OpenMPLoopInfoStackFrame &frame) {
390
- loopInfo = frame.loopInfo ;
401
+ loopInfos = frame.loopInfos ;
391
402
return WalkResult::interrupt ();
392
403
});
393
- return loopInfo ;
404
+ return loopInfos ;
394
405
}
395
406
396
407
// / Converts the given region that appears within an OpenMP dialect operation to
@@ -1133,6 +1144,11 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
1133
1144
// variables. Although this could be done after allocas, we don't want to mess
1134
1145
// up with the alloca insertion point.
1135
1146
for (unsigned i = 0 ; i < op.getNumReductionVars (); ++i) {
1147
+
1148
+ llvm::Type *reductionType =
1149
+ moduleTranslation.convertType (reductionDecls[i].getType ());
1150
+ ReductionVarToType[privateReductionVariables[i]] = reductionType;
1151
+
1136
1152
SmallVector<llvm::Value *, 1 > phis;
1137
1153
1138
1154
// map block argument to initializer region
@@ -1206,9 +1222,11 @@ static void collectReductionInfo(
1206
1222
atomicGen = owningAtomicReductionGens[i];
1207
1223
llvm::Value *variable =
1208
1224
moduleTranslation.lookupValue (loop.getReductionVars ()[i]);
1209
- reductionInfos.push_back (
1210
- {moduleTranslation.convertType (reductionDecls[i].getType ()), variable,
1211
- privateReductionVariables[i],
1225
+ llvm::Type *reductionType =
1226
+ moduleTranslation.convertType (reductionDecls[i].getType ());
1227
+ ReductionVarToType[privateReductionVariables[i]] = reductionType;
1228
+ reductionInfos.push_back (
1229
+ {reductionType, variable, privateReductionVariables[i],
1212
1230
/* EvaluationKind=*/ llvm::OpenMPIRBuilder::EvalKind::Scalar,
1213
1231
owningReductionGens[i],
1214
1232
/* ReductionGenClang=*/ nullptr , atomicGen});
@@ -2342,27 +2360,60 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2342
2360
if (failed (handleError (regionBlock, opInst)))
2343
2361
return failure ();
2344
2362
2345
- builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2346
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2347
-
2348
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2349
- ompBuilder->applyWorkshareLoop (
2350
- ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
2351
- convertToScheduleKind (schedule), chunk, isSimd,
2352
- scheduleMod == omp::ScheduleModifier::monotonic,
2353
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2354
- workshareLoopType);
2355
-
2356
- if (failed (handleError (wsloopIP, opInst)))
2357
- return failure ();
2358
-
2359
- // Process the reductions if required.
2360
- if (failed (createReductionsAndCleanup (
2361
- wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2362
- privateReductionVariables, isByRef, wsloopOp.getNowait (),
2363
- /* isTeamsReduction=*/ false )))
2364
- return failure ();
2363
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
2364
+ findCurrentLoopInfos (moduleTranslation);
2365
+ auto inputLoopFinishIp = loopInfos.front ()->getAfterIP ();
2366
+ bool isInScanRegion =
2367
+ wsloopOp.getReductionMod () && (wsloopOp.getReductionMod ().value () ==
2368
+ mlir::omp::ReductionModifier::inscan);
2369
+ if (isInScanRegion) {
2370
+ builder.restoreIP (inputLoopFinishIp);
2371
+ SmallVector<OwningReductionGen> owningReductionGens;
2372
+ SmallVector<OwningAtomicReductionGen> owningAtomicReductionGens;
2373
+ SmallVector<llvm::OpenMPIRBuilder::ReductionInfo> reductionInfos;
2374
+ collectReductionInfo (wsloopOp, builder, moduleTranslation, reductionDecls,
2375
+ owningReductionGens, owningAtomicReductionGens,
2376
+ privateReductionVariables, reductionInfos);
2377
+ llvm::BasicBlock *cont = splitBB (builder, false , " omp.scan.loop.cont" );
2378
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy redIP =
2379
+ ompBuilder->emitScanReduction (builder.saveIP (), reductionInfos);
2380
+ if (failed (handleError (redIP, opInst)))
2381
+ return failure ();
2365
2382
2383
+ builder.restoreIP (*redIP);
2384
+ builder.CreateBr (cont);
2385
+ }
2386
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
2387
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
2388
+ ompBuilder->applyWorkshareLoop (
2389
+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
2390
+ convertToScheduleKind (schedule), chunk, isSimd,
2391
+ scheduleMod == omp::ScheduleModifier::monotonic,
2392
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
2393
+ workshareLoopType);
2394
+
2395
+ if (failed (handleError (wsloopIP, opInst)))
2396
+ return failure ();
2397
+ }
2398
+ builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2399
+ if (isInScanRegion) {
2400
+ SmallVector<Region *> reductionRegions;
2401
+ llvm::transform (reductionDecls, std::back_inserter (reductionRegions),
2402
+ [](omp::DeclareReductionOp reductionDecl) {
2403
+ return &reductionDecl.getCleanupRegion ();
2404
+ });
2405
+ if (failed (inlineOmpRegionCleanup (
2406
+ reductionRegions, privateReductionVariables, moduleTranslation,
2407
+ builder, " omp.reduction.cleanup" )))
2408
+ return failure ();
2409
+ } else {
2410
+ // Process the reductions if required.
2411
+ if (failed (createReductionsAndCleanup (
2412
+ wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
2413
+ privateReductionVariables, isByRef, wsloopOp.getNowait (),
2414
+ /* isTeamsReduction=*/ false )))
2415
+ return failure ();
2416
+ }
2366
2417
return cleanupPrivateVars (builder, moduleTranslation, wsloopOp.getLoc (),
2367
2418
privateVarsInfo.llvmVars ,
2368
2419
privateVarsInfo.privatizers );
@@ -2528,6 +2579,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
2528
2579
2529
2580
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
2530
2581
findAllocaInsertPoint (builder, moduleTranslation);
2582
+ parallelAllocaIP = allocaIP;
2531
2583
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2532
2584
2533
2585
llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
@@ -2553,6 +2605,64 @@ convertOrderKind(std::optional<omp::ClauseOrderKind> o) {
2553
2605
llvm_unreachable (" Unknown ClauseOrderKind kind" );
2554
2606
}
2555
2607
2608
+ static LogicalResult
2609
+ convertOmpScan (Operation &opInst, llvm::IRBuilderBase &builder,
2610
+ LLVM::ModuleTranslation &moduleTranslation) {
2611
+ if (failed (checkImplementationStatus (opInst)))
2612
+ return failure ();
2613
+ auto scanOp = cast<omp::ScanOp>(opInst);
2614
+ bool isInclusive = scanOp.hasInclusiveVars ();
2615
+ SmallVector<llvm::Value *> llvmScanVars;
2616
+ SmallVector<llvm::Type *> llvmScanVarsType;
2617
+ mlir::OperandRange mlirScanVars = scanOp.getInclusiveVars ();
2618
+ if (!isInclusive)
2619
+ mlirScanVars = scanOp.getExclusiveVars ();
2620
+ for (auto val : mlirScanVars) {
2621
+ llvm::Value *llvmVal = moduleTranslation.lookupValue (val);
2622
+ llvmScanVars.push_back (llvmVal);
2623
+ llvmScanVarsType.push_back (ReductionVarToType[llvmVal]);
2624
+ val.getDefiningOp ();
2625
+ }
2626
+ auto parallelOp = scanOp->getParentOfType <omp::ParallelOp>();
2627
+ if (!parallelOp) {
2628
+ return failure ();
2629
+ }
2630
+ llvm::OpenMPIRBuilder::InsertPointTy allocaIP = parallelAllocaIP;
2631
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2632
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
2633
+ moduleTranslation.getOpenMPBuilder ()->createScan (
2634
+ ompLoc, allocaIP, llvmScanVars, llvmScanVarsType, isInclusive);
2635
+ if (failed (handleError (afterIP, opInst)))
2636
+ return failure ();
2637
+
2638
+ builder.restoreIP (*afterIP);
2639
+
2640
+ // TODO: The argument of LoopnestOp is stored into the index variable and this
2641
+ // variable is used across scan operation. However that makes the mlir
2642
+ // invalid.(`Intra-iteration dependences from a statement in the structured
2643
+ // block sequence that precede a scan directive to a statement in the
2644
+ // structured block sequence that follows a scan directive must not exist,
2645
+ // except for dependences for the list items specified in an inclusive or
2646
+ // exclusive clause.`). The argument of LoopNestOp need to be loaded again
2647
+ // after ScanOp again so mlir generated is valid.
2648
+ auto parentOp = scanOp->getParentOp ();
2649
+ auto loopOp = cast<omp::LoopNestOp>(parentOp);
2650
+ if (loopOp) {
2651
+ auto &firstBlock = *(scanOp->getParentRegion ()->getBlocks ()).begin ();
2652
+ auto &ins = *(firstBlock.begin ());
2653
+ if (isa<LLVM::StoreOp>(ins)) {
2654
+ LLVM::StoreOp storeOp = dyn_cast<LLVM::StoreOp>(ins);
2655
+ auto src = moduleTranslation.lookupValue (storeOp->getOperand (0 ));
2656
+ if (src == moduleTranslation.lookupValue (
2657
+ (loopOp.getRegion ().getArguments ())[0 ])) {
2658
+ auto dest = moduleTranslation.lookupValue (storeOp->getOperand (1 ));
2659
+ builder.CreateStore (src, dest);
2660
+ }
2661
+ }
2662
+ }
2663
+ return success ();
2664
+ }
2665
+
2556
2666
// / Converts an OpenMP simd loop into LLVM IR using OpenMPIRBuilder.
2557
2667
static LogicalResult
2558
2668
convertOmpSimd (Operation &opInst, llvm::IRBuilderBase &builder,
@@ -2626,13 +2736,15 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
2626
2736
return failure ();
2627
2737
2628
2738
builder.SetInsertPoint (*regionBlock, (*regionBlock)->begin ());
2629
- llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo (moduleTranslation);
2630
- ompBuilder->applySimd (loopInfo, alignedVars,
2631
- simdOp.getIfExpr ()
2632
- ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2633
- : nullptr ,
2634
- order, simdlen, safelen);
2635
-
2739
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
2740
+ findCurrentLoopInfos (moduleTranslation);
2741
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
2742
+ ompBuilder->applySimd (
2743
+ loopInfo, alignedVars,
2744
+ simdOp.getIfExpr () ? moduleTranslation.lookupValue (simdOp.getIfExpr ())
2745
+ : nullptr ,
2746
+ order, simdlen, safelen);
2747
+ }
2636
2748
return cleanupPrivateVars (builder, moduleTranslation, simdOp.getLoc (),
2637
2749
privateVarsInfo.llvmVars ,
2638
2750
privateVarsInfo.privatizers );
@@ -2698,16 +2810,51 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
2698
2810
ompLoc.DL );
2699
2811
computeIP = loopInfos.front ()->getPreheaderIP ();
2700
2812
}
2813
+ if (auto wsloopOp = loopOp->getParentOfType <omp::WsloopOp>()) {
2814
+ bool isInScanRegion =
2815
+ wsloopOp.getReductionMod () && (wsloopOp.getReductionMod ().value () ==
2816
+ mlir::omp::ReductionModifier::inscan);
2817
+ if (isInScanRegion) {
2818
+ // TODO: Handle nesting if Scan loop is nested in a loop
2819
+ assert (loopOp.getNumLoops () == 1 );
2820
+ llvm::Expected<SmallVector<llvm::CanonicalLoopInfo *>> loopResults =
2821
+ ompBuilder->createCanonicalScanLoops (
2822
+ loc, bodyGen, lowerBound, upperBound, step,
2823
+ /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP,
2824
+ " loop" );
2825
+
2826
+ if (failed (handleError (loopResults, *loopOp)))
2827
+ return failure ();
2828
+ auto inputLoop = loopResults->front ();
2829
+ auto scanLoop = loopResults->back ();
2830
+ moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
2831
+ [&](OpenMPLoopInfoStackFrame &frame) {
2832
+ frame.loopInfos .push_back (inputLoop);
2833
+ frame.loopInfos .push_back (scanLoop);
2834
+ return WalkResult::interrupt ();
2835
+ });
2836
+ builder.restoreIP (scanLoop->getAfterIP ());
2837
+ return success ();
2838
+ } else {
2839
+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2840
+ ompBuilder->createCanonicalLoop (
2841
+ loc, bodyGen, lowerBound, upperBound, step,
2842
+ /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP);
2843
+ if (failed (handleError (loopResult, *loopOp)))
2844
+ return failure ();
2845
+ loopInfos.push_back (*loopResult);
2846
+ }
2847
+ } else {
2848
+ llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2849
+ ompBuilder->createCanonicalLoop (
2850
+ loc, bodyGen, lowerBound, upperBound, step,
2851
+ /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP);
2701
2852
2702
- llvm::Expected<llvm::CanonicalLoopInfo *> loopResult =
2703
- ompBuilder->createCanonicalLoop (
2704
- loc, bodyGen, lowerBound, upperBound, step,
2705
- /* IsSigned=*/ true , loopOp.getLoopInclusive (), computeIP);
2706
-
2707
- if (failed (handleError (loopResult, *loopOp)))
2708
- return failure ();
2853
+ if (failed (handleError (loopResult, *loopOp)))
2854
+ return failure ();
2709
2855
2710
- loopInfos.push_back (*loopResult);
2856
+ loopInfos.push_back (*loopResult);
2857
+ }
2711
2858
}
2712
2859
2713
2860
// Collapse loops. Store the insertion point because LoopInfos may get
@@ -2719,7 +2866,8 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
2719
2866
// after applying transformations.
2720
2867
moduleTranslation.stackWalk <OpenMPLoopInfoStackFrame>(
2721
2868
[&](OpenMPLoopInfoStackFrame &frame) {
2722
- frame.loopInfo = ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {});
2869
+ frame.loopInfos .push_back (
2870
+ ompBuilder->collapseLoops (ompLoc.DL , loopInfos, {}));
2723
2871
return WalkResult::interrupt ();
2724
2872
});
2725
2873
@@ -4328,19 +4476,20 @@ convertOmpDistribute(Operation &opInst, llvm::IRBuilderBase &builder,
4328
4476
llvm::omp::WorksharingLoopType::DistributeStaticLoop;
4329
4477
bool loopNeedsBarrier = false ;
4330
4478
llvm::Value *chunk = nullptr ;
4331
-
4332
- llvm::CanonicalLoopInfo *loopInfo =
4333
- findCurrentLoopInfo (moduleTranslation);
4334
- llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4335
- ompBuilder->applyWorkshareLoop (
4336
- ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
4337
- convertToScheduleKind (schedule), chunk, isSimd,
4338
- scheduleMod == omp::ScheduleModifier::monotonic,
4339
- scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4340
- workshareLoopType);
4341
-
4342
- if (!wsloopIP)
4343
- return wsloopIP.takeError ();
4479
+ SmallVector<llvm::CanonicalLoopInfo *> loopInfos =
4480
+ findCurrentLoopInfos (moduleTranslation);
4481
+ for (llvm::CanonicalLoopInfo *loopInfo : loopInfos) {
4482
+ llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP =
4483
+ ompBuilder->applyWorkshareLoop (
4484
+ ompLoc.DL , loopInfo, allocaIP, loopNeedsBarrier,
4485
+ convertToScheduleKind (schedule), chunk, isSimd,
4486
+ scheduleMod == omp::ScheduleModifier::monotonic,
4487
+ scheduleMod == omp::ScheduleModifier::nonmonotonic, isOrdered,
4488
+ workshareLoopType);
4489
+
4490
+ if (!wsloopIP)
4491
+ return wsloopIP.takeError ();
4492
+ }
4344
4493
}
4345
4494
4346
4495
if (failed (cleanupPrivateVars (builder, moduleTranslation,
@@ -5373,6 +5522,9 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
5373
5522
.Case ([&](omp::SimdOp) {
5374
5523
return convertOmpSimd (*op, builder, moduleTranslation);
5375
5524
})
5525
+ .Case ([&](omp::ScanOp) {
5526
+ return convertOmpScan (*op, builder, moduleTranslation);
5527
+ })
5376
5528
.Case ([&](omp::AtomicReadOp) {
5377
5529
return convertOmpAtomicRead (*op, builder, moduleTranslation);
5378
5530
})
0 commit comments