Skip to content

Commit 32a884c

Browse files
committed
[mlir] Add translation of omp.wsloop to LLVM IR
Introduce a translation of OpenMP workshare loop construct to LLVM IR. This is a minimalist version to enable the pipeline and currently only supports static loop schedule (default in the specification) on non-collapsed loops. Other features will be added on per-need basis. Reviewed By: kiranchandramohan Differential Revision: https://reviews.llvm.org/D92055
1 parent 8451d48 commit 32a884c

File tree

4 files changed

+168
-1
lines changed

4 files changed

+168
-1
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def WsLoopOp : OpenMP_Op<"wsloop", [AttrSizedOperandSegments]> {
185185
];
186186

187187
let regions = (region AnyRegion:$region);
188+
189+
let extraClassDeclaration = [{
190+
/// Returns the number of loops in the workshape loop nest.
191+
unsigned getNumLoops() { return lowerBound().size(); }
192+
}];
188193
}
189194

190195
def YieldOp : OpenMP_Op<"yield", [NoSideEffect, ReturnLike, Terminator,

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ class ModuleTranslation {
100100
llvm::BasicBlock &continuationIP,
101101
llvm::IRBuilder<> &builder,
102102
LogicalResult &bodyGenStatus);
103+
virtual LogicalResult convertOmpWsLoop(Operation &opInst,
104+
llvm::IRBuilder<> &builder);
105+
103106
/// Converts the type from MLIR LLVM dialect to LLVM.
104107
llvm::Type *convertType(LLVMType type);
105108

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,126 @@ LogicalResult ModuleTranslation::convertOmpMaster(Operation &opInst,
536536
return success();
537537
}
538538

539+
/// Converts an OpenMP workshare loop into LLVM IR using OpenMPIRBuilder.
540+
LogicalResult ModuleTranslation::convertOmpWsLoop(Operation &opInst,
541+
llvm::IRBuilder<> &builder) {
542+
auto loop = cast<omp::WsLoopOp>(opInst);
543+
// TODO: this should be in the op verifier instead.
544+
if (loop.lowerBound().empty())
545+
return failure();
546+
547+
if (loop.getNumLoops() != 1)
548+
return opInst.emitOpError("collapsed loops not yet supported");
549+
550+
if (loop.schedule_val().hasValue() &&
551+
omp::symbolizeClauseScheduleKind(loop.schedule_val().getValue()) !=
552+
omp::ClauseScheduleKind::Static)
553+
return opInst.emitOpError(
554+
"only static (default) loop schedule is currently supported");
555+
556+
llvm::Function *func = builder.GetInsertBlock()->getParent();
557+
llvm::LLVMContext &llvmContext = llvmModule->getContext();
558+
559+
// Find the loop configuration.
560+
llvm::Value *lowerBound = valueMapping.lookup(loop.lowerBound()[0]);
561+
llvm::Value *upperBound = valueMapping.lookup(loop.upperBound()[0]);
562+
llvm::Value *step = valueMapping.lookup(loop.step()[0]);
563+
llvm::Type *ivType = step->getType();
564+
llvm::Value *chunk = loop.schedule_chunk_var()
565+
? valueMapping[loop.schedule_chunk_var()]
566+
: llvm::ConstantInt::get(ivType, 1);
567+
568+
// Set up the source location value for OpenMP runtime.
569+
llvm::DISubprogram *subprogram =
570+
builder.GetInsertBlock()->getParent()->getSubprogram();
571+
const llvm::DILocation *diLoc =
572+
debugTranslation->translateLoc(opInst.getLoc(), subprogram);
573+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder.saveIP(),
574+
llvm::DebugLoc(diLoc));
575+
576+
// Generator of the canonical loop body. Produces an SESE region of basic
577+
// blocks.
578+
// TODO: support error propagation in OpenMPIRBuilder and use it instead of
579+
// relying on captured variables.
580+
LogicalResult bodyGenStatus = success();
581+
auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
582+
llvm::IRBuilder<>::InsertPointGuard guard(builder);
583+
584+
// Make sure further conversions know about the induction variable.
585+
valueMapping[loop.getRegion().front().getArgument(0)] = iv;
586+
587+
llvm::BasicBlock *entryBlock = ip.getBlock();
588+
llvm::BasicBlock *exitBlock =
589+
entryBlock->splitBasicBlock(ip.getPoint(), "omp.wsloop.exit");
590+
591+
// Convert the body of the loop.
592+
Region &region = loop.region();
593+
for (Block &bb : region) {
594+
llvm::BasicBlock *llvmBB =
595+
llvm::BasicBlock::Create(llvmContext, "omp.wsloop.region", func);
596+
blockMapping[&bb] = llvmBB;
597+
598+
// Retarget the branch of the entry block to the entry block of the
599+
// converted region (regions are single-entry).
600+
if (bb.isEntryBlock()) {
601+
auto *branch = cast<llvm::BranchInst>(entryBlock->getTerminator());
602+
branch->setSuccessor(0, llvmBB);
603+
}
604+
}
605+
606+
// Block conversion creates a new IRBuilder every time so need not bother
607+
// about maintaining the insertion point.
608+
llvm::SetVector<Block *> blocks = topologicalSort(region);
609+
for (Block *bb : blocks) {
610+
if (failed(convertBlock(*bb, bb->isEntryBlock()))) {
611+
bodyGenStatus = failure();
612+
return;
613+
}
614+
615+
// Special handling for `omp.yield` terminators (we may have more than
616+
// one): they return the control to the parent WsLoop operation so replace
617+
// them with the branch to the exit block. We handle this here to avoid
618+
// relying inter-function communication through the ModuleTranslation
619+
// class to set up the correct insertion point. This is also consistent
620+
// with MLIR's idiom of handling special region terminators in the same
621+
// code that handles the region-owning operation.
622+
if (isa<omp::YieldOp>(bb->getTerminator())) {
623+
llvm::BasicBlock *llvmBB = blockMapping[bb];
624+
builder.SetInsertPoint(llvmBB, llvmBB->end());
625+
builder.CreateBr(exitBlock);
626+
}
627+
}
628+
629+
connectPHINodes(region, valueMapping, blockMapping, branchMapping);
630+
};
631+
632+
// Delegate actual loop construction to the OpenMP IRBuilder.
633+
// TODO: this currently assumes WsLoop is semantically similar to SCF loop,
634+
// i.e. it has a positive step, uses signed integer semantics, and its upper
635+
// bound is not included. Reconsider this code when WsLoop clearly supports
636+
// more cases.
637+
llvm::BasicBlock *insertBlock = builder.GetInsertBlock();
638+
llvm::CanonicalLoopInfo *loopInfo = ompBuilder->createCanonicalLoop(
639+
ompLoc, bodyGen, lowerBound, upperBound, step, /*IsSigned=*/true,
640+
/*InclusiveStop=*/false);
641+
if (failed(bodyGenStatus))
642+
return failure();
643+
644+
// TODO: get the alloca insertion point from the parallel operation builder.
645+
// If we insert the at the top of the current function, they will be passed as
646+
// extra arguments into the function the parallel operation builder outlines.
647+
// Put them at the start of the current block for now.
648+
llvm::OpenMPIRBuilder::InsertPointTy allocaIP(
649+
insertBlock, insertBlock->getFirstInsertionPt());
650+
loopInfo = ompBuilder->createStaticWorkshareLoop(
651+
ompLoc, loopInfo, allocaIP,
652+
!loop.nowait().hasValue() || loop.nowait().getValue(), chunk);
653+
654+
// Continue building IR after the loop.
655+
builder.restoreIP(loopInfo->getAfterIP());
656+
return success();
657+
}
658+
539659
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
540660
/// (including OpenMP runtime calls).
541661
LogicalResult
@@ -577,6 +697,13 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
577697
.Case(
578698
[&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
579699
.Case([&](omp::MasterOp) { return convertOmpMaster(opInst, builder); })
700+
.Case([&](omp::WsLoopOp) { return convertOmpWsLoop(opInst, builder); })
701+
.Case([&](omp::YieldOp op) {
702+
// Yields are loop terminators that can be just omitted. The loop
703+
// structure was created in the function that handles WsLoopOp.
704+
assert(op.getNumOperands() == 0 && "unexpected yield with operands");
705+
return success();
706+
})
580707
.Default([&](Operation *inst) {
581708
return inst->emitError("unsupported OpenMP operation: ")
582709
<< inst->getName();

mlir/test/Target/openmp-llvm.mlir

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
1+
// RUN: mlir-translate -mlir-to-llvmir -split-input-file %s | FileCheck %s
22

33
// CHECK-LABEL: define void @test_stand_alone_directives()
44
llvm.func @test_stand_alone_directives() {
@@ -291,3 +291,35 @@ llvm.func @test_omp_master() -> () {
291291
}
292292
llvm.return
293293
}
294+
295+
// -----
296+
297+
// CHECK: %struct.ident_t = type
298+
// CHECK: @[[$parallel_loc:.*]] = private unnamed_addr constant {{.*}} c";LLVMDialectModule;wsloop_simple;{{[0-9]+}};{{[0-9]+}};;\00"
299+
// CHECK: @[[$parallel_loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$parallel_loc]], {{.*}}
300+
301+
// CHECK: @[[$wsloop_loc:.*]] = private unnamed_addr constant {{.*}} c";LLVMDialectModule;wsloop_simple;{{[0-9]+}};{{[0-9]+}};;\00"
302+
// CHECK: @[[$wsloop_loc_struct:.*]] = private unnamed_addr constant %struct.ident_t {{.*}} @[[$wsloop_loc]], {{.*}}
303+
304+
// CHECK-LABEL: @wsloop_simple
305+
llvm.func @wsloop_simple(%arg0: !llvm.ptr<float>) {
306+
%0 = llvm.mlir.constant(42 : index) : !llvm.i64
307+
%1 = llvm.mlir.constant(10 : index) : !llvm.i64
308+
%2 = llvm.mlir.constant(1 : index) : !llvm.i64
309+
omp.parallel {
310+
"omp.wsloop"(%1, %0, %2) ( {
311+
^bb0(%arg1: !llvm.i64):
312+
// The form of the emitted IR is controlled by OpenMPIRBuilder and
313+
// tested there. Just check that the right functions are called.
314+
// CHECK: call i32 @__kmpc_global_thread_num
315+
// CHECK: call void @__kmpc_for_static_init_{{.*}}(%struct.ident_t* @[[$wsloop_loc_struct]],
316+
%3 = llvm.mlir.constant(2.000000e+00 : f32) : !llvm.float
317+
%4 = llvm.getelementptr %arg0[%arg1] : (!llvm.ptr<float>, !llvm.i64) -> !llvm.ptr<float>
318+
llvm.store %3, %4 : !llvm.ptr<float>
319+
omp.yield
320+
// CHECK: call void @__kmpc_for_static_fini(%struct.ident_t* @[[$wsloop_loc_struct]],
321+
}) {operand_segment_sizes = dense<[1, 1, 1, 0, 0, 0, 0, 0, 0]> : vector<9xi32>} : (!llvm.i64, !llvm.i64, !llvm.i64) -> ()
322+
omp.terminator
323+
}
324+
llvm.return
325+
}

0 commit comments

Comments
 (0)