Skip to content

[MLIR][OpenMP] Simplify OpenMP to LLVM dialect conversion #132009

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 20, 2025

Conversation

skatrak
Copy link
Member

@skatrak skatrak commented Mar 19, 2025

This patch makes a few changes to unify the conversion process from the 'omp' to the 'llvm' dialect. The main goal of this change is to consolidate the logic used to identify legal and illegal ops, and to consolidate the conversion logic into a single class.

Changes introduced are the following:

  • Removal of getNumVariableOperands() and getVariableOperand() extra class declarations from OpenMP operations. These are redundant, as they are equivalent to mlir::Operation::getNumOperands() and mlir::Operation::getOperands(), respectively.
  • Consolidation of RegionOpConversion, RegionLessOpWithVarOperandsConversion, RegionOpWithVarOperandsConversion, RegionLessOpConversion, AtomicReadOpConversion, MapInfoOpConversion, DeclMapperOpConversion and MultiRegionOpConversion into a single OpenMPOpConversion class. This is possible because all of the previous were doing parts of the same set of operations based on whether they defined any regions, whether they took operands, type attributes, etc.
  • Update of mlir::configureOpenMPToLLVMConversionLegality to use a single generic set of checks for all operations, removing the need to list every operation manually.

This patch makes a few changes to unify the conversion process from the 'omp'
to the 'llvm' dialect. The main goal of this change is to consolidate the logic
used to identify legal and illegal ops, and to consolidate the conversion logic
into a single class.

Changes introduced are the following:
  - Removal of `getNumVariableOperands()` and `getVariableOperand()` extra
class declarations from OpenMP operations. These are redundant, as they are
equivalent to `mlir::Operation::getNumOperands()` and
`mlir::Operation::getOperands()`, respectively.
  - Consolidation of `RegionOpConversion`,
`RegionLessOpWithVarOperandsConversion`, `RegionOpWithVarOperandsConversion`,
`RegionLessOpConversion`, `AtomicReadOpConversion`, `MapInfoOpConversion`,
`DeclMapperOpConversion` and `MultiRegionOpConversion` into a single
`OpenMPOpConversion` class. This is possible because all of the previous were
doing parts of the same set of operations based on whether they defined any
regions, whether they took operands, type attributes, etc.
  - Update of `mlir::configureOpenMPToLLVMConversionLegality` to use a single
generic set of checks for all operations, removing the need to list every
operation manually.
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2025

@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-flang-openmp

Author: Sergio Afonso (skatrak)

Changes

This patch makes a few changes to unify the conversion process from the 'omp' to the 'llvm' dialect. The main goal of this change is to consolidate the logic used to identify legal and illegal ops, and to consolidate the conversion logic into a single class.

Changes introduced are the following:

  • Removal of getNumVariableOperands() and getVariableOperand() extra class declarations from OpenMP operations. These are redundant, as they are equivalent to mlir::Operation::getNumOperands() and mlir::Operation::getOperands(), respectively.
  • Consolidation of RegionOpConversion, RegionLessOpWithVarOperandsConversion, RegionOpWithVarOperandsConversion, RegionLessOpConversion, AtomicReadOpConversion, MapInfoOpConversion, DeclMapperOpConversion and MultiRegionOpConversion into a single OpenMPOpConversion class. This is possible because all of the previous were doing parts of the same set of operations based on whether they defined any regions, whether they took operands, type attributes, etc.
  • Update of mlir::configureOpenMPToLLVMConversionLegality to use a single generic set of checks for all operations, removing the need to list every operation manually.

Patch is 23.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132009.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (-92)
  • (modified) mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp (+95-250)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 2c2ecdd225f4a..63b38835d133f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -891,17 +891,6 @@ def FlushOp : OpenMP_Op<"flush", clauses = [
 
   // Override inherited assembly format to include `varList`.
   let assemblyFormat = "( `(` $varList^ `:` type($varList) `)` )? attr-dict";
-
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      return getOperation()->getNumOperands();
-    }
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      return getOperand(i);
-    }
-  }] # clausesExtraClassDeclaration;
 }
 
 //===----------------------------------------------------------------------===//
@@ -1001,18 +990,6 @@ def MapBoundsOp : OpenMP_Op<"map.bounds",
     ) attr-dict
   }];
 
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      return getNumOperands();
-    }
-
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      return getOperands()[i];
-    }
-  }];
-
   let hasVerifier = 1;
 }
 
@@ -1098,18 +1075,6 @@ def MapInfoOp : OpenMP_Op<"map.info", [AttrSizedOperandSegments]> {
       | `bounds` `(` $bounds `)`
     ) `->` type($omp_ptr) attr-dict
   }];
-
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      return getNumOperands();
-    }
-
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      return getOperands()[i];
-    }
-  }];
 }
 
 //===---------------------------------------------------------------------===//
@@ -1515,21 +1480,6 @@ def AtomicReadOp : OpenMP_Op<"atomic.read", traits = [
     clausesOptAssemblyFormat #
     ") `:` type($v) `,` type($x) `,` $element_type attr-dict";
 
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      assert(getX() && "expected 'x' operand");
-      assert(getV() && "expected 'v' operand");
-      return 2;
-    }
-
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      assert(i < 2 && "invalid index position for an operand");
-      return i == 0 ? getX() : getV();
-    }
-  }] # clausesExtraClassDeclaration;
-
   let hasVerifier = 1;
 }
 
@@ -1555,21 +1505,6 @@ def AtomicWriteOp : OpenMP_Op<"atomic.write", traits = [
   let assemblyFormat = "$x `=` $expr" # clausesReqAssemblyFormat # " oilist(" #
     clausesOptAssemblyFormat # ") `:` type($x) `,` type($expr) attr-dict";
 
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      assert(getX() && "expected address operand");
-      assert(getExpr() && "expected value operand");
-      return 2;
-    }
-
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      assert(i < 2 && "invalid index position for an operand");
-      return i == 0 ? getX() : getExpr();
-    }
-  }] # clausesExtraClassDeclaration;
-
   let hasVerifier = 1;
 }
 
@@ -1614,20 +1549,6 @@ def AtomicUpdateOp : OpenMP_Op<"atomic.update", traits = [
   let assemblyFormat = clausesAssemblyFormat #
     "$x `:` type($x) $region attr-dict";
 
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      assert(getX() && "expected 'x' operand");
-      return 1;
-    }
-
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      assert(i == 0 && "invalid index position for an operand");
-      return getX();
-    }
-  }] # clausesExtraClassDeclaration;
-
   let hasVerifier = 1;
   let hasRegionVerifier = 1;
   let hasCanonicalizeMethod = 1;
@@ -1715,19 +1636,6 @@ def ThreadprivateOp : OpenMP_Op<"threadprivate",
   let assemblyFormat = [{
     $sym_addr `:` type($sym_addr) `->` type($tls_addr) attr-dict
   }];
-  let extraClassDeclaration = [{
-    /// The number of variable operands.
-    unsigned getNumVariableOperands() {
-      assert(getSymAddr() && "expected one variable operand");
-      return 1;
-    }
-
-    /// The i-th variable operand passed.
-    Value getVariableOperand(unsigned i) {
-      assert(i == 0 && "invalid index position for an operand");
-      return getSymAddr();
-    }
-  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
index 7888745dc6920..218260bd53a1e 100644
--- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
+++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
@@ -28,262 +28,101 @@ namespace mlir {
 using namespace mlir;
 
 namespace {
-/// A pattern that converts the region arguments in a single-region OpenMP
-/// operation to the LLVM dialect. The body of the region is not modified and is
-/// expected to either be processed by the conversion infrastructure or already
-/// contain ops compatible with LLVM dialect types.
-template <typename OpType>
-struct RegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
-  using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
 
-  LogicalResult
-  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto newOp = rewriter.create<OpType>(
-        curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
-    rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
-                                newOp.getRegion().end());
-    if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
-                                           *this->getTypeConverter())))
-      return failure();
-
-    rewriter.eraseOp(curOp);
-    return success();
-  }
-};
-
-template <typename T>
-struct RegionLessOpWithVarOperandsConversion
-    : public ConvertOpToLLVMPattern<T> {
+/// A pattern that converts the result and operand types, attributes, and region
+/// arguments of an OpenMP operation to the LLVM dialect.
+///
+/// Attributes are copied verbatim by default, and only translated if they are
+/// type attributes.
+///
+/// Region bodies, if any, are not modified and expected to either be processed
+/// by the conversion infrastructure or already contain ops compatible with LLVM
+/// dialect types.
+template <typename T, bool SupportsMemRefOperand = true>
+struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
   using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-    SmallVector<Type> resTypes;
-    if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
-      return failure();
-    SmallVector<Value> convertedOperands;
-    assert(curOp.getNumVariableOperands() ==
-               curOp.getOperation()->getNumOperands() &&
-           "unexpected non-variable operands");
-    for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
-      Value originalVariableOperand = curOp.getVariableOperand(idx);
-      if (!originalVariableOperand)
-        return failure();
-      if (isa<MemRefType>(originalVariableOperand.getType())) {
-        // TODO: Support memref type in variable operands
-        return rewriter.notifyMatchFailure(curOp,
-                                           "memref is not supported yet");
-      }
-      convertedOperands.emplace_back(adaptor.getOperands()[idx]);
-    }
 
-    rewriter.replaceOpWithNewOp<T>(curOp, resTypes, convertedOperands,
-                                   curOp->getAttrs());
-    return success();
-  }
-};
-
-template <typename T>
-struct RegionOpWithVarOperandsConversion : public ConvertOpToLLVMPattern<T> {
-  using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(T curOp, typename T::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-    SmallVector<Type> resTypes;
-    if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
-      return failure();
-    SmallVector<Value> convertedOperands;
-    assert(curOp.getNumVariableOperands() ==
-               curOp.getOperation()->getNumOperands() &&
-           "unexpected non-variable operands");
-    for (unsigned idx = 0; idx < curOp.getNumVariableOperands(); ++idx) {
-      Value originalVariableOperand = curOp.getVariableOperand(idx);
-      if (!originalVariableOperand)
-        return failure();
-      if (isa<MemRefType>(originalVariableOperand.getType())) {
-        // TODO: Support memref type in variable operands
-        return rewriter.notifyMatchFailure(curOp,
-                                           "memref is not supported yet");
-      }
-      convertedOperands.emplace_back(adaptor.getOperands()[idx]);
-    }
-    auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
-                                    curOp->getAttrs());
-    rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
-                                newOp.getRegion().end());
-    if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
-                                           *this->getTypeConverter())))
-      return failure();
-
-    rewriter.eraseOp(curOp);
-    return success();
-  }
-};
-
-template <typename T>
-struct RegionLessOpConversion : public ConvertOpToLLVMPattern<T> {
-  using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
   LogicalResult
   matchAndRewrite(T curOp, typename T::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
+    // Translate result types.
     const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
     SmallVector<Type> resTypes;
     if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
       return failure();
 
-    rewriter.replaceOpWithNewOp<T>(curOp, resTypes, adaptor.getOperands(),
-                                   curOp->getAttrs());
-    return success();
-  }
-};
-
-struct AtomicReadOpConversion
-    : public ConvertOpToLLVMPattern<omp::AtomicReadOp> {
-  using ConvertOpToLLVMPattern<omp::AtomicReadOp>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(omp::AtomicReadOp curOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-    Type curElementType = curOp.getElementType();
-    auto newOp = rewriter.create<omp::AtomicReadOp>(
-        curOp.getLoc(), TypeRange(), adaptor.getOperands(), curOp->getAttrs());
-    TypeAttr typeAttr = TypeAttr::get(converter->convertType(curElementType));
-    newOp.setElementTypeAttr(typeAttr);
-    rewriter.eraseOp(curOp);
-    return success();
-  }
-};
-
-struct MapInfoOpConversion : public ConvertOpToLLVMPattern<omp::MapInfoOp> {
-  using ConvertOpToLLVMPattern<omp::MapInfoOp>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(omp::MapInfoOp curOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-
-    SmallVector<Type> resTypes;
-    if (failed(converter->convertTypes(curOp->getResultTypes(), resTypes)))
-      return failure();
-
-    // Copy attributes of the curOp except for the typeAttr which should
-    // be converted
-    SmallVector<NamedAttribute> newAttrs;
+    // Translate type attributes.
+    // They are kept unmodified except if they are type attributes.
+    SmallVector<NamedAttribute> convertedAttrs;
     for (NamedAttribute attr : curOp->getAttrs()) {
       if (auto typeAttr = dyn_cast<TypeAttr>(attr.getValue())) {
-        Type newAttr = converter->convertType(typeAttr.getValue());
-        newAttrs.emplace_back(attr.getName(), TypeAttr::get(newAttr));
+        Type convertedType = converter->convertType(typeAttr.getValue());
+        convertedAttrs.emplace_back(attr.getName(),
+                                    TypeAttr::get(convertedType));
       } else {
-        newAttrs.push_back(attr);
+        convertedAttrs.push_back(attr);
       }
     }
 
-    rewriter.replaceOpWithNewOp<omp::MapInfoOp>(
-        curOp, resTypes, adaptor.getOperands(), newAttrs);
-    return success();
-  }
-};
-
-struct DeclMapperOpConversion
-    : public ConvertOpToLLVMPattern<omp::DeclareMapperOp> {
-  using ConvertOpToLLVMPattern<omp::DeclareMapperOp>::ConvertOpToLLVMPattern;
-  LogicalResult
-  matchAndRewrite(omp::DeclareMapperOp curOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    const TypeConverter *converter = ConvertToLLVMPattern::getTypeConverter();
-    SmallVector<NamedAttribute> newAttrs;
-    newAttrs.emplace_back(curOp.getSymNameAttrName(), curOp.getSymNameAttr());
-    newAttrs.emplace_back(
-        curOp.getTypeAttrName(),
-        TypeAttr::get(converter->convertType(curOp.getType())));
-
-    auto newOp = rewriter.create<omp::DeclareMapperOp>(
-        curOp.getLoc(), TypeRange(), adaptor.getOperands(), newAttrs);
-    rewriter.inlineRegionBefore(curOp.getRegion(), newOp.getRegion(),
-                                newOp.getRegion().end());
-    if (failed(rewriter.convertRegionTypes(&newOp.getRegion(),
-                                           *this->getTypeConverter())))
-      return failure();
-
-    rewriter.eraseOp(curOp);
-    return success();
-  }
-};
-
-template <typename OpType>
-struct MultiRegionOpConversion : public ConvertOpToLLVMPattern<OpType> {
-  using ConvertOpToLLVMPattern<OpType>::ConvertOpToLLVMPattern;
-
-  void forwardOpAttrs(OpType curOp, OpType newOp) const {}
+    // Translate operands.
+    SmallVector<Value> convertedOperands;
+    convertedOperands.reserve(curOp->getNumOperands());
+    for (auto [originalOperand, convertedOperand] :
+         llvm::zip_equal(curOp->getOperands(), adaptor.getOperands())) {
+      if (!originalOperand)
+        return failure();
 
-  LogicalResult
-  matchAndRewrite(OpType curOp, typename OpType::Adaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    auto newOp = rewriter.create<OpType>(
-        curOp.getLoc(), TypeRange(), curOp.getSymNameAttr(),
-        TypeAttr::get(this->getTypeConverter()->convertType(
-            curOp.getTypeAttr().getValue())));
-    forwardOpAttrs(curOp, newOp);
+      if constexpr (!SupportsMemRefOperand) {
+        if (isa<MemRefType>(originalOperand.getType())) {
+          // TODO: Support memref type in variable operands
+          return rewriter.notifyMatchFailure(curOp,
+                                             "memref is not supported yet");
+        }
+      }
+      convertedOperands.push_back(convertedOperand);
+    }
 
-    for (unsigned idx = 0; idx < curOp.getNumRegions(); idx++) {
-      rewriter.inlineRegionBefore(curOp.getRegion(idx), newOp.getRegion(idx),
-                                  newOp.getRegion(idx).end());
-      if (failed(rewriter.convertRegionTypes(&newOp.getRegion(idx),
+    // Create new operation.
+    auto newOp = rewriter.create<T>(curOp.getLoc(), resTypes, convertedOperands,
+                                    convertedAttrs);
+
+    // Translate regions.
+    for (auto [originalRegion, convertedRegion] :
+         llvm::zip_equal(curOp->getRegions(), newOp->getRegions())) {
+      rewriter.inlineRegionBefore(originalRegion, convertedRegion,
+                                  convertedRegion.end());
+      if (failed(rewriter.convertRegionTypes(&convertedRegion,
                                              *this->getTypeConverter())))
         return failure();
     }
 
-    rewriter.eraseOp(curOp);
+    // Delete old operation and replace result uses with those of the new one.
+    rewriter.replaceOp(curOp, newOp->getResults());
     return success();
   }
 };
 
-template <>
-void MultiRegionOpConversion<omp::PrivateClauseOp>::forwardOpAttrs(
-    omp::PrivateClauseOp curOp, omp::PrivateClauseOp newOp) const {
-  newOp.setDataSharingType(curOp.getDataSharingType());
-}
 } // namespace
 
 void mlir::configureOpenMPToLLVMConversionLegality(
     ConversionTarget &target, const LLVMTypeConverter &typeConverter) {
   target.addDynamicallyLegalOp<
-      omp::AtomicReadOp, omp::AtomicWriteOp, omp::CancellationPointOp,
-      omp::CancelOp, omp::CriticalDeclareOp, omp::DeclareMapperInfoOp,
-      omp::FlushOp, omp::MapBoundsOp, omp::MapInfoOp, omp::OrderedOp,
-      omp::ScanOp, omp::TargetEnterDataOp, omp::TargetExitDataOp,
-      omp::TargetUpdateOp, omp::ThreadprivateOp, omp::YieldOp>(
-      [&](Operation *op) {
-        return typeConverter.isLegal(op->getOperandTypes()) &&
-               typeConverter.isLegal(op->getResultTypes());
-      });
-  target.addDynamicallyLegalOp<
-      omp::AtomicUpdateOp, omp::CriticalOp, omp::DeclareMapperOp,
-      omp::DeclareReductionOp, omp::DistributeOp, omp::LoopNestOp, omp::LoopOp,
-      omp::MasterOp, omp::OrderedRegionOp, omp::ParallelOp,
-      omp::PrivateClauseOp, omp::SectionOp, omp::SectionsOp, omp::SimdOp,
-      omp::SingleOp, omp::TargetDataOp, omp::TargetOp, omp::TaskgroupOp,
-      omp::TaskloopOp, omp::TaskOp, omp::TeamsOp,
-      omp::WsloopOp>([&](Operation *op) {
-    return std::all_of(op->getRegions().begin(), op->getRegions().end(),
+#define GET_OP_LIST
+#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
+      >([&](Operation *op) {
+    return typeConverter.isLegal(op->getOperandTypes()) &&
+           typeConverter.isLegal(op->getResultTypes()) &&
+           std::all_of(op->getRegions().begin(), op->getRegions().end(),
                        [&](Region &region) {
                          return typeConverter.isLegal(&region);
                        }) &&
-           typeConverter.isLegal(op->getOperandTypes()) &&
-           typeConverter.isLegal(op->getResultTypes());
+           std::all_of(op->getAttrs().begin(), op->getAttrs().end(),
+                       [&](NamedAttribute attr) {
+                         auto typeAttr = dyn_cast<TypeAttr>(attr.getValue());
+                         return !typeAttr ||
+                                typeConverter.isLegal(typeAttr.getValue());
+                       });
   });
-  target.addDynamicallyLegalOp<omp::PrivateClauseOp>(
-      [&](omp::PrivateClauseOp op) -> bool {
-        return std::all_of(op->getRegions().begin(), op->getRegions().end(),
-                           [&](Region &region) {
-                             return typeConverter.isLegal(&region);
-                           }) &&
-               typeConverter.isLegal(op->getOperandTypes()) &&
-               typeConverter.isLegal(op->getResultTypes()) &&
-               typeConverter.isLegal(op.getType());
-      });
 }
 
 void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
@@ -295,36 +134,42 @@ void mlir::populateOpenMPToLLVMConversionPatterns(LLVMTypeConverter &converter,
       [&](omp::MapBoundsType type) -> Type { return type; });
 
   patterns.add<
-      AtomicReadOpConversion, DeclMapperOpConversion, MapInfoOpConversion,
-      MultiRegionOpConversion<omp::DeclareReductionOp>,
-      MultiRegionOpConversion<omp::PrivateClauseOp>,
-      RegionLessOpConversion<omp::CancellationPointOp>,
-      RegionLessOpConversion<omp::CancelOp>,
-      RegionLessOpConversion<omp::CriticalDeclareOp>,
-      RegionLessOpConversion<omp::DeclareMapperInfoOp>,
-      RegionLessOpConversion<omp::OrderedOp>,
-      RegionLessOpConversion<omp::ScanOp>,
-      RegionLessOpConversion<omp::TargetEnterDataOp>,
-      RegionLessOpConversion<omp::TargetExitDataOp>,
-      RegionLessOpConversion<omp::TargetUpdateOp>,
-      RegionLessOpConversion<omp::YieldOp>,
-      RegionLessOpWithVarOperandsConversion<omp::AtomicWriteOp>,
-      RegionLessOpWithVarOperandsConversion<omp::FlushOp>,
-      RegionLessOpWithVarOperandsConversion<omp::MapBoundsOp>,
-      RegionLessOpWithVarOperands...
[truncated]

@skatrak
Copy link
Member Author

skatrak commented Mar 19, 2025

One compromise I made for this initial implementation was adding the SupportsMemRefOperand template parameter to the OpenMPOpConversion class, which is only set to false for AtomicUpdateOp, AtomicWriteOp, FlushOp, MapBoundsOp and ThreadprivateOp.

This mirrors the previous behavior, where conversion would fail if an operand of MemRefType for the operation was found, and this was only checked for that (seemingly arbitrary) set of operations. I would like to propose not checking that in this conversion pass at all and simplify things further. If the types memrefs are converted to are an issue, it seems like we could trigger errors later during OpenMP to LLVM IR translation.

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really nice cleanup. LGTM. Do I understand correctly that this is not NFC because in some situations attributes were not forwarded previously but now they will be?

@skatrak
Copy link
Member Author

skatrak commented Mar 19, 2025

Do I understand correctly that this is not NFC because in some situations attributes were not forwarded previously but now they will be?

Thank you for the quick review! In general this will look at operands, results, attributes and entry block arguments of all operations, so it's very likely that it plugs some holes on different places. I also completed the list of operations passed to configureOpenMPToLLVMConversionLegality and populateOpenMPToLLVMConversionPatterns, since some were missing previously. So that'd be a functional change too.

Copy link
Member

@ergawy ergawy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks! Very nice simplification.

RegionOpWithVarOperandsConversion<omp::AtomicUpdateOp>>(converter);
OpenMPOpConversion<omp::AtomicCaptureOp>,
OpenMPOpConversion<omp::AtomicReadOp>,
OpenMPOpConversion<omp::AtomicUpdateOp, /*SupportsMemRefOperand=*/false>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small suggestion, feel free to ignore if you disagree. We can remove the template argument and constexpr if against the Op types for which the bool is false. Cleans the template a bit more and collects all the ops fro which the bool should have been false in a single location.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the suggestion, with that change I was able to also remove the need to manually list operations, so that's one less thing to remember adding when we create new ops.

@skatrak skatrak merged commit ff3341c into llvm:main Mar 20, 2025
11 checks passed
@skatrak skatrak deleted the simplify-mlir-omp-llvm branch March 20, 2025 14:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants