Skip to content

Commit 8e5d75f

Browse files
committed
Add region args.
1 parent 299fb04 commit 8e5d75f

File tree

3 files changed

+164
-70
lines changed

3 files changed

+164
-70
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 132 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,14 @@ static void genNestedEvaluations(Fortran::lower::AbstractConverter &converter,
147147
//===----------------------------------------------------------------------===//
148148

149149
class DataSharingProcessor {
150+
public:
151+
struct DelayedPrivatizationInfo {
152+
llvm::SetVector<mlir::SymbolRefAttr> privatizers;
153+
llvm::SetVector<mlir::Value> hostAddresses;
154+
llvm::SetVector<const Fortran::semantics::Symbol *> hostSymbols;
155+
};
156+
157+
private:
150158
bool hasLastPrivateOp;
151159
mlir::OpBuilder::InsertPoint lastPrivIP;
152160
mlir::OpBuilder::InsertPoint insPt;
@@ -163,8 +171,8 @@ class DataSharingProcessor {
163171

164172
bool useDelayedPrivatizationWhenPossible;
165173
Fortran::lower::SymMap *symTable;
166-
llvm::SetVector<mlir::SymbolRefAttr> privatizers;
167-
llvm::SetVector<mlir::Value> privateSymHostAddrsses;
174+
175+
DelayedPrivatizationInfo delayedPrivatizationInfo;
168176

169177
bool needBarrier();
170178
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
@@ -214,12 +222,8 @@ class DataSharingProcessor {
214222
loopIV = iv;
215223
}
216224

217-
const llvm::SetVector<mlir::SymbolRefAttr> &getPrivatizers() const {
218-
return privatizers;
219-
};
220-
221-
const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
222-
return privateSymHostAddrsses;
225+
const DelayedPrivatizationInfo &getDelayedPrivatizationInfo() const {
226+
return delayedPrivatizationInfo;
223227
}
224228
};
225229

@@ -547,8 +551,10 @@ void DataSharingProcessor::privatize() {
547551
symTable->popScope();
548552
firOpBuilder.restoreInsertionPoint(ip);
549553

550-
privatizers.insert(mlir::SymbolRefAttr::get(privatizerOp));
551-
privateSymHostAddrsses.insert(hsb.getAddr());
554+
delayedPrivatizationInfo.privatizers.insert(
555+
mlir::SymbolRefAttr::get(privatizerOp));
556+
delayedPrivatizationInfo.hostAddresses.insert(hsb.getAddr());
557+
delayedPrivatizationInfo.hostSymbols.insert(sym);
552558
} else {
553559
cloneSymbol(sym);
554560
copyFirstPrivateSymbol(sym);
@@ -2305,7 +2311,9 @@ static void createBodyOfOp(
23052311
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
23062312
Fortran::lower::pft::Evaluation &eval, bool genNested,
23072313
const Fortran::parser::OmpClauseList *clauses = nullptr,
2308-
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2314+
std::function<llvm::SmallVector<const Fortran::semantics::Symbol *>(
2315+
mlir::Operation *)>
2316+
genRegionEntryCB = nullptr,
23092317
bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
23102318
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
23112319

@@ -2319,27 +2327,15 @@ static void createBodyOfOp(
23192327
// argument. Also update the symbol's address with the mlir argument value.
23202328
// e.g. For loops the argument is the induction variable. And all further
23212329
// uses of the induction variable should use this mlir value.
2322-
if (args.size()) {
2323-
std::size_t loopVarTypeSize = 0;
2324-
for (const Fortran::semantics::Symbol *arg : args)
2325-
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
2326-
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2327-
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
2328-
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
2329-
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
2330-
// The argument is not currently in memory, so make a temporary for the
2331-
// argument, and store it there, then bind that location to the argument.
2332-
mlir::Operation *storeOp = nullptr;
2333-
for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
2334-
mlir::Value indexVal =
2335-
fir::getBase(op.getRegion().front().getArgument(argIndex));
2336-
storeOp =
2337-
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
2330+
auto regionArgs =
2331+
[&]() -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2332+
if (genRegionEntryCB != nullptr) {
2333+
return genRegionEntryCB(op);
23382334
}
2339-
firOpBuilder.setInsertionPointAfter(storeOp);
2340-
} else {
2335+
23412336
firOpBuilder.createBlock(&op.getRegion());
2342-
}
2337+
return {};
2338+
}();
23432339

23442340
// Mark the earliest insertion point.
23452341
mlir::Operation *marker = insertMarker(firOpBuilder);
@@ -2437,8 +2433,8 @@ static void createBodyOfOp(
24372433
assert(tempDsp.has_value());
24382434
tempDsp->processStep2(op, isLoop);
24392435
} else {
2440-
if (isLoop && args.size() > 0)
2441-
dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
2436+
if (isLoop && regionArgs.size() > 0)
2437+
dsp->setLoopIV(converter.getSymbolAddress(*regionArgs[0]));
24422438
dsp->processStep2(op, isLoop);
24432439
}
24442440
}
@@ -2514,41 +2510,44 @@ static void genBodyOfTargetDataOp(
25142510
}
25152511

25162512
template <typename OpTy, typename... Args>
2517-
static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
2518-
Fortran::lower::pft::Evaluation &eval, bool genNested,
2519-
mlir::Location currentLocation, bool outerCombined,
2520-
const Fortran::parser::OmpClauseList *clauseList,
2521-
DataSharingProcessor *dsp, Args &&...args) {
2513+
static OpTy genOpWithBody(
2514+
Fortran::lower::AbstractConverter &converter,
2515+
Fortran::lower::pft::Evaluation &eval, bool genNested,
2516+
mlir::Location currentLocation, bool outerCombined,
2517+
const Fortran::parser::OmpClauseList *clauseList,
2518+
std::function<llvm::SmallVector<const Fortran::semantics::Symbol *>(
2519+
mlir::Operation *)>
2520+
genRegionEntryCB,
2521+
DataSharingProcessor *dsp, Args &&...args) {
25222522
auto op = converter.getFirOpBuilder().create<OpTy>(
25232523
currentLocation, std::forward<Args>(args)...);
25242524
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
2525-
clauseList,
2526-
/*args=*/{}, outerCombined, dsp);
2525+
clauseList, genRegionEntryCB, outerCombined, dsp);
25272526
return op;
25282527
}
25292528

25302529
static mlir::omp::MasterOp
25312530
genMasterOp(Fortran::lower::AbstractConverter &converter,
25322531
Fortran::lower::pft::Evaluation &eval, bool genNested,
25332532
mlir::Location currentLocation) {
2534-
return genOpWithBody<mlir::omp::MasterOp>(converter, eval, genNested,
2535-
currentLocation,
2536-
/*outerCombined=*/false,
2537-
/*clauseList=*/nullptr,
2538-
/*dsp=*/nullptr,
2539-
/*resultTypes=*/mlir::TypeRange());
2533+
return genOpWithBody<mlir::omp::MasterOp>(
2534+
converter, eval, genNested, currentLocation,
2535+
/*outerCombined=*/false,
2536+
/*clauseList=*/nullptr, /*genRegionEntryCB=*/nullptr,
2537+
/*dsp=*/nullptr,
2538+
/*resultTypes=*/mlir::TypeRange());
25402539
}
25412540

25422541
static mlir::omp::OrderedRegionOp
25432542
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
25442543
Fortran::lower::pft::Evaluation &eval, bool genNested,
25452544
mlir::Location currentLocation) {
2546-
return genOpWithBody<mlir::omp::OrderedRegionOp>(converter, eval, genNested,
2547-
currentLocation,
2548-
/*outerCombined=*/false,
2549-
/*clauseList=*/nullptr,
2550-
/*dsp=*/nullptr,
2551-
/*simd=*/false);
2545+
return genOpWithBody<mlir::omp::OrderedRegionOp>(
2546+
converter, eval, genNested, currentLocation,
2547+
/*outerCombined=*/false,
2548+
/*clauseList=*/nullptr, /*genRegionEntryCB=*/nullptr,
2549+
/*dsp=*/nullptr,
2550+
/*simd=*/false);
25522551
}
25532552

25542553
static mlir::omp::ParallelOp
@@ -2584,16 +2583,44 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25842583
dsp.processStep1();
25852584
}
25862585

2587-
llvm::SmallVector<mlir::Attribute> privatizers(dsp.getPrivatizers().begin(),
2588-
dsp.getPrivatizers().end());
2586+
const auto &delayedPrivatizationInfo = dsp.getDelayedPrivatizationInfo();
2587+
llvm::SmallVector<mlir::Attribute> privatizers(
2588+
delayedPrivatizationInfo.privatizers.begin(),
2589+
delayedPrivatizationInfo.privatizers.end());
25892590

25902591
llvm::SmallVector<mlir::Value> privateSymAddresses(
2591-
dsp.getPrivateSymHostAddrsses().begin(),
2592-
dsp.getPrivateSymHostAddrsses().end());
2592+
delayedPrivatizationInfo.hostAddresses.begin(),
2593+
delayedPrivatizationInfo.hostAddresses.end());
2594+
2595+
auto genRegionEntryCB = [&](mlir::Operation *op) {
2596+
auto parallelOp = llvm::cast<mlir::omp::ParallelOp>(op);
2597+
auto privateVars = parallelOp.getPrivateVars();
2598+
auto &region = parallelOp.getRegion();
2599+
llvm::SmallVector<mlir::Type> privateVarTypes;
2600+
llvm::SmallVector<mlir::Location> privateVarLocs;
2601+
2602+
for (auto privateVar : privateVars) {
2603+
privateVarTypes.push_back(privateVar.getType());
2604+
privateVarLocs.push_back(privateVar.getLoc());
2605+
}
2606+
2607+
converter.getFirOpBuilder().createBlock(&region, {}, privateVarTypes,
2608+
privateVarLocs);
2609+
2610+
int argIdx = 0;
2611+
for (const auto *sym : delayedPrivatizationInfo.hostSymbols) {
2612+
converter.bindSymbol(*sym, region.getArgument(argIdx));
2613+
++argIdx;
2614+
}
2615+
2616+
return llvm::SmallVector<const Fortran::semantics::Symbol *>(
2617+
delayedPrivatizationInfo.hostSymbols.begin(),
2618+
delayedPrivatizationInfo.hostSymbols.end());
2619+
};
25932620

25942621
return genOpWithBody<mlir::omp::ParallelOp>(
25952622
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2596-
&dsp,
2623+
genRegionEntryCB, &dsp,
25972624
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
25982625
numThreadsClauseOperand, allocateOperands, allocatorOperands,
25992626
reductionVars,
@@ -2618,6 +2645,7 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
26182645
return genOpWithBody<mlir::omp::SectionOp>(
26192646
converter, eval, genNested, currentLocation,
26202647
/*outerCombined=*/false, &sectionsClauseList,
2648+
/*genRegionEntryCB=*/nullptr,
26212649
/*dsp=*/nullptr);
26222650
}
26232651

@@ -2639,8 +2667,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
26392667

26402668
return genOpWithBody<mlir::omp::SingleOp>(
26412669
converter, eval, genNested, currentLocation,
2642-
/*outerCombined=*/false, &beginClauseList, /*dsp=*/nullptr,
2643-
allocateOperands, allocatorOperands, nowaitAttr);
2670+
/*outerCombined=*/false, &beginClauseList, /*genRegionEntryCB=*/nullptr,
2671+
/*dsp=*/nullptr, allocateOperands, allocatorOperands, nowaitAttr);
26442672
}
26452673

26462674
static mlir::omp::TaskOp
@@ -2672,8 +2700,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
26722700

26732701
return genOpWithBody<mlir::omp::TaskOp>(
26742702
converter, eval, genNested, currentLocation,
2675-
/*outerCombined=*/false, &clauseList, /*dsp=*/nullptr, ifClauseOperand,
2676-
finalClauseOperand, untiedAttr, mergeableAttr,
2703+
/*outerCombined=*/false, &clauseList, /*genRegionEntryCB=*/nullptr,
2704+
/*dsp=*/nullptr, ifClauseOperand, finalClauseOperand, untiedAttr,
2705+
mergeableAttr,
26772706
/*in_reduction_vars=*/mlir::ValueRange(),
26782707
/*in_reductions=*/nullptr, priorityClauseOperand,
26792708
dependTypeOperands.empty()
@@ -2695,7 +2724,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
26952724
currentLocation, llvm::omp::Directive::OMPD_taskgroup);
26962725
return genOpWithBody<mlir::omp::TaskGroupOp>(
26972726
converter, eval, genNested, currentLocation,
2698-
/*outerCombined=*/false, &clauseList,
2727+
/*outerCombined=*/false, &clauseList, /*genRegionEntryCB=*/nullptr,
26992728
/*dsp=*/nullptr,
27002729
/*task_reduction_vars=*/mlir::ValueRange(),
27012730
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
@@ -3076,6 +3105,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
30763105

30773106
return genOpWithBody<mlir::omp::TeamsOp>(
30783107
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
3108+
/*genRegionEntryCB=*/nullptr,
30793109
/*dsp=*/nullptr,
30803110
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
30813111
threadLimitClauseOperand, allocateOperands, allocatorOperands,
@@ -3273,6 +3303,33 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
32733303
}
32743304
}
32753305

3306+
static llvm::SmallVector<const Fortran::semantics::Symbol *> genCodeForIterVar(
3307+
mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3308+
mlir::Location &loc,
3309+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args) {
3310+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3311+
auto &region = op->getRegion(0);
3312+
3313+
std::size_t loopVarTypeSize = 0;
3314+
for (const Fortran::semantics::Symbol *arg : args)
3315+
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
3316+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
3317+
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
3318+
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
3319+
firOpBuilder.createBlock(&region, {}, tiv, locs);
3320+
// The argument is not currently in memory, so make a temporary for the
3321+
// argument, and store it there, then bind that location to the argument.
3322+
mlir::Operation *storeOp = nullptr;
3323+
for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
3324+
mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex));
3325+
storeOp =
3326+
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
3327+
}
3328+
firOpBuilder.setInsertionPointAfter(storeOp);
3329+
3330+
return args;
3331+
}
3332+
32763333
static void
32773334
createSimdLoop(Fortran::lower::AbstractConverter &converter,
32783335
Fortran::lower::pft::Evaluation &eval,
@@ -3320,9 +3377,14 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
33203377

33213378
auto *nestedEval = getCollapsedLoopEval(
33223379
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
3380+
3381+
auto ivCallback = [&](mlir::Operation *op) {
3382+
return genCodeForIterVar(op, converter, loc, iv);
3383+
};
3384+
33233385
createBodyOfOp<mlir::omp::SimdLoopOp>(simdLoopOp, converter, loc, *nestedEval,
33243386
/*genNested=*/true, &loopOpClauseList,
3325-
iv, /*outer=*/false, &dsp);
3387+
ivCallback, /*outer=*/false, &dsp);
33263388
}
33273389

33283390
static void createWsLoop(Fortran::lower::AbstractConverter &converter,
@@ -3395,8 +3457,14 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
33953457

33963458
auto *nestedEval = getCollapsedLoopEval(
33973459
eval, Fortran::lower::getCollapseValue(beginClauseList));
3460+
3461+
auto ivCallback = [&](mlir::Operation *op) {
3462+
return genCodeForIterVar(op, converter, loc, iv);
3463+
};
3464+
33983465
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, loc, *nestedEval,
3399-
/*genNested=*/true, &beginClauseList, iv,
3466+
/*genNested=*/true, &beginClauseList,
3467+
ivCallback,
34003468
/*outer=*/false, &dsp);
34013469
}
34023470

@@ -3725,6 +3793,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
37253793
/*genNested=*/false, currentLocation,
37263794
/*outerCombined=*/false,
37273795
/*clauseList=*/nullptr,
3796+
/*genRegionEntryCB=*/nullptr,
37283797
/*dsp=*/nullptr,
37293798
/*reduction_vars=*/mlir::ValueRange(),
37303799
/*reductions=*/nullptr, allocateOperands,

flang/test/Lower/OpenMP/FIR/delayed_privatization.f90

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ subroutine delayed_privatization()
2929
! %c222_i32 = arith.constant 222 : i32
3030
! fir.store %c222_i32 to %1 : !fir.ref<i32>
3131
! omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
32-
! %2 = fir.load %0 : !fir.ref<i32>
33-
! %3 = fir.load %1 : !fir.ref<i32>
32+
! ^bb0(%arg0: !fir.ref<i32>, %arg1: !fir.ref<i32>):
33+
! %2 = fir.load %arg0 : !fir.ref<i32>
34+
! %3 = fir.load %arg1 : !fir.ref<i32>
3435
! %4 = arith.addi %2, %3 : i32
3536
! %c2_i32 = arith.constant 2 : i32
3637
! %5 = arith.addi %4, %c2_i32 : i32
37-
! fir.store %5 to %0 : !fir.ref<i32>
38+
! fir.store %5 to %arg0 : !fir.ref<i32>
3839
! omp.terminator
3940
! }
4041
! return
@@ -53,7 +54,6 @@ subroutine delayed_privatization()
5354
! fir.store %1 to %0 : !fir.ref<i32>
5455
! omp.yield(%0 : !fir.ref<i32>)
5556
! }) : () -> ()
56-
!}
5757
!
5858
! -----------------------------
5959
! ### Conversion to LLVM + OMP:
@@ -69,12 +69,13 @@ subroutine delayed_privatization()
6969
! %5 = llvm.mlir.constant(222 : i32) : i32
7070
! llvm.store %5, %3 : i32, !llvm.ptr
7171
! omp.parallel private(@var1.privatizer %1, @var2.privatizer %3 : !llvm.ptr, !llvm.ptr) {
72-
! %6 = llvm.load %1 : !llvm.ptr -> i32
73-
! %7 = llvm.load %3 : !llvm.ptr -> i32
72+
! ^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
73+
! %6 = llvm.load %arg0 : !llvm.ptr -> i32
74+
! %7 = llvm.load %arg1 : !llvm.ptr -> i32
7475
! %8 = llvm.add %6, %7 : i32
7576
! %9 = llvm.mlir.constant(2 : i32) : i32
7677
! %10 = llvm.add %8, %9 : i32
77-
! llvm.store %10, %1 : i32, !llvm.ptr
78+
! llvm.store %10, %arg0 : i32, !llvm.ptr
7879
! omp.terminator
7980
! }
8081
! llvm.return

0 commit comments

Comments
 (0)