Skip to content

Commit f8562e2

Browse files
authored
[flang][OpenMP][NFC] Further refactoring for genOpWithBody & (#80839)
`createBodyOfOp` This refactors the arguments to the above functions in 2 ways: - Combines the 2 structs of arguments into one since they were almost identical. - Replaces the `args` argument with a callback to a rebion-body generation function. This is a preparation for delayed privatization as we will need different callbacks for ws loops and parallel ops with delayed privatization.
1 parent 88c830a commit f8562e2

File tree

1 file changed

+141
-79
lines changed

1 file changed

+141
-79
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 141 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,31 +2266,68 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
22662266
return storeOp;
22672267
}
22682268

2269-
struct CreateBodyOfOpInfo {
2269+
struct OpWithBodyGenInfo {
2270+
/// A type for a code-gen callback function. This takes as argument the op for
2271+
/// which the code is being generated and returns the arguments of the op's
2272+
/// region.
2273+
using GenOMPRegionEntryCBFn =
2274+
std::function<llvm::SmallVector<const Fortran::semantics::Symbol *>(
2275+
mlir::Operation *)>;
2276+
2277+
OpWithBodyGenInfo(Fortran::lower::AbstractConverter &converter,
2278+
mlir::Location loc, Fortran::lower::pft::Evaluation &eval)
2279+
: converter(converter), loc(loc), eval(eval) {}
2280+
2281+
OpWithBodyGenInfo &setGenNested(bool value) {
2282+
genNested = value;
2283+
return *this;
2284+
}
2285+
2286+
OpWithBodyGenInfo &setOuterCombined(bool value) {
2287+
outerCombined = value;
2288+
return *this;
2289+
}
2290+
2291+
OpWithBodyGenInfo &setClauses(const Fortran::parser::OmpClauseList *value) {
2292+
clauses = value;
2293+
return *this;
2294+
}
2295+
2296+
OpWithBodyGenInfo &setDataSharingProcessor(DataSharingProcessor *value) {
2297+
dsp = value;
2298+
return *this;
2299+
}
2300+
2301+
OpWithBodyGenInfo &setGenRegionEntryCb(GenOMPRegionEntryCBFn value) {
2302+
genRegionEntryCB = value;
2303+
return *this;
2304+
}
2305+
2306+
/// [inout] converter to use for the clauses.
22702307
Fortran::lower::AbstractConverter &converter;
2271-
mlir::Location &loc;
2308+
/// [in] location in source code.
2309+
mlir::Location loc;
2310+
/// [in] current PFT node/evaluation.
22722311
Fortran::lower::pft::Evaluation &eval;
2312+
/// [in] whether to generate FIR for nested evaluations
22732313
bool genNested = true;
2274-
const Fortran::parser::OmpClauseList *clauses = nullptr;
2275-
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {};
2314+
/// [in] is this an outer operation - prevents privatization.
22762315
bool outerCombined = false;
2316+
/// [in] list of clauses to process.
2317+
const Fortran::parser::OmpClauseList *clauses = nullptr;
2318+
/// [in] if provided, processes the construct's data-sharing attributes.
22772319
DataSharingProcessor *dsp = nullptr;
2320+
/// [in] if provided, emits the op's region entry. Otherwise, an emtpy block
2321+
/// is created in the region.
2322+
GenOMPRegionEntryCBFn genRegionEntryCB = nullptr;
22782323
};
22792324

22802325
/// Create the body (block) for an OpenMP Operation.
22812326
///
2282-
/// \param [in] op - the operation the body belongs to.
2283-
/// \param [inout] converter - converter to use for the clauses.
2284-
/// \param [in] loc - location in source code.
2285-
/// \param [in] eval - current PFT node/evaluation.
2286-
/// \param [in] genNested - whether to generate FIR for nested evaluations
2287-
/// \oaran [in] clauses - list of clauses to process.
2288-
/// \param [in] args - block arguments (induction variable[s]) for the
2289-
//// region.
2290-
/// \param [in] outerCombined - is this an outer operation - prevents
2291-
/// privatization.
2327+
/// \param [in] op - the operation the body belongs to.
2328+
/// \param [in] info - options controlling code-gen for the construction.
22922329
template <typename Op>
2293-
static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) {
2330+
static void createBodyOfOp(Op &op, OpWithBodyGenInfo &info) {
22942331
fir::FirOpBuilder &firOpBuilder = info.converter.getFirOpBuilder();
22952332

22962333
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -2303,28 +2340,15 @@ static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) {
23032340
// argument. Also update the symbol's address with the mlir argument value.
23042341
// e.g. For loops the argument is the induction variable. And all further
23052342
// uses of the induction variable should use this mlir value.
2306-
if (info.args.size()) {
2307-
std::size_t loopVarTypeSize = 0;
2308-
for (const Fortran::semantics::Symbol *arg : info.args)
2309-
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
2310-
mlir::Type loopVarType = getLoopVarType(info.converter, loopVarTypeSize);
2311-
llvm::SmallVector<mlir::Type> tiv(info.args.size(), loopVarType);
2312-
llvm::SmallVector<mlir::Location> locs(info.args.size(), info.loc);
2313-
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
2314-
// The argument is not currently in memory, so make a temporary for the
2315-
// argument, and store it there, then bind that location to the argument.
2316-
mlir::Operation *storeOp = nullptr;
2317-
for (auto [argIndex, argSymbol] : llvm::enumerate(info.args)) {
2318-
mlir::Value indexVal =
2319-
fir::getBase(op.getRegion().front().getArgument(argIndex));
2320-
storeOp = createAndSetPrivatizedLoopVar(info.converter, info.loc,
2321-
indexVal, argSymbol);
2343+
auto regionArgs =
2344+
[&]() -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2345+
if (info.genRegionEntryCB != nullptr) {
2346+
return info.genRegionEntryCB(op);
23222347
}
2323-
firOpBuilder.setInsertionPointAfter(storeOp);
2324-
} else {
2325-
firOpBuilder.createBlock(&op.getRegion());
2326-
}
23272348

2349+
firOpBuilder.createBlock(&op.getRegion());
2350+
return {};
2351+
}();
23282352
// Mark the earliest insertion point.
23292353
mlir::Operation *marker = insertMarker(firOpBuilder);
23302354

@@ -2421,8 +2445,8 @@ static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) {
24212445
assert(tempDsp.has_value());
24222446
tempDsp->processStep2(op, isLoop);
24232447
} else {
2424-
if (isLoop && info.args.size() > 0)
2425-
info.dsp->setLoopIV(info.converter.getSymbolAddress(*info.args[0]));
2448+
if (isLoop && regionArgs.size() > 0)
2449+
info.dsp->setLoopIV(info.converter.getSymbolAddress(*regionArgs[0]));
24262450
info.dsp->processStep2(op, isLoop);
24272451
}
24282452
}
@@ -2497,24 +2521,11 @@ static void genBodyOfTargetDataOp(
24972521
genNestedEvaluations(converter, eval);
24982522
}
24992523

2500-
struct GenOpWithBodyInfo {
2501-
Fortran::lower::AbstractConverter &converter;
2502-
Fortran::lower::pft::Evaluation &eval;
2503-
bool genNested = false;
2504-
mlir::Location currentLocation;
2505-
bool outerCombined = false;
2506-
const Fortran::parser::OmpClauseList *clauseList = nullptr;
2507-
};
2508-
25092524
template <typename OpTy, typename... Args>
2510-
static OpTy genOpWithBody(GenOpWithBodyInfo info, Args &&...args) {
2525+
static OpTy genOpWithBody(OpWithBodyGenInfo &info, Args &&...args) {
25112526
auto op = info.converter.getFirOpBuilder().create<OpTy>(
2512-
info.currentLocation, std::forward<Args>(args)...);
2513-
createBodyOfOp<OpTy>(
2514-
op, {info.converter, info.currentLocation, info.eval, info.genNested,
2515-
info.clauseList,
2516-
/*args*/ llvm::SmallVector<const Fortran::semantics::Symbol *>{},
2517-
info.outerCombined});
2527+
info.loc, std::forward<Args>(args)...);
2528+
createBodyOfOp<OpTy>(op, info);
25182529
return op;
25192530
}
25202531

@@ -2523,7 +2534,8 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
25232534
Fortran::lower::pft::Evaluation &eval, bool genNested,
25242535
mlir::Location currentLocation) {
25252536
return genOpWithBody<mlir::omp::MasterOp>(
2526-
{converter, eval, genNested, currentLocation},
2537+
OpWithBodyGenInfo(converter, currentLocation, eval)
2538+
.setGenNested(genNested),
25272539
/*resultTypes=*/mlir::TypeRange());
25282540
}
25292541

@@ -2532,7 +2544,8 @@ genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
25322544
Fortran::lower::pft::Evaluation &eval, bool genNested,
25332545
mlir::Location currentLocation) {
25342546
return genOpWithBody<mlir::omp::OrderedRegionOp>(
2535-
{converter, eval, genNested, currentLocation},
2547+
OpWithBodyGenInfo(converter, currentLocation, eval)
2548+
.setGenNested(genNested),
25362549
/*simd=*/false);
25372550
}
25382551

@@ -2560,7 +2573,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25602573
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
25612574

25622575
return genOpWithBody<mlir::omp::ParallelOp>(
2563-
{converter, eval, genNested, currentLocation, outerCombined, &clauseList},
2576+
OpWithBodyGenInfo(converter, currentLocation, eval)
2577+
.setGenNested(genNested)
2578+
.setOuterCombined(outerCombined)
2579+
.setClauses(&clauseList),
25642580
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
25652581
numThreadsClauseOperand, allocateOperands, allocatorOperands,
25662582
reductionVars,
@@ -2579,8 +2595,9 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
25792595
// Currently only private/firstprivate clause is handled, and
25802596
// all privatization is done within `omp.section` operations.
25812597
return genOpWithBody<mlir::omp::SectionOp>(
2582-
{converter, eval, genNested, currentLocation,
2583-
/*outerCombined=*/false, &sectionsClauseList});
2598+
OpWithBodyGenInfo(converter, currentLocation, eval)
2599+
.setGenNested(genNested)
2600+
.setClauses(&sectionsClauseList));
25842601
}
25852602

25862603
static mlir::omp::SingleOp
@@ -2600,8 +2617,9 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
26002617
ClauseProcessor(converter, endClauseList).processNowait(nowaitAttr);
26012618

26022619
return genOpWithBody<mlir::omp::SingleOp>(
2603-
{converter, eval, genNested, currentLocation,
2604-
/*outerCombined=*/false, &beginClauseList},
2620+
OpWithBodyGenInfo(converter, currentLocation, eval)
2621+
.setGenNested(genNested)
2622+
.setClauses(&beginClauseList),
26052623
allocateOperands, allocatorOperands, nowaitAttr);
26062624
}
26072625

@@ -2633,8 +2651,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
26332651
currentLocation, llvm::omp::Directive::OMPD_task);
26342652

26352653
return genOpWithBody<mlir::omp::TaskOp>(
2636-
{converter, eval, genNested, currentLocation,
2637-
/*outerCombined=*/false, &clauseList},
2654+
OpWithBodyGenInfo(converter, currentLocation, eval)
2655+
.setGenNested(genNested)
2656+
.setClauses(&clauseList),
26382657
ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
26392658
/*in_reduction_vars=*/mlir::ValueRange(),
26402659
/*in_reductions=*/nullptr, priorityClauseOperand,
@@ -2656,8 +2675,9 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
26562675
cp.processTODO<Fortran::parser::OmpClause::TaskReduction>(
26572676
currentLocation, llvm::omp::Directive::OMPD_taskgroup);
26582677
return genOpWithBody<mlir::omp::TaskGroupOp>(
2659-
{converter, eval, genNested, currentLocation,
2660-
/*outerCombined=*/false, &clauseList},
2678+
OpWithBodyGenInfo(converter, currentLocation, eval)
2679+
.setGenNested(genNested)
2680+
.setClauses(&clauseList),
26612681
/*task_reduction_vars=*/mlir::ValueRange(),
26622682
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
26632683
}
@@ -3040,7 +3060,10 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
30403060
currentLocation, llvm::omp::Directive::OMPD_teams);
30413061

30423062
return genOpWithBody<mlir::omp::TeamsOp>(
3043-
{converter, eval, genNested, currentLocation, outerCombined, &clauseList},
3063+
OpWithBodyGenInfo(converter, currentLocation, eval)
3064+
.setGenNested(genNested)
3065+
.setOuterCombined(outerCombined)
3066+
.setClauses(&clauseList),
30443067
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
30453068
threadLimitClauseOperand, allocateOperands, allocatorOperands,
30463069
reductionVars,
@@ -3237,6 +3260,33 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
32373260
}
32383261
}
32393262

3263+
static llvm::SmallVector<const Fortran::semantics::Symbol *>
3264+
genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3265+
mlir::Location &loc,
3266+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args) {
3267+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
3268+
auto &region = op->getRegion(0);
3269+
3270+
std::size_t loopVarTypeSize = 0;
3271+
for (const Fortran::semantics::Symbol *arg : args)
3272+
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
3273+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
3274+
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
3275+
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
3276+
firOpBuilder.createBlock(&region, {}, tiv, locs);
3277+
// The argument is not currently in memory, so make a temporary for the
3278+
// argument, and store it there, then bind that location to the argument.
3279+
mlir::Operation *storeOp = nullptr;
3280+
for (auto [argIndex, argSymbol] : llvm::enumerate(args)) {
3281+
mlir::Value indexVal = fir::getBase(region.front().getArgument(argIndex));
3282+
storeOp =
3283+
createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
3284+
}
3285+
firOpBuilder.setInsertionPointAfter(storeOp);
3286+
3287+
return args;
3288+
}
3289+
32403290
static void
32413291
createSimdLoop(Fortran::lower::AbstractConverter &converter,
32423292
Fortran::lower::pft::Evaluation &eval,
@@ -3284,10 +3334,16 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
32843334

32853335
auto *nestedEval = getCollapsedLoopEval(
32863336
eval, Fortran::lower::getCollapseValue(loopOpClauseList));
3337+
3338+
auto ivCallback = [&](mlir::Operation *op) {
3339+
return genLoopVars(op, converter, loc, iv);
3340+
};
3341+
32873342
createBodyOfOp<mlir::omp::SimdLoopOp>(
3288-
simdLoopOp, {converter, loc, *nestedEval,
3289-
/*genNested=*/true, &loopOpClauseList, iv,
3290-
/*outerCombined=*/false, &dsp});
3343+
simdLoopOp, OpWithBodyGenInfo(converter, loc, *nestedEval)
3344+
.setClauses(&loopOpClauseList)
3345+
.setDataSharingProcessor(&dsp)
3346+
.setGenRegionEntryCb(ivCallback));
32913347
}
32923348

32933349
static void createWsLoop(Fortran::lower::AbstractConverter &converter,
@@ -3360,10 +3416,16 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
33603416

33613417
auto *nestedEval = getCollapsedLoopEval(
33623418
eval, Fortran::lower::getCollapseValue(beginClauseList));
3363-
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp,
3364-
{converter, loc, *nestedEval,
3365-
/*genNested=*/true, &beginClauseList, iv,
3366-
/*outerCombined=*/false, &dsp});
3419+
3420+
auto ivCallback = [&](mlir::Operation *op) {
3421+
return genLoopVars(op, converter, loc, iv);
3422+
};
3423+
3424+
createBodyOfOp<mlir::omp::WsLoopOp>(
3425+
wsLoopOp, OpWithBodyGenInfo(converter, loc, *nestedEval)
3426+
.setClauses(&beginClauseList)
3427+
.setDataSharingProcessor(&dsp)
3428+
.setGenRegionEntryCb(ivCallback));
33673429
}
33683430

33693431
static void createSimdWsLoop(
@@ -3644,8 +3706,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
36443706
currentLocation, mlir::FlatSymbolRefAttr::get(firOpBuilder.getContext(),
36453707
global.getSymName()));
36463708
}();
3647-
createBodyOfOp<mlir::omp::CriticalOp>(criticalOp,
3648-
{converter, currentLocation, eval});
3709+
auto genInfo = OpWithBodyGenInfo(converter, currentLocation, eval);
3710+
createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, genInfo);
36493711
}
36503712

36513713
static void
@@ -3687,11 +3749,11 @@ genOMP(Fortran::lower::AbstractConverter &converter,
36873749
}
36883750

36893751
// SECTIONS construct
3690-
genOpWithBody<mlir::omp::SectionsOp>({converter, eval,
3691-
/*genNested=*/false, currentLocation},
3692-
/*reduction_vars=*/mlir::ValueRange(),
3693-
/*reductions=*/nullptr, allocateOperands,
3694-
allocatorOperands, nowaitClauseOperand);
3752+
genOpWithBody<mlir::omp::SectionsOp>(
3753+
OpWithBodyGenInfo(converter, currentLocation, eval).setGenNested(false),
3754+
/*reduction_vars=*/mlir::ValueRange(),
3755+
/*reductions=*/nullptr, allocateOperands, allocatorOperands,
3756+
nowaitClauseOperand);
36953757

36963758
const auto &sectionBlocks =
36973759
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t);

0 commit comments

Comments
 (0)