Skip to content

[mlir][llvm] Port overflowFlags to a native operation property (RELAND) #89410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2024

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Apr 19, 2024

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties.

Reland to fix flang

…m#89312)

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.
@Mogball Mogball merged commit e553ac4 into llvm:main Apr 19, 2024
4 of 5 checks passed
@llvmbot llvmbot added mlir:llvm mlir flang Flang issues not falling into any other category flang:fir-hlfir flang:codegen labels Apr 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 19, 2024

@llvm/pr-subscribers-flang-fir-hlfir
@llvm/pr-subscribers-flang-codegen

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

Changes

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on operations as native properties.

Reland to fix flang


Patch is 23.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89410.diff

12 Files Affected:

  • (modified) flang/lib/Optimizer/CodeGen/CodeGen.cpp (+4-6)
  • (modified) mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h (+11-11)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/Pattern.h (+9-5)
  • (modified) mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h (+10-6)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td (+29-47)
  • (modified) mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td (+18-5)
  • (modified) mlir/include/mlir/Target/LLVMIR/ModuleImport.h (+1-2)
  • (modified) mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp (-7)
  • (modified) mlir/lib/Conversion/LLVMCommon/Pattern.cpp (+13-6)
  • (modified) mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp (+13-13)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp (+72-4)
  • (modified) mlir/lib/Target/LLVMIR/ModuleImport.cpp (+3-4)
diff --git a/flang/lib/Optimizer/CodeGen/CodeGen.cpp b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
index d909bda89cdeb4..921eac2f8f4b60 100644
--- a/flang/lib/Optimizer/CodeGen/CodeGen.cpp
+++ b/flang/lib/Optimizer/CodeGen/CodeGen.cpp
@@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion
     const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
     TypePair baseBoxTyPair =
         baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
-    mlir::LLVM::IntegerOverflowFlagsAttr nsw =
-        mlir::LLVM::IntegerOverflowFlagsAttr::get(
-            rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
+    mlir::LLVM::IntegerOverflowFlags nsw =
+        mlir::LLVM::IntegerOverflowFlags::nsw;
 
     // For each dimension of the array, generate the offset calculation.
     for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
@@ -2396,9 +2395,8 @@ struct CoordinateOpConversion
     auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
     mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
     mlir::Type byteTy = ::getI8Type(coor.getContext());
-    mlir::LLVM::IntegerOverflowFlagsAttr nsw =
-        mlir::LLVM::IntegerOverflowFlagsAttr::get(
-            rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
+    mlir::LLVM::IntegerOverflowFlags nsw =
+        mlir::LLVM::IntegerOverflowFlags::nsw;
 
     for (unsigned i = 1, last = operands.size(); i < last; ++i) {
       if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
diff --git a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
index 0891e2ba7be760..7ffc8613317603 100644
--- a/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
+++ b/mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
@@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
 LLVM::IntegerOverflowFlags
 convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
 
-/// Creates an LLVM overflow attribute from a given arithmetic overflow
-/// attribute.
-LLVM::IntegerOverflowFlagsAttr
-convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
-
 /// Creates an LLVM rounding mode enum value from a given arithmetic rounding
 /// mode enum value.
 LLVM::RoundingMode
@@ -72,6 +67,9 @@ class AttrConvertFastMathToLLVM {
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   NamedAttrList convertedAttr;
@@ -89,19 +87,18 @@ class AttrConvertOverflowToLLVM {
     // Get the name of the arith overflow attribute.
     StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
     // Remove the source overflow attribute.
-    auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
-        convertedAttr.erase(arithAttrName));
-    if (arithAttr) {
-      StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
-      convertedAttr.set(targetAttrName,
-                        convertArithOverflowAttrToLLVM(arithAttr));
+    if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
+            convertedAttr.erase(arithAttrName))) {
+      overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
     }
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }
 
 private:
   NamedAttrList convertedAttr;
+  LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
 };
 
 template <typename SourceOp, typename TargetOp>
@@ -132,6 +129,9 @@ class AttrConverterConstrainedFPToLLVM {
   }
 
   ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   NamedAttrList convertedAttr;
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index f362167ee93249..f3bf5b66398e09 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -11,6 +11,7 @@
 
 #include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
@@ -18,13 +19,16 @@ class CallOpInterface;
 
 namespace LLVM {
 namespace detail {
+/// Handle generically setting flags as native properties on LLVM operations.
+void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);
+
 /// Replaces the given operation "op" with a new operation of type "targetOp"
 /// and given operands.
-LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
-                              ValueRange operands,
-                              ArrayRef<NamedAttribute> targetAttrs,
-                              const LLVMTypeConverter &typeConverter,
-                              ConversionPatternRewriter &rewriter);
+LogicalResult oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 
 } // namespace detail
 } // namespace LLVM
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
index 279175b6128fc7..964281592cc65e 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
@@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
     std::function<Value(Type, ValueRange)> createOperand,
     ConversionPatternRewriter &rewriter);
 
-LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
-                                    ValueRange operands,
-                                    ArrayRef<NamedAttribute> targetAttrs,
-                                    const LLVMTypeConverter &typeConverter,
-                                    ConversionPatternRewriter &rewriter);
+LogicalResult vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
 } // namespace detail
 } // namespace LLVM
 
@@ -70,6 +70,9 @@ class AttrConvertPassThrough {
   AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
 
   ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
+  LLVM::IntegerOverflowFlags getOverflowFlags() const {
+    return LLVM::IntegerOverflowFlags::none;
+  }
 
 private:
   ArrayRef<NamedAttribute> srcAttrs;
@@ -100,7 +103,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {
 
     return LLVM::detail::vectorOneToOneRewrite(
         op, TargetOp::getOperationName(), adaptor.getOperands(),
-        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
+        attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
+        attrConvert.getOverflowFlags());
   }
 };
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
index cee752aeb269b7..7085f81e203a1e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
@@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {
 
 def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
   let description = [{
-    Access to op integer overflow flags.
+    This interface defines an LLVM operation with integer overflow flags and
+    provides a uniform API for accessing them.
   }];
 
   let cppNamespace = "::mlir::LLVM";
 
   let methods = [
-    InterfaceMethod<
-      /*desc=*/        "Returns an IntegerOverflowFlagsAttr attribute for the operation",
-      /*returnType=*/  "IntegerOverflowFlagsAttr",
-      /*methodName=*/  "getOverflowAttr",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        return op.getOverflowFlagsAttr();
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Returns whether the operation has the No Unsigned Wrap keyword",
-      /*returnType=*/  "bool",
-      /*methodName=*/  "hasNoUnsignedWrap",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
-      }]
-      >,
-    InterfaceMethod<
-      /*desc=*/        "Returns whether the operation has the No Signed Wrap keyword",
-      /*returnType=*/  "bool",
-      /*methodName=*/  "hasNoSignedWrap",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        auto op = cast<ConcreteOp>(this->getOperation());
-        IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
-        return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
-      }]
-      >,
-    StaticInterfaceMethod<
-      /*desc=*/        [{Returns the name of the IntegerOverflowFlagsAttr attribute
-                         for the operation}],
-      /*returnType=*/  "StringRef",
-      /*methodName=*/  "getIntegerOverflowAttrName",
-      /*args=*/        (ins),
-      /*methodBody=*/  [{}],
-      /*defaultImpl=*/ [{
-        return "overflowFlags";
-      }]
-      >
+    InterfaceMethod<[{
+      Get the integer overflow flags for the operation.
+    }], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
+      return $_op.getProperties().overflowFlags;
+    }]>,
+    InterfaceMethod<[{
+      Set the integer overflow flags for the operation.
+    }], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
+      $_op.getProperties().overflowFlags = flags;
+    }]>,
+    InterfaceMethod<[{
+      Returns whether the operation has the No Unsigned Wrap keyword.
+    }], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
+      return bitEnumContainsAll($_op.getOverflowFlags(),
+                                IntegerOverflowFlags::nuw);
+    }]>,
+    InterfaceMethod<[{
+      Returns whether the operation has the No Signed Wrap keyword.
+    }], "bool", "hasNoSignedWrap", (ins), [{}], [{
+      return bitEnumContainsAll($_op.getOverflowFlags(),
+                                IntegerOverflowFlags::nsw);
+    }]>,
+    StaticInterfaceMethod<[{
+      Get the attribute name of the overflow flags property.
+    }], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
+      return "overflowFlags";
+    }]>,
   ];
 }
 
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
index f8f9264b3889be..eedae4b9bb7c8e 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
@@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
                                    list<Trait> traits = []> :
     LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
     !listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
-  dag iofArg = (
-    ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
+  dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
   let arguments = !con(commonArgs, iofArg);
+
+  let builders = [
+    OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs, 
+                   "IntegerOverflowFlags":$overflowFlags), [{
+      build($_builder, $_state, type, lhs, rhs);
+      $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+    }]>,
+    OpBuilder<(ins "Value":$lhs, "Value":$rhs, 
+                   "IntegerOverflowFlags":$overflowFlags), [{
+      build($_builder, $_state, lhs, rhs);
+      $_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
+    }]>
+  ];
+
   string mlirBuilder = [{
     auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
-    moduleImport.setIntegerOverflowFlagsAttr(inst, op);
+    moduleImport.setIntegerOverflowFlags(inst, op);
     $res = op;
   }];
   let assemblyFormat = [{
-    $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
-    custom<LLVMOpAttrs>(attr-dict) `:` type($res)
+    $lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
+    `` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
   }];
   string llvmBuilder =
     "$res = builder.Create" # instName #
diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
index 1a188b1d042854..04d098d38155bf 100644
--- a/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
+++ b/mlir/include/mlir/Target/LLVMIR/ModuleImport.h
@@ -183,8 +183,7 @@ class ModuleImport {
   /// Sets the integer overflow flags (nsw/nuw) attribute for the imported
   /// operation `op` given the original instruction `inst`. Asserts if the
   /// operation does not implement the integer overflow flag interface.
-  void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
-                                   Operation *op) const;
+  void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;
 
   /// Sets the fastmath flags attribute for the imported operation `op` given
   /// the original instruction `inst`. Asserts if the operation does not
diff --git a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
index f12eba98480d33..cf60a048f782c6 100644
--- a/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
+++ b/mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
@@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
   return llvmFlags;
 }
 
-LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
-    arith::IntegerOverflowFlagsAttr flagsAttr) {
-  arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
-  return LLVM::IntegerOverflowFlagsAttr::get(
-      flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
-}
-
 LLVM::RoundingMode
 mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
   switch (roundingMode) {
diff --git a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
index 83c31a204efc7e..1886dfa870961a 100644
--- a/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/Pattern.cpp
@@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
 // Detail methods
 //===----------------------------------------------------------------------===//
 
+void LLVM::detail::setNativeProperties(Operation *op,
+                                       IntegerOverflowFlags overflowFlags) {
+  if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
+    iface.setOverflowFlags(overflowFlags);
+}
+
 /// Replaces the given operation "op" with a new operation of type "targetOp"
 /// and given operands.
-LogicalResult
-LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
-                              ValueRange operands,
-                              ArrayRef<NamedAttribute> targetAttrs,
-                              const LLVMTypeConverter &typeConverter,
-                              ConversionPatternRewriter &rewriter) {
+LogicalResult LLVM::detail::oneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags) {
   unsigned numResults = op->getNumResults();
 
   SmallVector<Type> resultTypes;
@@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
       rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
                       resultTypes, targetAttrs);
 
+  setNativeProperties(newOp, overflowFlags);
+
   // If the operation produced 0 or 1 result, return them immediately.
   if (numResults == 0)
     return rewriter.eraseOp(op), success();
diff --git a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
index 544bcc71aca1b5..626135c10a3e96 100644
--- a/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
@@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
   return success();
 }
 
-LogicalResult
-LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
-                                    ValueRange operands,
-                                    ArrayRef<NamedAttribute> targetAttrs,
-                                    const LLVMTypeConverter &typeConverter,
-                                    ConversionPatternRewriter &rewriter) {
+LogicalResult LLVM::detail::vectorOneToOneRewrite(
+    Operation *op, StringRef targetOp, ValueRange operands,
+    ArrayRef<NamedAttribute> targetAttrs,
+    const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
+    IntegerOverflowFlags overflowFlags) {
   assert(!operands.empty());
 
   // Cannot convert ops if their operands are not of LLVM type.
@@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
   auto llvmNDVectorTy = operands[0].getType();
   if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
     return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
-                           rewriter);
+                           rewriter, overflowFlags);
 
-  auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
-                                                         ValueRange operands) {
-    return rewriter
-        .create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
-                llvm1DVectorTy, targetAttrs)
-        ->getResult(0);
+  auto callback = [op, targetOp, targetAttrs, overflowFlags,
+                   &rewriter](Type llvm1DVectorTy, ValueRange operands) {
+    Operation *newOp =
+        rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
+                        operands, llvm1DVectorTy, targetAttrs);
+    LLVM::detail::setNativeProperties(newOp, overflowFlags);
+    return newOp->getResult(0);
   };
 
   return handleMultidimensionalVectors(op, operands, typeConverter, callback,
diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
index 1db506e286b3c0..78ff24dae68b4c 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
@@ -47,6 +47,74 @@ using mlir::LLVM::linkage::getMaxEnumValForLinkage;
 
 #include "mlir/Dialect/LLVMIR/LLVMOpsDialect.cpp.inc"
 
+//===----------------------------------------------------------------------===//
+// Property Helpers
+//===----------------------------------------------------------------------===//
+
+//===----------------------------------------------------------------------===//
+// IntegerOverflowFlags
+
+namespace mlir {
+static Attribute convertToAttribute(MLIRContext *ctx,
+                                    IntegerOverflowFlags flags) {
+  return IntegerOverflowFlagsAttr::get(ctx, flags);
+}
+
+static LogicalResult
+convertFromAttribute(IntegerOverflowFlags &flags, Attribute attr,
+                     function_ref<InFlightDiagnostic()> emitError) {
+  auto flagsAttr = dyn_cast<IntegerOverflowFlagsAttr>(attr);
+  if (!flagsAttr) {
+    return emitError() << "expected 'overflowFlags' attribute to be an "
+                          "IntegerOverflowFlagsAttr, but got "
+                       << attr;
+  }
+  flags = flagsAttr.getValue();
+  return success();
+}
+} // namespace mlir
+
+static ParseResult parseOverflowFlags(AsmParser &p,
+                 ...
[truncated]

aniplcc pushed a commit to aniplcc/llvm-project that referenced this pull request Apr 21, 2024
…AND) (llvm#89410)

This PR changes the LLVM dialect's IntegerOverflowFlags to be stored on
operations as native properties.

Reland to fix flang
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:codegen flang:fir-hlfir flang Flang issues not falling into any other category mlir:llvm mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants