@@ -1908,8 +1908,8 @@ LogicalResult TargetOp::verifyRegions() {
1908
1908
return emitError (" target containing multiple 'omp.teams' nested ops" );
1909
1909
1910
1910
// Check that host_eval values are only used in legal ways.
1911
- llvm::omp::OMPTgtExecModeFlags execFlags =
1912
- getKernelExecFlags (getInnermostCapturedOmpOp () );
1911
+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1912
+ TargetRegionFlags execFlags = getKernelExecFlags (capturedOp );
1913
1913
for (Value hostEvalArg :
1914
1914
cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1915
1915
for (Operation *user : hostEvalArg.getUsers ()) {
@@ -1924,7 +1924,8 @@ LogicalResult TargetOp::verifyRegions() {
1924
1924
" and 'thread_limit' in 'omp.teams'" ;
1925
1925
}
1926
1926
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1927
- if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1927
+ if (bitEnumContainsAny (execFlags, TargetRegionFlags::spmd) &&
1928
+ parallelOp->isAncestor (capturedOp) &&
1928
1929
hostEvalArg == parallelOp.getNumThreads ())
1929
1930
continue ;
1930
1931
@@ -1933,15 +1934,16 @@ LogicalResult TargetOp::verifyRegions() {
1933
1934
" 'omp.parallel' when representing target SPMD" ;
1934
1935
}
1935
1936
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1936
- if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1937
+ if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
1938
+ loopNestOp.getOperation () == capturedOp &&
1937
1939
(llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1938
1940
llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1939
1941
llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1940
1942
continue ;
1941
1943
1942
1944
return emitOpError () << " host_eval argument only legal as loop bounds "
1943
- " and steps in 'omp.loop_nest' when "
1944
- " representing target SPMD or Generic-SPMD " ;
1945
+ " and steps in 'omp.loop_nest' when trip count "
1946
+ " must be evaluated in the host " ;
1945
1947
}
1946
1948
1947
1949
return emitOpError () << " host_eval argument illegal use in '"
@@ -1951,42 +1953,21 @@ LogicalResult TargetOp::verifyRegions() {
1951
1953
return success ();
1952
1954
}
1953
1955
1954
- // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1955
- // / effects, but don't include a memory write effect.
1956
- static bool siblingAllowedInCapture (Operation *op) {
1957
- if (!op)
1958
- return false ;
1956
+ static Operation *
1957
+ findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1958
+ llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
1959
+ assert (rootOp && " expected valid operation" );
1959
1960
1960
- bool isOmpDialect =
1961
- op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1962
- op->getDialect ();
1963
-
1964
- if (isOmpDialect)
1965
- return op->hasTrait <OpTrait::IsTerminator>();
1966
-
1967
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1968
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1969
- memOp.getEffects (effects);
1970
- return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1971
- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1972
- isa<SideEffects::AutomaticAllocationScopeResource>(
1973
- effect.getResource ());
1974
- });
1975
- }
1976
- return true ;
1977
- }
1978
-
1979
- Operation *TargetOp::getInnermostCapturedOmpOp () {
1980
- Dialect *ompDialect = (*this )->getDialect ();
1961
+ Dialect *ompDialect = rootOp->getDialect ();
1981
1962
Operation *capturedOp = nullptr ;
1982
1963
DominanceInfo domInfo;
1983
1964
1984
1965
// Process in pre-order to check operations from outermost to innermost,
1985
1966
// ensuring we only enter the region of an operation if it meets the criteria
1986
1967
// for being captured. We stop the exploration of nested operations as soon as
1987
1968
// we process a region holding no operations to be captured.
1988
- walk<WalkOrder::PreOrder>([&](Operation *op) {
1989
- if (op == * this )
1969
+ rootOp-> walk <WalkOrder::PreOrder>([&](Operation *op) {
1970
+ if (op == rootOp )
1990
1971
return WalkResult::advance ();
1991
1972
1992
1973
// Ignore operations of other dialects or omp operations with no regions,
@@ -2001,22 +1982,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2001
1982
// (i.e. its block's successors can reach it) or if it's not guaranteed to
2002
1983
// be executed before all exits of the region (i.e. it doesn't dominate all
2003
1984
// blocks with no successors reachable from the entry block).
2004
- Region *parentRegion = op->getParentRegion ();
2005
- Block *parentBlock = op->getBlock ();
2006
-
2007
- for (Block *successor : parentBlock->getSuccessors ())
2008
- if (successor->isReachable (parentBlock))
2009
- return WalkResult::interrupt ();
2010
-
2011
- for (Block &block : *parentRegion)
2012
- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2013
- !domInfo.dominates (parentBlock, &block))
2014
- return WalkResult::interrupt ();
1985
+ if (checkSingleMandatoryExec) {
1986
+ Region *parentRegion = op->getParentRegion ();
1987
+ Block *parentBlock = op->getBlock ();
1988
+
1989
+ for (Block *successor : parentBlock->getSuccessors ())
1990
+ if (successor->isReachable (parentBlock))
1991
+ return WalkResult::interrupt ();
1992
+
1993
+ for (Block &block : *parentRegion)
1994
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1995
+ !domInfo.dominates (parentBlock, &block))
1996
+ return WalkResult::interrupt ();
1997
+ }
2015
1998
2016
1999
// Don't capture this op if it has a not-allowed sibling, and stop recursing
2017
2000
// into nested operations.
2018
2001
for (Operation &sibling : op->getParentRegion ()->getOps ())
2019
- if (&sibling != op && !siblingAllowedInCapture (&sibling))
2002
+ if (&sibling != op && !siblingAllowedFn (&sibling))
2020
2003
return WalkResult::interrupt ();
2021
2004
2022
2005
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2029,10 +2012,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2029
2012
return capturedOp;
2030
2013
}
2031
2014
2032
- llvm::omp::OMPTgtExecModeFlags
2033
- TargetOp::getKernelExecFlags (Operation *capturedOp) {
2034
- using namespace llvm ::omp;
2015
+ Operation *TargetOp::getInnermostCapturedOmpOp () {
2016
+ auto *ompDialect = getContext ()->getLoadedDialect <omp::OpenMPDialect>();
2017
+
2018
+ // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2019
+ // effects, but don't include a memory write effect.
2020
+ return findCapturedOmpOp (
2021
+ *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2022
+ if (!sibling)
2023
+ return false ;
2024
+
2025
+ if (ompDialect == sibling->getDialect ())
2026
+ return sibling->hasTrait <OpTrait::IsTerminator>();
2027
+
2028
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2030
+ effects;
2031
+ memOp.getEffects (effects);
2032
+ return !llvm::any_of (
2033
+ effects, [&](MemoryEffects::EffectInstance &effect) {
2034
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2035
+ isa<SideEffects::AutomaticAllocationScopeResource>(
2036
+ effect.getResource ());
2037
+ });
2038
+ }
2039
+ return true ;
2040
+ });
2041
+ }
2035
2042
2043
+ TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
2036
2044
// A non-null captured op is only valid if it resides inside of a TargetOp
2037
2045
// and is the result of calling getInnermostCapturedOmpOp() on it.
2038
2046
TargetOp targetOp =
@@ -2041,60 +2049,94 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
2041
2049
(targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
2042
2050
" unexpected captured op" );
2043
2051
2044
- // Make sure this region is capturing a loop. Otherwise, it's a generic
2045
- // kernel.
2052
+ // If it's not capturing a loop, it's a default target region.
2046
2053
if (!isa_and_present<LoopNestOp>(capturedOp))
2047
- return OMP_TGT_EXEC_MODE_GENERIC ;
2054
+ return TargetRegionFlags::generic ;
2048
2055
2049
- SmallVector<LoopWrapperInterface> wrappers;
2050
- cast<LoopNestOp>(capturedOp).gatherWrappers (wrappers);
2051
- assert (!wrappers.empty ());
2056
+ // Get the innermost non-simd loop wrapper.
2057
+ SmallVector<LoopWrapperInterface> loopWrappers;
2058
+ cast<LoopNestOp>(capturedOp).gatherWrappers (loopWrappers);
2059
+ assert (!loopWrappers.empty ());
2052
2060
2053
- // Ignore optional SIMD leaf construct.
2054
- auto *innermostWrapper = wrappers.begin ();
2061
+ LoopWrapperInterface *innermostWrapper = loopWrappers.begin ();
2055
2062
if (isa<SimdOp>(innermostWrapper))
2056
2063
innermostWrapper = std::next (innermostWrapper);
2057
2064
2058
- long numWrappers = std::distance (innermostWrapper, wrappers.end ());
2059
-
2060
- // Detect Generic-SPMD: target-teams-distribute[-simd].
2061
- // Detect SPMD: target-teams-loop.
2062
- if (numWrappers == 1 ) {
2063
- if (!isa<DistributeOp, LoopOp>(innermostWrapper))
2064
- return OMP_TGT_EXEC_MODE_GENERIC;
2065
-
2066
- Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2067
- if (!isa_and_present<TeamsOp>(teamsOp))
2068
- return OMP_TGT_EXEC_MODE_GENERIC;
2065
+ auto numWrappers = std::distance (innermostWrapper, loopWrappers.end ());
2066
+ if (numWrappers != 1 && numWrappers != 2 )
2067
+ return TargetRegionFlags::generic;
2069
2068
2070
- if (teamsOp->getParentOp () == targetOp.getOperation ())
2071
- return isa<DistributeOp>(innermostWrapper)
2072
- ? OMP_TGT_EXEC_MODE_GENERIC_SPMD
2073
- : OMP_TGT_EXEC_MODE_SPMD;
2074
- }
2075
-
2076
- // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
2069
+ // Detect target-teams-distribute-parallel-wsloop[-simd].
2077
2070
if (numWrappers == 2 ) {
2078
2071
if (!isa<WsloopOp>(innermostWrapper))
2079
- return OMP_TGT_EXEC_MODE_GENERIC ;
2072
+ return TargetRegionFlags::generic ;
2080
2073
2081
2074
innermostWrapper = std::next (innermostWrapper);
2082
2075
if (!isa<DistributeOp>(innermostWrapper))
2083
- return OMP_TGT_EXEC_MODE_GENERIC ;
2076
+ return TargetRegionFlags::generic ;
2084
2077
2085
2078
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2086
2079
if (!isa_and_present<ParallelOp>(parallelOp))
2087
- return OMP_TGT_EXEC_MODE_GENERIC ;
2080
+ return TargetRegionFlags::generic ;
2088
2081
2089
2082
Operation *teamsOp = parallelOp->getParentOp ();
2090
2083
if (!isa_and_present<TeamsOp>(teamsOp))
2091
- return OMP_TGT_EXEC_MODE_GENERIC ;
2084
+ return TargetRegionFlags::generic ;
2092
2085
2093
2086
if (teamsOp->getParentOp () == targetOp.getOperation ())
2094
- return OMP_TGT_EXEC_MODE_SPMD;
2087
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2088
+ }
2089
+ // Detect target-teams-distribute[-simd] and target-teams-loop.
2090
+ else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2091
+ Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2092
+ if (!isa_and_present<TeamsOp>(teamsOp))
2093
+ return TargetRegionFlags::generic;
2094
+
2095
+ if (teamsOp->getParentOp () != targetOp.getOperation ())
2096
+ return TargetRegionFlags::generic;
2097
+
2098
+ if (isa<LoopOp>(innermostWrapper))
2099
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2100
+
2101
+ // Find single immediately nested captured omp.parallel and add spmd flag
2102
+ // (generic-spmd case).
2103
+ //
2104
+ // TODO: This shouldn't have to be done here, as it is too easy to break.
2105
+ // The openmp-opt pass should be updated to be able to promote kernels like
2106
+ // this from "Generic" to "Generic-SPMD". However, the use of the
2107
+ // `kmpc_distribute_static_loop` family of functions produced by the
2108
+ // OMPIRBuilder for these kernels prevents that from working.
2109
+ Dialect *ompDialect = targetOp->getDialect ();
2110
+ Operation *nestedCapture = findCapturedOmpOp (
2111
+ capturedOp, /* checkSingleMandatoryExec=*/ false ,
2112
+ [&](Operation *sibling) {
2113
+ return sibling && (ompDialect != sibling->getDialect () ||
2114
+ sibling->hasTrait <OpTrait::IsTerminator>());
2115
+ });
2116
+
2117
+ TargetRegionFlags result =
2118
+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2119
+
2120
+ if (!nestedCapture)
2121
+ return result;
2122
+
2123
+ while (nestedCapture->getParentOp () != capturedOp)
2124
+ nestedCapture = nestedCapture->getParentOp ();
2125
+
2126
+ return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2127
+ : result;
2128
+ }
2129
+ // Detect target-parallel-wsloop[-simd].
2130
+ else if (isa<WsloopOp>(innermostWrapper)) {
2131
+ Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2132
+ if (!isa_and_present<ParallelOp>(parallelOp))
2133
+ return TargetRegionFlags::generic;
2134
+
2135
+ if (parallelOp->getParentOp () == targetOp.getOperation ())
2136
+ return TargetRegionFlags::spmd;
2095
2137
}
2096
2138
2097
- return OMP_TGT_EXEC_MODE_GENERIC ;
2139
+ return TargetRegionFlags::generic ;
2098
2140
}
2099
2141
2100
2142
// ===----------------------------------------------------------------------===//
0 commit comments