Skip to content

Commit 53921ac

Browse files
author
git apple-llvm automerger
committed
Merge commit 'f59b5b8d597d' from llvm.org/main into next
2 parents 5d6519a + f59b5b8 commit 53921ac

File tree

7 files changed

+285
-87
lines changed

7 files changed

+285
-87
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
222222

223223
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
224224

225+
//===----------------------------------------------------------------------===//
226+
// target_region_flags enum.
227+
//===----------------------------------------------------------------------===//
228+
229+
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
230+
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
231+
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
232+
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
233+
234+
def TargetRegionFlags : OpenMP_BitEnumAttr<
235+
"TargetRegionFlags",
236+
"target region property flags", [
237+
TargetRegionFlagsNone,
238+
TargetRegionFlagsGeneric,
239+
TargetRegionFlagsSpmd,
240+
TargetRegionFlagsTripCount
241+
]>;
242+
225243
//===----------------------------------------------------------------------===//
226244
// variable_capture_kind enum.
227245
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
13121312
///
13131313
/// \param capturedOp result of a still valid (no modifications made to any
13141314
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
1315-
static llvm::omp::OMPTgtExecModeFlags
1315+
static ::mlir::omp::TargetRegionFlags
13161316
getKernelExecFlags(Operation *capturedOp);
13171317
}] # clausesExtraClassDeclaration;
13181318

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 123 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
19081908
return emitError("target containing multiple 'omp.teams' nested ops");
19091909

19101910
// Check that host_eval values are only used in legal ways.
1911-
llvm::omp::OMPTgtExecModeFlags execFlags =
1912-
getKernelExecFlags(getInnermostCapturedOmpOp());
1911+
Operation *capturedOp = getInnermostCapturedOmpOp();
1912+
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
19131913
for (Value hostEvalArg :
19141914
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
19151915
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
19241924
"and 'thread_limit' in 'omp.teams'";
19251925
}
19261926
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1927-
if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1927+
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1928+
parallelOp->isAncestor(capturedOp) &&
19281929
hostEvalArg == parallelOp.getNumThreads())
19291930
continue;
19301931

@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
19331934
"'omp.parallel' when representing target SPMD";
19341935
}
19351936
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1936-
if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1937+
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
1938+
loopNestOp.getOperation() == capturedOp &&
19371939
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
19381940
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
19391941
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
19401942
continue;
19411943

19421944
return emitOpError() << "host_eval argument only legal as loop bounds "
1943-
"and steps in 'omp.loop_nest' when "
1944-
"representing target SPMD or Generic-SPMD";
1945+
"and steps in 'omp.loop_nest' when trip count "
1946+
"must be evaluated in the host";
19451947
}
19461948

19471949
return emitOpError() << "host_eval argument illegal use in '"
@@ -1951,42 +1953,21 @@ LogicalResult TargetOp::verifyRegions() {
19511953
return success();
19521954
}
19531955

1954-
/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
1955-
/// effects, but don't include a memory write effect.
1956-
static bool siblingAllowedInCapture(Operation *op) {
1957-
if (!op)
1958-
return false;
1956+
static Operation *
1957+
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1958+
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
1959+
assert(rootOp && "expected valid operation");
19591960

1960-
bool isOmpDialect =
1961-
op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
1962-
op->getDialect();
1963-
1964-
if (isOmpDialect)
1965-
return op->hasTrait<OpTrait::IsTerminator>();
1966-
1967-
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1968-
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
1969-
memOp.getEffects(effects);
1970-
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
1971-
return isa<MemoryEffects::Write>(effect.getEffect()) &&
1972-
isa<SideEffects::AutomaticAllocationScopeResource>(
1973-
effect.getResource());
1974-
});
1975-
}
1976-
return true;
1977-
}
1978-
1979-
Operation *TargetOp::getInnermostCapturedOmpOp() {
1980-
Dialect *ompDialect = (*this)->getDialect();
1961+
Dialect *ompDialect = rootOp->getDialect();
19811962
Operation *capturedOp = nullptr;
19821963
DominanceInfo domInfo;
19831964

19841965
// Process in pre-order to check operations from outermost to innermost,
19851966
// ensuring we only enter the region of an operation if it meets the criteria
19861967
// for being captured. We stop the exploration of nested operations as soon as
19871968
// we process a region holding no operations to be captured.
1988-
walk<WalkOrder::PreOrder>([&](Operation *op) {
1989-
if (op == *this)
1969+
rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
1970+
if (op == rootOp)
19901971
return WalkResult::advance();
19911972

19921973
// Ignore operations of other dialects or omp operations with no regions,
@@ -2001,22 +1982,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20011982
// (i.e. its block's successors can reach it) or if it's not guaranteed to
20021983
// be executed before all exits of the region (i.e. it doesn't dominate all
20031984
// blocks with no successors reachable from the entry block).
2004-
Region *parentRegion = op->getParentRegion();
2005-
Block *parentBlock = op->getBlock();
2006-
2007-
for (Block *successor : parentBlock->getSuccessors())
2008-
if (successor->isReachable(parentBlock))
2009-
return WalkResult::interrupt();
2010-
2011-
for (Block &block : *parentRegion)
2012-
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2013-
!domInfo.dominates(parentBlock, &block))
2014-
return WalkResult::interrupt();
1985+
if (checkSingleMandatoryExec) {
1986+
Region *parentRegion = op->getParentRegion();
1987+
Block *parentBlock = op->getBlock();
1988+
1989+
for (Block *successor : parentBlock->getSuccessors())
1990+
if (successor->isReachable(parentBlock))
1991+
return WalkResult::interrupt();
1992+
1993+
for (Block &block : *parentRegion)
1994+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1995+
!domInfo.dominates(parentBlock, &block))
1996+
return WalkResult::interrupt();
1997+
}
20151998

20161999
// Don't capture this op if it has a not-allowed sibling, and stop recursing
20172000
// into nested operations.
20182001
for (Operation &sibling : op->getParentRegion()->getOps())
2019-
if (&sibling != op && !siblingAllowedInCapture(&sibling))
2002+
if (&sibling != op && !siblingAllowedFn(&sibling))
20202003
return WalkResult::interrupt();
20212004

20222005
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2012,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20292012
return capturedOp;
20302013
}
20312014

2032-
llvm::omp::OMPTgtExecModeFlags
2033-
TargetOp::getKernelExecFlags(Operation *capturedOp) {
2034-
using namespace llvm::omp;
2015+
Operation *TargetOp::getInnermostCapturedOmpOp() {
2016+
auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2017+
2018+
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
2019+
// effects, but don't include a memory write effect.
2020+
return findCapturedOmpOp(
2021+
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2022+
if (!sibling)
2023+
return false;
2024+
2025+
if (ompDialect == sibling->getDialect())
2026+
return sibling->hasTrait<OpTrait::IsTerminator>();
2027+
2028+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2030+
effects;
2031+
memOp.getEffects(effects);
2032+
return !llvm::any_of(
2033+
effects, [&](MemoryEffects::EffectInstance &effect) {
2034+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2035+
isa<SideEffects::AutomaticAllocationScopeResource>(
2036+
effect.getResource());
2037+
});
2038+
}
2039+
return true;
2040+
});
2041+
}
20352042

2043+
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
20362044
// A non-null captured op is only valid if it resides inside of a TargetOp
20372045
// and is the result of calling getInnermostCapturedOmpOp() on it.
20382046
TargetOp targetOp =
@@ -2041,60 +2049,94 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
20412049
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
20422050
"unexpected captured op");
20432051

2044-
// Make sure this region is capturing a loop. Otherwise, it's a generic
2045-
// kernel.
2052+
// If it's not capturing a loop, it's a default target region.
20462053
if (!isa_and_present<LoopNestOp>(capturedOp))
2047-
return OMP_TGT_EXEC_MODE_GENERIC;
2054+
return TargetRegionFlags::generic;
20482055

2049-
SmallVector<LoopWrapperInterface> wrappers;
2050-
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
2051-
assert(!wrappers.empty());
2056+
// Get the innermost non-simd loop wrapper.
2057+
SmallVector<LoopWrapperInterface> loopWrappers;
2058+
cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2059+
assert(!loopWrappers.empty());
20522060

2053-
// Ignore optional SIMD leaf construct.
2054-
auto *innermostWrapper = wrappers.begin();
2061+
LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
20552062
if (isa<SimdOp>(innermostWrapper))
20562063
innermostWrapper = std::next(innermostWrapper);
20572064

2058-
long numWrappers = std::distance(innermostWrapper, wrappers.end());
2059-
2060-
// Detect Generic-SPMD: target-teams-distribute[-simd].
2061-
// Detect SPMD: target-teams-loop.
2062-
if (numWrappers == 1) {
2063-
if (!isa<DistributeOp, LoopOp>(innermostWrapper))
2064-
return OMP_TGT_EXEC_MODE_GENERIC;
2065-
2066-
Operation *teamsOp = (*innermostWrapper)->getParentOp();
2067-
if (!isa_and_present<TeamsOp>(teamsOp))
2068-
return OMP_TGT_EXEC_MODE_GENERIC;
2065+
auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2066+
if (numWrappers != 1 && numWrappers != 2)
2067+
return TargetRegionFlags::generic;
20692068

2070-
if (teamsOp->getParentOp() == targetOp.getOperation())
2071-
return isa<DistributeOp>(innermostWrapper)
2072-
? OMP_TGT_EXEC_MODE_GENERIC_SPMD
2073-
: OMP_TGT_EXEC_MODE_SPMD;
2074-
}
2075-
2076-
// Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
2069+
// Detect target-teams-distribute-parallel-wsloop[-simd].
20772070
if (numWrappers == 2) {
20782071
if (!isa<WsloopOp>(innermostWrapper))
2079-
return OMP_TGT_EXEC_MODE_GENERIC;
2072+
return TargetRegionFlags::generic;
20802073

20812074
innermostWrapper = std::next(innermostWrapper);
20822075
if (!isa<DistributeOp>(innermostWrapper))
2083-
return OMP_TGT_EXEC_MODE_GENERIC;
2076+
return TargetRegionFlags::generic;
20842077

20852078
Operation *parallelOp = (*innermostWrapper)->getParentOp();
20862079
if (!isa_and_present<ParallelOp>(parallelOp))
2087-
return OMP_TGT_EXEC_MODE_GENERIC;
2080+
return TargetRegionFlags::generic;
20882081

20892082
Operation *teamsOp = parallelOp->getParentOp();
20902083
if (!isa_and_present<TeamsOp>(teamsOp))
2091-
return OMP_TGT_EXEC_MODE_GENERIC;
2084+
return TargetRegionFlags::generic;
20922085

20932086
if (teamsOp->getParentOp() == targetOp.getOperation())
2094-
return OMP_TGT_EXEC_MODE_SPMD;
2087+
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2088+
}
2089+
// Detect target-teams-distribute[-simd] and target-teams-loop.
2090+
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2091+
Operation *teamsOp = (*innermostWrapper)->getParentOp();
2092+
if (!isa_and_present<TeamsOp>(teamsOp))
2093+
return TargetRegionFlags::generic;
2094+
2095+
if (teamsOp->getParentOp() != targetOp.getOperation())
2096+
return TargetRegionFlags::generic;
2097+
2098+
if (isa<LoopOp>(innermostWrapper))
2099+
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2100+
2101+
// Find single immediately nested captured omp.parallel and add spmd flag
2102+
// (generic-spmd case).
2103+
//
2104+
// TODO: This shouldn't have to be done here, as it is too easy to break.
2105+
// The openmp-opt pass should be updated to be able to promote kernels like
2106+
// this from "Generic" to "Generic-SPMD". However, the use of the
2107+
// `kmpc_distribute_static_loop` family of functions produced by the
2108+
// OMPIRBuilder for these kernels prevents that from working.
2109+
Dialect *ompDialect = targetOp->getDialect();
2110+
Operation *nestedCapture = findCapturedOmpOp(
2111+
capturedOp, /*checkSingleMandatoryExec=*/false,
2112+
[&](Operation *sibling) {
2113+
return sibling && (ompDialect != sibling->getDialect() ||
2114+
sibling->hasTrait<OpTrait::IsTerminator>());
2115+
});
2116+
2117+
TargetRegionFlags result =
2118+
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2119+
2120+
if (!nestedCapture)
2121+
return result;
2122+
2123+
while (nestedCapture->getParentOp() != capturedOp)
2124+
nestedCapture = nestedCapture->getParentOp();
2125+
2126+
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127+
: result;
2128+
}
2129+
// Detect target-parallel-wsloop[-simd].
2130+
else if (isa<WsloopOp>(innermostWrapper)) {
2131+
Operation *parallelOp = (*innermostWrapper)->getParentOp();
2132+
if (!isa_and_present<ParallelOp>(parallelOp))
2133+
return TargetRegionFlags::generic;
2134+
2135+
if (parallelOp->getParentOp() == targetOp.getOperation())
2136+
return TargetRegionFlags::spmd;
20952137
}
20962138

2097-
return OMP_TGT_EXEC_MODE_GENERIC;
2139+
return TargetRegionFlags::generic;
20982140
}
20992141

21002142
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4646,7 +4646,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
46464646
combinedMaxThreadsVal = maxThreadsVal;
46474647

46484648
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
4649-
attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
4649+
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
4650+
assert(
4651+
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
4652+
omp::TargetRegionFlags::spmd) &&
4653+
"invalid kernel flags");
4654+
attrs.ExecFlags =
4655+
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
4656+
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
4657+
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
4658+
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
4659+
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
46504660
attrs.MinTeams = minTeamsVal;
46514661
attrs.MaxTeams.front() = maxTeamsVal;
46524662
attrs.MinThreads = 1;
@@ -4691,8 +4701,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
46914701
if (numThreads)
46924702
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
46934703

4694-
if (targetOp.getKernelExecFlags(capturedOp) !=
4695-
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
4704+
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
4705+
omp::TargetRegionFlags::trip_count)) {
46964706
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
46974707
attrs.LoopTripCount = nullptr;
46984708

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2320,7 +2320,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
23202320
// -----
23212321

23222322
func.func @omp_target_host_eval_loop1(%x : i32) {
2323-
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
2323+
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
23242324
omp.target host_eval(%x -> %arg0 : i32) {
23252325
omp.wsloop {
23262326
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2335,7 +2335,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
23352335
// -----
23362336

23372337
func.func @omp_target_host_eval_loop2(%x : i32) {
2338-
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
2338+
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
23392339
omp.target host_eval(%x -> %arg0 : i32) {
23402340
omp.teams {
23412341
^bb0:

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2864,6 +2864,23 @@ func.func @omp_target_host_eval(%x : i32) {
28642864
omp.terminator
28652865
}
28662866

2867+
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
2868+
// CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
2869+
// CHECK: omp.wsloop {
2870+
// CHECK: omp.loop_nest
2871+
omp.target host_eval(%x -> %arg0 : i32) {
2872+
%y = arith.constant 2 : i32
2873+
omp.parallel num_threads(%arg0 : i32) {
2874+
omp.wsloop {
2875+
omp.loop_nest (%iv) : i32 = (%y) to (%y) step (%y) {
2876+
omp.yield
2877+
}
2878+
}
2879+
omp.terminator
2880+
}
2881+
omp.terminator
2882+
}
2883+
28672884
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
28682885
// CHECK: omp.teams {
28692886
// CHECK: omp.distribute {

0 commit comments

Comments
 (0)