@@ -2266,31 +2266,68 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
2266
2266
return storeOp;
2267
2267
}
2268
2268
2269
- struct CreateBodyOfOpInfo {
2269
+ struct OpWithBodyGenInfo {
2270
+ // / A type for a code-gen callback function. This takes as argument the op for
2271
+ // / which the code is being generated and returns the arguments of the op's
2272
+ // / region.
2273
+ using GenOMPRegionEntryCBFn =
2274
+ std::function<llvm::SmallVector<const Fortran::semantics::Symbol *>(
2275
+ mlir::Operation *)>;
2276
+
2277
+ OpWithBodyGenInfo (Fortran::lower::AbstractConverter &converter,
2278
+ mlir::Location loc, Fortran::lower::pft::Evaluation &eval)
2279
+ : converter(converter), loc(loc), eval(eval) {}
2280
+
2281
+ OpWithBodyGenInfo &setGenNested (bool value) {
2282
+ genNested = value;
2283
+ return *this ;
2284
+ }
2285
+
2286
+ OpWithBodyGenInfo &setOuterCombined (bool value) {
2287
+ outerCombined = value;
2288
+ return *this ;
2289
+ }
2290
+
2291
+ OpWithBodyGenInfo &setClauses (const Fortran::parser::OmpClauseList *value) {
2292
+ clauses = value;
2293
+ return *this ;
2294
+ }
2295
+
2296
+ OpWithBodyGenInfo &setDataSharingProcessor (DataSharingProcessor *value) {
2297
+ dsp = value;
2298
+ return *this ;
2299
+ }
2300
+
2301
+ OpWithBodyGenInfo &setGenRegionEntryCb (GenOMPRegionEntryCBFn value) {
2302
+ genRegionEntryCB = value;
2303
+ return *this ;
2304
+ }
2305
+
2306
+ // / [inout] converter to use for the clauses.
2270
2307
Fortran::lower::AbstractConverter &converter;
2271
- mlir::Location &loc;
2308
+ // / [in] location in source code.
2309
+ mlir::Location loc;
2310
+ // / [in] current PFT node/evaluation.
2272
2311
Fortran::lower::pft::Evaluation &eval;
2312
+ // / [in] whether to generate FIR for nested evaluations
2273
2313
bool genNested = true ;
2274
- const Fortran::parser::OmpClauseList *clauses = nullptr ;
2275
- const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {};
2314
+ // / [in] is this an outer operation - prevents privatization.
2276
2315
bool outerCombined = false ;
2316
+ // / [in] list of clauses to process.
2317
+ const Fortran::parser::OmpClauseList *clauses = nullptr ;
2318
+ // / [in] if provided, processes the construct's data-sharing attributes.
2277
2319
DataSharingProcessor *dsp = nullptr ;
2320
+ // / [in] if provided, emits the op's region entry. Otherwise, an emtpy block
2321
+ // / is created in the region.
2322
+ GenOMPRegionEntryCBFn genRegionEntryCB = nullptr ;
2278
2323
};
2279
2324
2280
2325
// / Create the body (block) for an OpenMP Operation.
2281
2326
// /
2282
- // / \param [in] op - the operation the body belongs to.
2283
- // / \param [inout] converter - converter to use for the clauses.
2284
- // / \param [in] loc - location in source code.
2285
- // / \param [in] eval - current PFT node/evaluation.
2286
- // / \param [in] genNested - whether to generate FIR for nested evaluations
2287
- // / \oaran [in] clauses - list of clauses to process.
2288
- // / \param [in] args - block arguments (induction variable[s]) for the
2289
- // // region.
2290
- // / \param [in] outerCombined - is this an outer operation - prevents
2291
- // / privatization.
2327
+ // / \param [in] op - the operation the body belongs to.
2328
+ // / \param [in] info - options controlling code-gen for the construction.
2292
2329
template <typename Op>
2293
- static void createBodyOfOp (Op &op, CreateBodyOfOpInfo info) {
2330
+ static void createBodyOfOp (Op &op, OpWithBodyGenInfo & info) {
2294
2331
fir::FirOpBuilder &firOpBuilder = info.converter .getFirOpBuilder ();
2295
2332
2296
2333
auto insertMarker = [](fir::FirOpBuilder &builder) {
@@ -2303,28 +2340,15 @@ static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) {
2303
2340
// argument. Also update the symbol's address with the mlir argument value.
2304
2341
// e.g. For loops the argument is the induction variable. And all further
2305
2342
// uses of the induction variable should use this mlir value.
2306
- if (info.args .size ()) {
2307
- std::size_t loopVarTypeSize = 0 ;
2308
- for (const Fortran::semantics::Symbol *arg : info.args )
2309
- loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
2310
- mlir::Type loopVarType = getLoopVarType (info.converter , loopVarTypeSize);
2311
- llvm::SmallVector<mlir::Type> tiv (info.args .size (), loopVarType);
2312
- llvm::SmallVector<mlir::Location> locs (info.args .size (), info.loc );
2313
- firOpBuilder.createBlock (&op.getRegion (), {}, tiv, locs);
2314
- // The argument is not currently in memory, so make a temporary for the
2315
- // argument, and store it there, then bind that location to the argument.
2316
- mlir::Operation *storeOp = nullptr ;
2317
- for (auto [argIndex, argSymbol] : llvm::enumerate (info.args )) {
2318
- mlir::Value indexVal =
2319
- fir::getBase (op.getRegion ().front ().getArgument (argIndex));
2320
- storeOp = createAndSetPrivatizedLoopVar (info.converter , info.loc ,
2321
- indexVal, argSymbol);
2343
+ auto regionArgs =
2344
+ [&]() -> llvm::SmallVector<const Fortran::semantics::Symbol *> {
2345
+ if (info.genRegionEntryCB != nullptr ) {
2346
+ return info.genRegionEntryCB (op);
2322
2347
}
2323
- firOpBuilder.setInsertionPointAfter (storeOp);
2324
- } else {
2325
- firOpBuilder.createBlock (&op.getRegion ());
2326
- }
2327
2348
2349
+ firOpBuilder.createBlock (&op.getRegion ());
2350
+ return {};
2351
+ }();
2328
2352
// Mark the earliest insertion point.
2329
2353
mlir::Operation *marker = insertMarker (firOpBuilder);
2330
2354
@@ -2421,8 +2445,8 @@ static void createBodyOfOp(Op &op, CreateBodyOfOpInfo info) {
2421
2445
assert (tempDsp.has_value ());
2422
2446
tempDsp->processStep2 (op, isLoop);
2423
2447
} else {
2424
- if (isLoop && info. args .size () > 0 )
2425
- info.dsp ->setLoopIV (info.converter .getSymbolAddress (*info. args [0 ]));
2448
+ if (isLoop && regionArgs .size () > 0 )
2449
+ info.dsp ->setLoopIV (info.converter .getSymbolAddress (*regionArgs [0 ]));
2426
2450
info.dsp ->processStep2 (op, isLoop);
2427
2451
}
2428
2452
}
@@ -2497,24 +2521,11 @@ static void genBodyOfTargetDataOp(
2497
2521
genNestedEvaluations (converter, eval);
2498
2522
}
2499
2523
2500
- struct GenOpWithBodyInfo {
2501
- Fortran::lower::AbstractConverter &converter;
2502
- Fortran::lower::pft::Evaluation &eval;
2503
- bool genNested = false ;
2504
- mlir::Location currentLocation;
2505
- bool outerCombined = false ;
2506
- const Fortran::parser::OmpClauseList *clauseList = nullptr ;
2507
- };
2508
-
2509
2524
template <typename OpTy, typename ... Args>
2510
- static OpTy genOpWithBody (GenOpWithBodyInfo info, Args &&...args) {
2525
+ static OpTy genOpWithBody (OpWithBodyGenInfo & info, Args &&...args) {
2511
2526
auto op = info.converter .getFirOpBuilder ().create <OpTy>(
2512
- info.currentLocation , std::forward<Args>(args)...);
2513
- createBodyOfOp<OpTy>(
2514
- op, {info.converter , info.currentLocation , info.eval , info.genNested ,
2515
- info.clauseList ,
2516
- /* args*/ llvm::SmallVector<const Fortran::semantics::Symbol *>{},
2517
- info.outerCombined });
2527
+ info.loc , std::forward<Args>(args)...);
2528
+ createBodyOfOp<OpTy>(op, info);
2518
2529
return op;
2519
2530
}
2520
2531
@@ -2523,7 +2534,8 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
2523
2534
Fortran::lower::pft::Evaluation &eval, bool genNested,
2524
2535
mlir::Location currentLocation) {
2525
2536
return genOpWithBody<mlir::omp::MasterOp>(
2526
- {converter, eval, genNested, currentLocation},
2537
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2538
+ .setGenNested (genNested),
2527
2539
/* resultTypes=*/ mlir::TypeRange ());
2528
2540
}
2529
2541
@@ -2532,7 +2544,8 @@ genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
2532
2544
Fortran::lower::pft::Evaluation &eval, bool genNested,
2533
2545
mlir::Location currentLocation) {
2534
2546
return genOpWithBody<mlir::omp::OrderedRegionOp>(
2535
- {converter, eval, genNested, currentLocation},
2547
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2548
+ .setGenNested (genNested),
2536
2549
/* simd=*/ false );
2537
2550
}
2538
2551
@@ -2560,7 +2573,10 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
2560
2573
cp.processReduction (currentLocation, reductionVars, reductionDeclSymbols);
2561
2574
2562
2575
return genOpWithBody<mlir::omp::ParallelOp>(
2563
- {converter, eval, genNested, currentLocation, outerCombined, &clauseList},
2576
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2577
+ .setGenNested (genNested)
2578
+ .setOuterCombined (outerCombined)
2579
+ .setClauses (&clauseList),
2564
2580
/* resultTypes=*/ mlir::TypeRange (), ifClauseOperand,
2565
2581
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2566
2582
reductionVars,
@@ -2579,8 +2595,9 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
2579
2595
// Currently only private/firstprivate clause is handled, and
2580
2596
// all privatization is done within `omp.section` operations.
2581
2597
return genOpWithBody<mlir::omp::SectionOp>(
2582
- {converter, eval, genNested, currentLocation,
2583
- /* outerCombined=*/ false , §ionsClauseList});
2598
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2599
+ .setGenNested (genNested)
2600
+ .setClauses (§ionsClauseList));
2584
2601
}
2585
2602
2586
2603
static mlir::omp::SingleOp
@@ -2600,8 +2617,9 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
2600
2617
ClauseProcessor (converter, endClauseList).processNowait (nowaitAttr);
2601
2618
2602
2619
return genOpWithBody<mlir::omp::SingleOp>(
2603
- {converter, eval, genNested, currentLocation,
2604
- /* outerCombined=*/ false , &beginClauseList},
2620
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2621
+ .setGenNested (genNested)
2622
+ .setClauses (&beginClauseList),
2605
2623
allocateOperands, allocatorOperands, nowaitAttr);
2606
2624
}
2607
2625
@@ -2633,8 +2651,9 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
2633
2651
currentLocation, llvm::omp::Directive::OMPD_task);
2634
2652
2635
2653
return genOpWithBody<mlir::omp::TaskOp>(
2636
- {converter, eval, genNested, currentLocation,
2637
- /* outerCombined=*/ false , &clauseList},
2654
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2655
+ .setGenNested (genNested)
2656
+ .setClauses (&clauseList),
2638
2657
ifClauseOperand, finalClauseOperand, untiedAttr, mergeableAttr,
2639
2658
/* in_reduction_vars=*/ mlir::ValueRange (),
2640
2659
/* in_reductions=*/ nullptr , priorityClauseOperand,
@@ -2656,8 +2675,9 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
2656
2675
cp.processTODO <Fortran::parser::OmpClause::TaskReduction>(
2657
2676
currentLocation, llvm::omp::Directive::OMPD_taskgroup);
2658
2677
return genOpWithBody<mlir::omp::TaskGroupOp>(
2659
- {converter, eval, genNested, currentLocation,
2660
- /* outerCombined=*/ false , &clauseList},
2678
+ OpWithBodyGenInfo (converter, currentLocation, eval)
2679
+ .setGenNested (genNested)
2680
+ .setClauses (&clauseList),
2661
2681
/* task_reduction_vars=*/ mlir::ValueRange (),
2662
2682
/* task_reductions=*/ nullptr , allocateOperands, allocatorOperands);
2663
2683
}
@@ -3040,7 +3060,10 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
3040
3060
currentLocation, llvm::omp::Directive::OMPD_teams);
3041
3061
3042
3062
return genOpWithBody<mlir::omp::TeamsOp>(
3043
- {converter, eval, genNested, currentLocation, outerCombined, &clauseList},
3063
+ OpWithBodyGenInfo (converter, currentLocation, eval)
3064
+ .setGenNested (genNested)
3065
+ .setOuterCombined (outerCombined)
3066
+ .setClauses (&clauseList),
3044
3067
/* num_teams_lower=*/ nullptr , numTeamsClauseOperand, ifClauseOperand,
3045
3068
threadLimitClauseOperand, allocateOperands, allocatorOperands,
3046
3069
reductionVars,
@@ -3237,6 +3260,33 @@ static void convertLoopBounds(Fortran::lower::AbstractConverter &converter,
3237
3260
}
3238
3261
}
3239
3262
3263
+ static llvm::SmallVector<const Fortran::semantics::Symbol *>
3264
+ genLoopVars (mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
3265
+ mlir::Location &loc,
3266
+ const llvm::SmallVector<const Fortran::semantics::Symbol *> &args) {
3267
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder ();
3268
+ auto ®ion = op->getRegion (0 );
3269
+
3270
+ std::size_t loopVarTypeSize = 0 ;
3271
+ for (const Fortran::semantics::Symbol *arg : args)
3272
+ loopVarTypeSize = std::max (loopVarTypeSize, arg->GetUltimate ().size ());
3273
+ mlir::Type loopVarType = getLoopVarType (converter, loopVarTypeSize);
3274
+ llvm::SmallVector<mlir::Type> tiv (args.size (), loopVarType);
3275
+ llvm::SmallVector<mlir::Location> locs (args.size (), loc);
3276
+ firOpBuilder.createBlock (®ion, {}, tiv, locs);
3277
+ // The argument is not currently in memory, so make a temporary for the
3278
+ // argument, and store it there, then bind that location to the argument.
3279
+ mlir::Operation *storeOp = nullptr ;
3280
+ for (auto [argIndex, argSymbol] : llvm::enumerate (args)) {
3281
+ mlir::Value indexVal = fir::getBase (region.front ().getArgument (argIndex));
3282
+ storeOp =
3283
+ createAndSetPrivatizedLoopVar (converter, loc, indexVal, argSymbol);
3284
+ }
3285
+ firOpBuilder.setInsertionPointAfter (storeOp);
3286
+
3287
+ return args;
3288
+ }
3289
+
3240
3290
static void
3241
3291
createSimdLoop (Fortran::lower::AbstractConverter &converter,
3242
3292
Fortran::lower::pft::Evaluation &eval,
@@ -3284,10 +3334,16 @@ createSimdLoop(Fortran::lower::AbstractConverter &converter,
3284
3334
3285
3335
auto *nestedEval = getCollapsedLoopEval (
3286
3336
eval, Fortran::lower::getCollapseValue (loopOpClauseList));
3337
+
3338
+ auto ivCallback = [&](mlir::Operation *op) {
3339
+ return genLoopVars (op, converter, loc, iv);
3340
+ };
3341
+
3287
3342
createBodyOfOp<mlir::omp::SimdLoopOp>(
3288
- simdLoopOp, {converter, loc, *nestedEval,
3289
- /* genNested=*/ true , &loopOpClauseList, iv,
3290
- /* outerCombined=*/ false , &dsp});
3343
+ simdLoopOp, OpWithBodyGenInfo (converter, loc, *nestedEval)
3344
+ .setClauses (&loopOpClauseList)
3345
+ .setDataSharingProcessor (&dsp)
3346
+ .setGenRegionEntryCb (ivCallback));
3291
3347
}
3292
3348
3293
3349
static void createWsLoop (Fortran::lower::AbstractConverter &converter,
@@ -3360,10 +3416,16 @@ static void createWsLoop(Fortran::lower::AbstractConverter &converter,
3360
3416
3361
3417
auto *nestedEval = getCollapsedLoopEval (
3362
3418
eval, Fortran::lower::getCollapseValue (beginClauseList));
3363
- createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp,
3364
- {converter, loc, *nestedEval,
3365
- /* genNested=*/ true , &beginClauseList, iv,
3366
- /* outerCombined=*/ false , &dsp});
3419
+
3420
+ auto ivCallback = [&](mlir::Operation *op) {
3421
+ return genLoopVars (op, converter, loc, iv);
3422
+ };
3423
+
3424
+ createBodyOfOp<mlir::omp::WsLoopOp>(
3425
+ wsLoopOp, OpWithBodyGenInfo (converter, loc, *nestedEval)
3426
+ .setClauses (&beginClauseList)
3427
+ .setDataSharingProcessor (&dsp)
3428
+ .setGenRegionEntryCb (ivCallback));
3367
3429
}
3368
3430
3369
3431
static void createSimdWsLoop (
@@ -3644,8 +3706,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
3644
3706
currentLocation, mlir::FlatSymbolRefAttr::get (firOpBuilder.getContext (),
3645
3707
global.getSymName ()));
3646
3708
}();
3647
- createBodyOfOp<mlir::omp::CriticalOp>(criticalOp,
3648
- {converter, currentLocation, eval} );
3709
+ auto genInfo = OpWithBodyGenInfo (converter, currentLocation, eval);
3710
+ createBodyOfOp<mlir::omp::CriticalOp>(criticalOp, genInfo );
3649
3711
}
3650
3712
3651
3713
static void
@@ -3687,11 +3749,11 @@ genOMP(Fortran::lower::AbstractConverter &converter,
3687
3749
}
3688
3750
3689
3751
// SECTIONS construct
3690
- genOpWithBody<mlir::omp::SectionsOp>({converter, eval,
3691
- /* genNested= */ false , currentLocation} ,
3692
- /* reduction_vars=*/ mlir::ValueRange (),
3693
- /* reductions=*/ nullptr , allocateOperands,
3694
- allocatorOperands, nowaitClauseOperand);
3752
+ genOpWithBody<mlir::omp::SectionsOp>(
3753
+ OpWithBodyGenInfo (converter , currentLocation, eval). setGenNested ( false ) ,
3754
+ /* reduction_vars=*/ mlir::ValueRange (),
3755
+ /* reductions=*/ nullptr , allocateOperands, allocatorOperands ,
3756
+ nowaitClauseOperand);
3695
3757
3696
3758
const auto §ionBlocks =
3697
3759
std::get<Fortran::parser::OmpSectionBlocks>(sectionsConstruct.t );
0 commit comments