Skip to content

Commit ef847ab

Browse files
committed
[WIP] Delayed privatization.
This is a PoC for delayed privatization in OpenMP. Instead of directly emitting privatization code in the frontend, we add a new op to outline the privatization logic for a symbol and call-like mapping that maps from the host symbol to an outlined function-like privatizer op. Later, we would inline the delayed privatizer function-like op in the OpenMP region to basically get the same code generated directly by the fronend at the moment.
1 parent 5a07774 commit ef847ab

File tree

8 files changed

+326
-30
lines changed

8 files changed

+326
-30
lines changed

flang/include/flang/Lower/AbstractConverter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Common/Fortran.h"
1717
#include "flang/Lower/LoweringOptions.h"
1818
#include "flang/Lower/PFTDefs.h"
19+
#include "flang/Lower/SymbolMap.h"
1920
#include "flang/Optimizer/Builder/BoxValue.h"
2021
#include "flang/Semantics/symbol.h"
2122
#include "mlir/IR/Builders.h"
@@ -295,6 +296,9 @@ class AbstractConverter {
295296
return loweringOptions;
296297
}
297298

299+
virtual Fortran::lower::SymbolBox
300+
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) = 0;
301+
298302
private:
299303
/// Options controlling lowering behavior.
300304
const Fortran::lower::LoweringOptions &loweringOptions;

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1070,7 +1070,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
10701070
/// Find the symbol in one level up of symbol map such as for host-association
10711071
/// in OpenMP code or return null.
10721072
Fortran::lower::SymbolBox
1073-
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) {
1073+
lookupOneLevelUpSymbol(const Fortran::semantics::Symbol &sym) override {
10741074
if (Fortran::lower::SymbolBox v = localSymbols.lookupOneLevelUpSymbol(sym))
10751075
return v;
10761076
return {};

flang/lib/Lower/OpenMP.cpp

Lines changed: 108 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,11 @@ class DataSharingProcessor {
161161
const Fortran::parser::OmpClauseList &opClauseList;
162162
Fortran::lower::pft::Evaluation &eval;
163163

164+
bool useDelayedPrivatizationWhenPossible;
165+
Fortran::lower::SymMap *symTable;
166+
llvm::SetVector<mlir::SymbolRefAttr> privateInitializers;
167+
llvm::SetVector<mlir::Value> privateSymHostAddrsses;
168+
164169
bool needBarrier();
165170
void collectSymbols(Fortran::semantics::Symbol::Flag flag);
166171
void collectOmpObjectListSymbol(
@@ -182,10 +187,14 @@ class DataSharingProcessor {
182187
public:
183188
DataSharingProcessor(Fortran::lower::AbstractConverter &converter,
184189
const Fortran::parser::OmpClauseList &opClauseList,
185-
Fortran::lower::pft::Evaluation &eval)
190+
Fortran::lower::pft::Evaluation &eval,
191+
bool useDelayedPrivatizationWhenPossible = false,
192+
Fortran::lower::SymMap *symTable = nullptr)
186193
: hasLastPrivateOp(false), converter(converter),
187194
firOpBuilder(converter.getFirOpBuilder()), opClauseList(opClauseList),
188-
eval(eval) {}
195+
eval(eval), useDelayedPrivatizationWhenPossible(
196+
useDelayedPrivatizationWhenPossible),
197+
symTable(symTable) {}
189198
// Privatisation is split into two steps.
190199
// Step1 performs cloning of all privatisation clauses and copying for
191200
// firstprivates. Step1 is performed at the place where process/processStep1
@@ -204,6 +213,14 @@ class DataSharingProcessor {
204213
assert(!loopIV && "Loop iteration variable already set");
205214
loopIV = iv;
206215
}
216+
217+
const llvm::SetVector<mlir::SymbolRefAttr> &getPrivateInitializers() const {
218+
return privateInitializers;
219+
};
220+
221+
const llvm::SetVector<mlir::Value> &getPrivateSymHostAddrsses() const {
222+
return privateSymHostAddrsses;
223+
}
207224
};
208225

209226
void DataSharingProcessor::processStep1() {
@@ -496,8 +513,46 @@ void DataSharingProcessor::privatize() {
496513
copyFirstPrivateSymbol(&*mem);
497514
}
498515
} else {
499-
cloneSymbol(sym);
500-
copyFirstPrivateSymbol(sym);
516+
if (useDelayedPrivatizationWhenPossible) {
517+
auto ip = firOpBuilder.saveInsertionPoint();
518+
519+
auto moduleOp = firOpBuilder.getInsertionBlock()
520+
->getParentOp()
521+
->getParentOfType<mlir::ModuleOp>();
522+
523+
firOpBuilder.setInsertionPoint(&moduleOp.getBodyRegion().front(),
524+
moduleOp.getBodyRegion().front().end());
525+
526+
Fortran::lower::SymbolBox hsb = converter.lookupOneLevelUpSymbol(*sym);
527+
assert(hsb && "Host symbol box not found");
528+
529+
auto symType = hsb.getAddr().getType();
530+
auto symLoc = hsb.getAddr().getLoc();
531+
auto privatizerOp = firOpBuilder.create<mlir::omp::PrivateClauseOp>(
532+
symLoc, symType, sym->name().ToString());
533+
firOpBuilder.setInsertionPointToEnd(&privatizerOp.getBody().front());
534+
535+
symTable->pushScope();
536+
symTable->addSymbol(*sym, privatizerOp.getArgument(0));
537+
symTable->pushScope();
538+
539+
cloneSymbol(sym);
540+
copyFirstPrivateSymbol(sym);
541+
542+
firOpBuilder.create<mlir::omp::YieldOp>(
543+
hsb.getAddr().getLoc(),
544+
symTable->shallowLookupSymbol(*sym).getAddr());
545+
546+
symTable->popScope();
547+
symTable->popScope();
548+
firOpBuilder.restoreInsertionPoint(ip);
549+
550+
privateInitializers.insert(mlir::SymbolRefAttr::get(privatizerOp));
551+
privateSymHostAddrsses.insert(hsb.getAddr());
552+
} else {
553+
cloneSymbol(sym);
554+
copyFirstPrivateSymbol(sym);
555+
}
501556
}
502557
}
503558
}
@@ -2463,12 +2518,12 @@ static OpTy genOpWithBody(Fortran::lower::AbstractConverter &converter,
24632518
Fortran::lower::pft::Evaluation &eval, bool genNested,
24642519
mlir::Location currentLocation, bool outerCombined,
24652520
const Fortran::parser::OmpClauseList *clauseList,
2466-
Args &&...args) {
2521+
DataSharingProcessor *dsp, Args &&...args) {
24672522
auto op = converter.getFirOpBuilder().create<OpTy>(
24682523
currentLocation, std::forward<Args>(args)...);
24692524
createBodyOfOp<OpTy>(op, converter, currentLocation, eval, genNested,
24702525
clauseList,
2471-
/*args=*/{}, outerCombined);
2526+
/*args=*/{}, outerCombined, dsp);
24722527
return op;
24732528
}
24742529

@@ -2480,21 +2535,25 @@ genMasterOp(Fortran::lower::AbstractConverter &converter,
24802535
currentLocation,
24812536
/*outerCombined=*/false,
24822537
/*clauseList=*/nullptr,
2538+
/*dsp=*/nullptr,
24832539
/*resultTypes=*/mlir::TypeRange());
24842540
}
24852541

24862542
static mlir::omp::OrderedRegionOp
24872543
genOrderedRegionOp(Fortran::lower::AbstractConverter &converter,
24882544
Fortran::lower::pft::Evaluation &eval, bool genNested,
24892545
mlir::Location currentLocation) {
2490-
return genOpWithBody<mlir::omp::OrderedRegionOp>(
2491-
converter, eval, genNested, currentLocation,
2492-
/*outerCombined=*/false,
2493-
/*clauseList=*/nullptr, /*simd=*/false);
2546+
return genOpWithBody<mlir::omp::OrderedRegionOp>(converter, eval, genNested,
2547+
currentLocation,
2548+
/*outerCombined=*/false,
2549+
/*clauseList=*/nullptr,
2550+
/*dsp=*/nullptr,
2551+
/*simd=*/false);
24942552
}
24952553

24962554
static mlir::omp::ParallelOp
24972555
genParallelOp(Fortran::lower::AbstractConverter &converter,
2556+
Fortran::lower::SymMap &symTable,
24982557
Fortran::lower::pft::Evaluation &eval, bool genNested,
24992558
mlir::Location currentLocation,
25002559
const Fortran::parser::OmpClauseList &clauseList,
@@ -2516,16 +2575,37 @@ genParallelOp(Fortran::lower::AbstractConverter &converter,
25162575
if (!outerCombined)
25172576
cp.processReduction(currentLocation, reductionVars, reductionDeclSymbols);
25182577

2578+
bool privatize = !outerCombined;
2579+
DataSharingProcessor dsp(converter, clauseList, eval,
2580+
/*useDelayedPrivatizationWhenPossible=*/true,
2581+
&symTable);
2582+
2583+
if (privatize) {
2584+
dsp.processStep1();
2585+
}
2586+
2587+
llvm::SmallVector<mlir::Attribute> privateInits(
2588+
dsp.getPrivateInitializers().begin(), dsp.getPrivateInitializers().end());
2589+
2590+
llvm::SmallVector<mlir::Value> privateSymAddresses(
2591+
dsp.getPrivateSymHostAddrsses().begin(),
2592+
dsp.getPrivateSymHostAddrsses().end());
2593+
25192594
return genOpWithBody<mlir::omp::ParallelOp>(
25202595
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
2596+
&dsp,
25212597
/*resultTypes=*/mlir::TypeRange(), ifClauseOperand,
25222598
numThreadsClauseOperand, allocateOperands, allocatorOperands,
2523-
reductionVars,
2599+
reductionVars, privateSymAddresses,
25242600
reductionDeclSymbols.empty()
25252601
? nullptr
25262602
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
25272603
reductionDeclSymbols),
2528-
procBindKindAttr);
2604+
procBindKindAttr,
2605+
privateInits.empty()
2606+
? nullptr
2607+
: mlir::ArrayAttr::get(converter.getFirOpBuilder().getContext(),
2608+
privateInits));
25292609
}
25302610

25312611
static mlir::omp::SectionOp
@@ -2537,7 +2617,8 @@ genSectionOp(Fortran::lower::AbstractConverter &converter,
25372617
// all privatization is done within `omp.section` operations.
25382618
return genOpWithBody<mlir::omp::SectionOp>(
25392619
converter, eval, genNested, currentLocation,
2540-
/*outerCombined=*/false, &sectionsClauseList);
2620+
/*outerCombined=*/false, &sectionsClauseList,
2621+
/*dsp=*/nullptr);
25412622
}
25422623

25432624
static mlir::omp::SingleOp
@@ -2558,8 +2639,8 @@ genSingleOp(Fortran::lower::AbstractConverter &converter,
25582639

25592640
return genOpWithBody<mlir::omp::SingleOp>(
25602641
converter, eval, genNested, currentLocation,
2561-
/*outerCombined=*/false, &beginClauseList, allocateOperands,
2562-
allocatorOperands, nowaitAttr);
2642+
/*outerCombined=*/false, &beginClauseList, /*dsp=*/nullptr,
2643+
allocateOperands, allocatorOperands, nowaitAttr);
25632644
}
25642645

25652646
static mlir::omp::TaskOp
@@ -2591,8 +2672,8 @@ genTaskOp(Fortran::lower::AbstractConverter &converter,
25912672

25922673
return genOpWithBody<mlir::omp::TaskOp>(
25932674
converter, eval, genNested, currentLocation,
2594-
/*outerCombined=*/false, &clauseList, ifClauseOperand, finalClauseOperand,
2595-
untiedAttr, mergeableAttr,
2675+
/*outerCombined=*/false, &clauseList, /*dsp=*/nullptr, ifClauseOperand,
2676+
finalClauseOperand, untiedAttr, mergeableAttr,
25962677
/*in_reduction_vars=*/mlir::ValueRange(),
25972678
/*in_reductions=*/nullptr, priorityClauseOperand,
25982679
dependTypeOperands.empty()
@@ -2615,6 +2696,7 @@ genTaskGroupOp(Fortran::lower::AbstractConverter &converter,
26152696
return genOpWithBody<mlir::omp::TaskGroupOp>(
26162697
converter, eval, genNested, currentLocation,
26172698
/*outerCombined=*/false, &clauseList,
2699+
/*dsp=*/nullptr,
26182700
/*task_reduction_vars=*/mlir::ValueRange(),
26192701
/*task_reductions=*/nullptr, allocateOperands, allocatorOperands);
26202702
}
@@ -2994,6 +3076,7 @@ genTeamsOp(Fortran::lower::AbstractConverter &converter,
29943076

29953077
return genOpWithBody<mlir::omp::TeamsOp>(
29963078
converter, eval, genNested, currentLocation, outerCombined, &clauseList,
3079+
/*dsp=*/nullptr,
29973080
/*num_teams_lower=*/nullptr, numTeamsClauseOperand, ifClauseOperand,
29983081
threadLimitClauseOperand, allocateOperands, allocatorOperands,
29993082
reductionVars,
@@ -3392,8 +3475,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
33923475
if ((llvm::omp::allParallelSet & llvm::omp::loopConstructSet)
33933476
.test(ompDirective)) {
33943477
validDirective = true;
3395-
genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
3396-
loopOpClauseList,
3478+
genParallelOp(converter, symTable, eval, /*genNested=*/false,
3479+
currentLocation, loopOpClauseList,
33973480
/*outerCombined=*/true);
33983481
}
33993482
}
@@ -3481,8 +3564,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
34813564
genOrderedRegionOp(converter, eval, /*genNested=*/true, currentLocation);
34823565
break;
34833566
case llvm::omp::Directive::OMPD_parallel:
3484-
genParallelOp(converter, eval, /*genNested=*/true, currentLocation,
3485-
beginClauseList);
3567+
genParallelOp(converter, symTable, eval, /*genNested=*/true,
3568+
currentLocation, beginClauseList);
34863569
break;
34873570
case llvm::omp::Directive::OMPD_single:
34883571
genSingleOp(converter, eval, /*genNested=*/true, currentLocation,
@@ -3541,8 +3624,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
35413624
.test(directive.v)) {
35423625
bool outerCombined =
35433626
directive.v != llvm::omp::Directive::OMPD_target_parallel;
3544-
genParallelOp(converter, eval, /*genNested=*/false, currentLocation,
3545-
beginClauseList, outerCombined);
3627+
genParallelOp(converter, symTable, eval, /*genNested=*/false,
3628+
currentLocation, beginClauseList, outerCombined);
35463629
combinedDirective = true;
35473630
}
35483631
if ((llvm::omp::workShareSet & llvm::omp::blockConstructSet)
@@ -3625,7 +3708,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
36253708

36263709
// Parallel wrapper of PARALLEL SECTIONS construct
36273710
if (dir == llvm::omp::Directive::OMPD_parallel_sections) {
3628-
genParallelOp(converter, eval,
3711+
genParallelOp(converter, symTable, eval,
36293712
/*genNested=*/false, currentLocation, sectionsClauseList,
36303713
/*outerCombined=*/true);
36313714
} else {
@@ -3642,6 +3725,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
36423725
/*genNested=*/false, currentLocation,
36433726
/*outerCombined=*/false,
36443727
/*clauseList=*/nullptr,
3728+
/*dsp=*/nullptr,
36453729
/*reduction_vars=*/mlir::ValueRange(),
36463730
/*reductions=*/nullptr, allocateOperands,
36473731
allocatorOperands, nowaitClauseOperand);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
subroutine delayed_privatization()
2+
integer :: var1
3+
integer :: var2
4+
5+
!$OMP PARALLEL FIRSTPRIVATE(var1, var2)
6+
var1 = var1 + var2 + 2
7+
!$OMP END PARALLEL
8+
9+
end subroutine
10+
11+
! This is what flang emits with the PoC:
12+
! --------------------------------------
13+
!
14+
!func.func @_QPdelayed_privatization() {
15+
! %0 = fir.alloca i32 {bindc_name = "var1", uniq_name = "_QFdelayed_privatizationEvar1"}
16+
! %1 = fir.alloca i32 {bindc_name = "var2", uniq_name = "_QFdelayed_privatizationEvar2"}
17+
! omp.parallel private(@var1.privatizer %0, @var2.privatizer %1 : !fir.ref<i32>, !fir.ref<i32>) {
18+
! %2 = fir.load %0 : !fir.ref<i32>
19+
! %3 = fir.load %1 : !fir.ref<i32>
20+
! %4 = arith.addi %2, %3 : i32
21+
! %c2_i32 = arith.constant 2 : i32
22+
! %5 = arith.addi %4, %c2_i32 : i32
23+
! fir.store %5 to %0 : !fir.ref<i32>
24+
! omp.terminator
25+
! }
26+
! return
27+
!}
28+
!
29+
!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var1.privatizer"}> ({
30+
!^bb0(%arg0: !fir.ref<i32>):
31+
! %0 = fir.alloca i32 {bindc_name = "var1", pinned, uniq_name = "_QFdelayed_privatizationEvar1"}
32+
! %1 = fir.load %arg0 : !fir.ref<i32>
33+
! fir.store %1 to %0 : !fir.ref<i32>
34+
! omp.yield(%0 : !fir.ref<i32>)
35+
!}) : () -> ()
36+
!
37+
!"omp.private"() <{function_type = (!fir.ref<i32>) -> !fir.ref<i32>, sym_name = "var2.privatizer"}> ({
38+
!^bb0(%arg0: !fir.ref<i32>):
39+
! %0 = fir.alloca i32 {bindc_name = "var2", pinned, uniq_name = "_QFdelayed_privatizationEvar2"}
40+
! %1 = fir.load %arg0 : !fir.ref<i32>
41+
! fir.store %1 to %0 : !fir.ref<i32>
42+
! omp.yield(%0 : !fir.ref<i32>)
43+
!}) : () -> ()

0 commit comments

Comments
 (0)