Skip to content

Commit c83bdc7

Browse files
authored
[MLIR][OpenMP] Normalize lowering of omp.loop_nest (#127217)
This patch refactors the translation of `omp.loop_nest` operations into LLVM IR so that it is handled similarly to other operations. Before this change, the responsibility of translating the loop nest fell into each loop wrapper, causing code duplication. This patch centralizes that handling of the loop. One consequence of this was fixing an issue lowering non-inclusive `omp.simd` loops. As a result, it is now expected that the handling of composite constructs is performed collaboratively among translating functions for each operation involved. At the moment, only `do/for simd` is supported by ignoring SIMD information, and this behavior is preserved. The translation of loop wrapper operations needs access to the `llvm::CanonicalLoopInfo` loop information structure in order to apply transformations to it. This is now created in the nested call to `convertOmpLoopNest`, so it needs to be passed up to all associated loop wrapper translation functions. This is done via the creation of an `OpenMPLoopInfoStackFrame` within `convertHostOrTargetOperation`, associated to the outermost loop wrapper. This structure is updated by `convertOmpLoopNest`, making the result available to all loop wrappers after their body has been translated.
1 parent 72768d9 commit c83bdc7

File tree

4 files changed

+285
-285
lines changed

4 files changed

+285
-285
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,12 @@ class ModuleTranslation {
290290
/// Calls `callback` for every ModuleTranslation stack frame of type `T`
291291
/// starting from the top of the stack.
292292
template <typename T>
293-
WalkResult
294-
stackWalk(llvm::function_ref<WalkResult(const T &)> callback) const {
293+
WalkResult stackWalk(llvm::function_ref<WalkResult(T &)> callback) {
295294
static_assert(std::is_base_of<StackFrame, T>::value,
296295
"expected T derived from StackFrame");
297296
if (!callback)
298297
return WalkResult::skip();
299-
for (const std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
298+
for (std::unique_ptr<StackFrame> &frame : llvm::reverse(stack)) {
300299
if (T *ptr = dyn_cast_or_null<T>(frame.get())) {
301300
WalkResult result = callback(*ptr);
302301
if (result.wasInterrupted())

0 commit comments

Comments
 (0)