Skip to content

Commit 2d6e1c4

Browse files
[Flang][MLIR][OpenMP] WIP: Privatisation for index variables
1 parent decf027 commit 2d6e1c4

File tree

7 files changed

+217
-14
lines changed

7 files changed

+217
-14
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2075,6 +2075,7 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
20752075
firOpBuilder.setInsertionPointToStart(firOpBuilder.getAllocaBlock());
20762076

20772077
mlir::Type tempTy = converter.genType(*sym);
2078+
llvm::outs() << "Temp type = " << tempTy << "\n";
20782079
mlir::Value temp = firOpBuilder.create<fir::AllocaOp>(
20792080
loc, tempTy, /*pinned=*/true, /*lengthParams=*/mlir::ValueRange{},
20802081
/*shapeParams*/ mlir::ValueRange{},
@@ -2088,6 +2089,103 @@ createAndSetPrivatizedLoopVar(Fortran::lower::AbstractConverter &converter,
20882089
return storeOp;
20892090
}
20902091

2092+
/// Create the body (block) for an OpenMP Loop Operation.
2093+
///
2094+
/// \param [in] op - the operation the body belongs to.
2095+
/// \param [inout] converter - converter to use for the clauses.
2096+
/// \param [in] loc - location in source code.
2097+
/// \param [in] eval - current PFT node/evaluation.
2098+
/// \oaran [in] clauses - list of clauses to process.
2099+
/// \param [in] args - block arguments (induction variable[s]) for the
2100+
//// region.
2101+
/// \param [in] outerCombined - is this an outer operation - prevents
2102+
/// privatization.
2103+
template <typename Op>
2104+
static void createBodyOfLoopOp(
2105+
Op &op, Fortran::lower::AbstractConverter &converter, mlir::Location &loc,
2106+
Fortran::lower::pft::Evaluation &eval,
2107+
const Fortran::parser::OmpClauseList *clauses = nullptr,
2108+
const llvm::SmallVector<const Fortran::semantics::Symbol *> &args = {},
2109+
bool outerCombined = false, DataSharingProcessor *dsp = nullptr) {
2110+
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
2111+
// If an argument for the region is provided then create the block with that
2112+
// argument. Also update the symbol's address with the mlir argument value.
2113+
// e.g. For loops the argument is the induction variable. And all further
2114+
// uses of the induction variable should use this mlir value.
2115+
mlir::Operation *storeOp = nullptr;
2116+
assert(args.size() > 0);
2117+
std::size_t loopVarTypeSize = 0;
2118+
for (const Fortran::semantics::Symbol *arg : args)
2119+
loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
2120+
mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
2121+
llvm::SmallVector<mlir::Type> tiv;
2122+
llvm::SmallVector<mlir::Location> locs;
2123+
for (int i = 0; i < (int)args.size(); i++) {
2124+
tiv.push_back(loopVarType);
2125+
locs.push_back(loc);
2126+
}
2127+
int offset = 0;
2128+
// The argument is not currently in memory, so make a temporary for the
2129+
// argument, and store it there, then bind that location to the argument.
2130+
for (const Fortran::semantics::Symbol *arg : args) {
2131+
mlir::Type symType = converter.genType(*arg);
2132+
mlir::Type symRefType = firOpBuilder.getRefType(symType);
2133+
tiv.push_back(symRefType);
2134+
locs.push_back(loc);
2135+
offset++;
2136+
}
2137+
firOpBuilder.createBlock(&op.getRegion(), {}, tiv, locs);
2138+
2139+
int argIndex = 0;
2140+
for (const Fortran::semantics::Symbol *arg : args) {
2141+
mlir::Value addrVal =
2142+
fir::getBase(op.getRegion().front().getArgument(argIndex+offset));
2143+
converter.bindSymbol(*arg, addrVal);
2144+
mlir::Type symType = converter.genType(*arg);
2145+
mlir::Value indexVal =
2146+
fir::getBase(op.getRegion().front().getArgument(argIndex));
2147+
mlir::Value cvtVal = firOpBuilder.createConvert(loc, symType, indexVal);
2148+
addrVal = converter.getSymbolAddress(*arg);
2149+
storeOp = firOpBuilder.create<fir::StoreOp>(loc, cvtVal, addrVal);
2150+
argIndex++;
2151+
}
2152+
// Set the insert for the terminator operation to go at the end of the
2153+
// block - this is either empty or the block with the stores above,
2154+
// the end of the block works for both.
2155+
mlir::Block &block = op.getRegion().back();
2156+
firOpBuilder.setInsertionPointToEnd(&block);
2157+
2158+
// If it is an unstructured region and is not the outer region of a combined
2159+
// construct, create empty blocks for all evaluations.
2160+
if (eval.lowerAsUnstructured() && !outerCombined)
2161+
Fortran::lower::createEmptyRegionBlocks<mlir::omp::TerminatorOp,
2162+
mlir::omp::YieldOp>(
2163+
firOpBuilder, eval.getNestedEvaluations());
2164+
2165+
// Insert the terminator.
2166+
Fortran::lower::genOpenMPTerminator(firOpBuilder, op.getOperation(), loc);
2167+
// Reset the insert point to before the terminator.
2168+
resetBeforeTerminator(firOpBuilder, storeOp, block);
2169+
2170+
// Handle privatization. Do not privatize if this is the outer operation.
2171+
if (clauses && !outerCombined) {
2172+
constexpr bool isLoop = std::is_same_v<Op, mlir::omp::WsLoopOp> ||
2173+
std::is_same_v<Op, mlir::omp::SimdLoopOp>;
2174+
if (!dsp) {
2175+
DataSharingProcessor proc(converter, *clauses, eval);
2176+
proc.processStep1();
2177+
proc.processStep2(op, isLoop);
2178+
} else {
2179+
if (isLoop && args.size() > 0)
2180+
dsp->setLoopIV(converter.getSymbolAddress(*args[0]));
2181+
dsp->processStep2(op, isLoop);
2182+
}
2183+
2184+
if (storeOp)
2185+
firOpBuilder.setInsertionPointAfter(storeOp);
2186+
}
2187+
}
2188+
20912189
/// Create the body (block) for an OpenMP Operation.
20922190
///
20932191
/// \param [in] op - the operation the body belongs to.
@@ -2914,7 +3012,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
29143012
const Fortran::parser::OpenMPLoopConstruct &loopConstruct) {
29153013
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
29163014
llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
2917-
linearStepVars, reductionVars;
3015+
linearStepVars, privateVars, reductionVars;
29183016
mlir::Value scheduleChunkClauseOperand;
29193017
mlir::IntegerAttr orderedClauseOperand;
29203018
mlir::omp::ClauseOrderKindAttr orderClauseOperand;
@@ -3023,9 +3121,23 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
30233121
return;
30243122
}
30253123

3124+
// Collect the loops to collapse.
3125+
Fortran::lower::pft::Evaluation *doConstructEval =
3126+
&eval.getFirstNestedEvaluation();
3127+
Fortran::lower::pft::Evaluation *doLoop =
3128+
&doConstructEval->getFirstNestedEvaluation();
3129+
auto *doStmt = doLoop->getIf<Fortran::parser::NonLabelDoStmt>();
3130+
assert(doStmt && "Expected do loop to be in the nested evaluation");
3131+
const auto &loopControl =
3132+
std::get<std::optional<Fortran::parser::LoopControl>>(doStmt->t);
3133+
const Fortran::parser::LoopControl::Bounds *bounds =
3134+
std::get_if<Fortran::parser::LoopControl::Bounds>(&loopControl->u);
3135+
assert(bounds && "Expected bounds for worksharing do loop");
3136+
privateVars.push_back(converter.getSymbolAddress(*bounds->name.thing.symbol));
3137+
30263138
auto wsLoopOp = firOpBuilder.create<mlir::omp::WsLoopOp>(
30273139
currentLocation, lowerBound, upperBound, step, linearVars, linearStepVars,
3028-
reductionVars,
3140+
privateVars, reductionVars,
30293141
reductionDeclSymbols.empty()
30303142
? nullptr
30313143
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
@@ -3061,7 +3173,7 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
30613173
wsLoopOp.setNowaitAttr(nowaitClauseOperand);
30623174
}
30633175

3064-
createBodyOfOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
3176+
createBodyOfLoopOp<mlir::omp::WsLoopOp>(wsLoopOp, converter, currentLocation,
30653177
eval, &loopOpClauseList, iv,
30663178
/*outer=*/false, &dsp);
30673179
}

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
479479
Variadic<IntLikeType>:$step,
480480
Variadic<AnyType>:$linear_vars,
481481
Variadic<I32>:$linear_step_vars,
482+
Variadic<OpenMP_PointerLikeType>:$privates,
482483
Variadic<OpenMP_PointerLikeType>:$reduction_vars,
483484
OptionalAttr<SymbolRefArrayAttr>:$reductions,
484485
OptionalAttr<ScheduleKindAttr>:$schedule_val,
@@ -517,6 +518,7 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments,
517518
|`nowait` $nowait
518519
|`ordered` `(` $ordered_val `)`
519520
|`order` `(` custom<ClauseAttr>($order_val) `)`
521+
|`private` `(` custom<PrivateEntries>($privates, type($privates)) `)`
520522
|`reduction` `(`
521523
custom<ReductionVarList>(
522524
$reduction_vars, type($reduction_vars), $reductions

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,67 @@ void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
178178
p << stringifyEnum(attr.getValue());
179179
}
180180

181+
static ParseResult
182+
parsePrivateEntries(OpAsmParser &parser,
183+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &privateOperands,
184+
SmallVectorImpl<Type> &privateOperandTypes) {
185+
OpAsmParser::UnresolvedOperand arg;
186+
OpAsmParser::UnresolvedOperand blockArg;
187+
Type argType;
188+
auto parseEntries = [&]() -> ParseResult {
189+
if (parser.parseOperand(arg) || parser.parseArrow() ||
190+
parser.parseOperand(blockArg))
191+
return failure();
192+
privateOperands.push_back(arg);
193+
return success();
194+
};
195+
196+
auto parseTypes = [&]() -> ParseResult {
197+
if (parser.parseType(argType))
198+
return failure();
199+
privateOperandTypes.push_back(argType);
200+
return success();
201+
};
202+
203+
if (parser.parseCommaSeparatedList(parseEntries))
204+
return failure();
205+
206+
if (parser.parseColon())
207+
return failure();
208+
209+
if (parser.parseCommaSeparatedList(parseTypes))
210+
return failure();
211+
212+
return success();
213+
}
214+
215+
static void printPrivateEntries(OpAsmPrinter &p, Operation *op,
216+
OperandRange privateOperands,
217+
TypeRange privateOperandTypes) {
218+
auto &region = op->getRegion(0);
219+
220+
unsigned argIndex = 0;
221+
unsigned offset = 0;
222+
if (auto wsLoop = dyn_cast<WsLoopOp>(op))
223+
offset = wsLoop.getNumLoops();
224+
for (const auto &privOperand : privateOperands) {
225+
const auto &blockArg = region.front().getArgument(argIndex+offset);
226+
p << privOperand << " -> " << blockArg;
227+
argIndex++;
228+
if (argIndex < privateOperands.size())
229+
p << ", ";
230+
}
231+
p << " : ";
232+
233+
argIndex = 0;
234+
for (const auto &privOperandType : privateOperandTypes) {
235+
p << privOperandType;
236+
argIndex++;
237+
if (argIndex < privateOperands.size())
238+
p << ", ";
239+
}
240+
}
241+
181242
//===----------------------------------------------------------------------===//
182243
// Parser and printer for Linear Clause
183244
//===----------------------------------------------------------------------===//
@@ -1086,7 +1147,14 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
10861147
ValueRange steps, TypeRange loopVarTypes,
10871148
UnitAttr inclusive) {
10881149
auto args = region.front().getArguments();
1089-
p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound
1150+
p << " (";
1151+
unsigned numLoops = steps.size();
1152+
for (unsigned i=0; i<numLoops; i++) {
1153+
if (i != 0)
1154+
p << ", ";
1155+
p << args[i];
1156+
}
1157+
p << ") : " << args[0].getType() << " = (" << lowerBound
10901158
<< ") to (" << upperBound << ") ";
10911159
if (inclusive)
10921160
p << "inclusive ";
@@ -1269,7 +1337,8 @@ void WsLoopOp::build(OpBuilder &builder, OperationState &state,
12691337
ValueRange step, ArrayRef<NamedAttribute> attributes) {
12701338
build(builder, state, lowerBound, upperBound, step,
12711339
/*linear_vars=*/ValueRange(),
1272-
/*linear_step_vars=*/ValueRange(), /*reduction_vars=*/ValueRange(),
1340+
/*linear_step_vars=*/ValueRange(), /*private_vars=*/ValueRange(),
1341+
/*reduction_vars=*/ValueRange(),
12731342
/*reductions=*/nullptr, /*schedule_val=*/nullptr,
12741343
/*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
12751344
/*simd_modifier=*/false, /*nowait=*/false, /*ordered_val=*/nullptr,

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,6 +833,24 @@ static void collectReductionInfo(
833833
}
834834
}
835835

836+
/// Allocate space for privatized reduction variables.
837+
void
838+
allocPrivatizationVars(omp::WsLoopOp loop, llvm::IRBuilderBase &builder,
839+
LLVM::ModuleTranslation &moduleTranslation,
840+
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP) {
841+
unsigned offset = loop.getNumLoops();
842+
unsigned numArgs = loop.getRegion().front().getNumArguments();
843+
llvm::IRBuilderBase::InsertPointGuard guard(builder);
844+
builder.restoreIP(allocaIP);
845+
for (unsigned i = offset; i < numArgs; ++i) {
846+
if (auto op = loop.getPrivates()[i-offset].getDefiningOp<LLVM::AllocaOp>()) {
847+
llvm::Value *var = builder.CreateAlloca(moduleTranslation.convertType(op.getResultPtrElementType()));
848+
// moduleTranslation.convertType(loop.getPrivates()[i-offset].getType()));
849+
moduleTranslation.mapValue(loop.getRegion().front().getArgument(i), var);
850+
}
851+
}
852+
}
853+
836854
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
837855
static LogicalResult
838856
convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
@@ -861,6 +879,8 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,
861879
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
862880
findAllocaInsertPoint(builder, moduleTranslation);
863881

882+
allocPrivatizationVars(loop, builder, moduleTranslation, allocaIP);
883+
864884
SmallVector<llvm::Value *> privateReductionVariables;
865885
DenseMap<Value, llvm::Value *> reductionVariableMap;
866886
allocReductionVars(loop, builder, moduleTranslation, allocaIP, reductionDecls,

mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ func.func @wsloop(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4:
7979
// CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> ()
8080
"test.payload"(%arg6, %arg7) : (index, index) -> ()
8181
omp.yield
82-
}) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
82+
}) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0, 0>} : (index, index, index, index, index, index) -> ()
8383
omp.terminator
8484
}
8585
return

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,39 +141,39 @@ func.func @omp_wsloop(%lb : index, %ub : index, %step : index, %data_var : memre
141141
"omp.wsloop" (%lb, %ub, %step) ({
142142
^bb0(%iv: index):
143143
omp.yield
144-
}) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>, ordered_val = 1} :
144+
}) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>, ordered_val = 1} :
145145
(index, index, index) -> ()
146146

147147
// CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(static)
148148
// CHECK-SAME: for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
149149
"omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var) ({
150150
^bb0(%iv: index):
151151
omp.yield
152-
}) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0>, schedule_val = #omp<schedulekind static>} :
152+
}) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0,0>, schedule_val = #omp<schedulekind static>} :
153153
(index, index, index, memref<i32>, i32) -> ()
154154

155155
// CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>, %{{.*}} = %{{.*}} : memref<i32>) schedule(static)
156156
// CHECK-SAME: for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
157157
"omp.wsloop" (%lb, %ub, %step, %data_var, %data_var, %linear_var, %linear_var) ({
158158
^bb0(%iv: index):
159159
omp.yield
160-
}) {operandSegmentSizes = array<i32: 1,1,1,2,2,0,0>, schedule_val = #omp<schedulekind static>} :
160+
}) {operandSegmentSizes = array<i32: 1,1,1,2,2,0,0,0>, schedule_val = #omp<schedulekind static>} :
161161
(index, index, index, memref<i32>, memref<i32>, i32, i32) -> ()
162162

163163
// CHECK: omp.wsloop linear(%{{.*}} = %{{.*}} : memref<i32>) schedule(dynamic = %{{.*}}) ordered(2)
164164
// CHECK-SAME: for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
165165
"omp.wsloop" (%lb, %ub, %step, %data_var, %linear_var, %chunk_var) ({
166166
^bb0(%iv: index):
167167
omp.yield
168-
}) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,1>, schedule_val = #omp<schedulekind dynamic>, ordered_val = 2} :
168+
}) {operandSegmentSizes = array<i32: 1,1,1,1,1,0,0,1>, schedule_val = #omp<schedulekind dynamic>, ordered_val = 2} :
169169
(index, index, index, memref<i32>, i32, i32) -> ()
170170

171171
// CHECK: omp.wsloop schedule(auto) nowait
172172
// CHECK-SAME: for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
173173
"omp.wsloop" (%lb, %ub, %step) ({
174174
^bb0(%iv: index):
175175
omp.yield
176-
}) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>, nowait, schedule_val = #omp<schedulekind auto>} :
176+
}) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0,0>, nowait, schedule_val = #omp<schedulekind auto>} :
177177
(index, index, index) -> ()
178178

179179
return

mlir/test/Target/LLVMIR/openmp-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) {
310310
llvm.store %3, %4 : f32, !llvm.ptr
311311
omp.yield
312312
// CHECK: call void @__kmpc_for_static_fini(ptr @[[$loc_struct]],
313-
}) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
313+
}) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
314314
omp.terminator
315315
}
316316
llvm.return
@@ -330,7 +330,7 @@ llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) {
330330
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
331331
llvm.store %3, %4 : f32, !llvm.ptr
332332
omp.yield
333-
}) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
333+
}) {operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
334334
llvm.return
335335
}
336336

@@ -348,7 +348,7 @@ llvm.func @wsloop_inclusive_2(%arg0: !llvm.ptr) {
348348
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr, i64) -> !llvm.ptr, f32
349349
llvm.store %3, %4 : f32, !llvm.ptr
350350
omp.yield
351-
}) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
351+
}) {inclusive, operandSegmentSizes = array<i32: 1, 1, 1, 0, 0, 0, 0, 0>} : (i64, i64, i64) -> ()
352352
llvm.return
353353
}
354354

0 commit comments

Comments
 (0)