Skip to content

[mlir][llvm] Port overflowFlags to a native operation property (RELAND) #89410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions flang/lib/Optimizer/CodeGen/CodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2110,9 +2110,8 @@ struct XArrayCoorOpConversion
const bool baseIsBoxed = coor.getMemref().getType().isa<fir::BaseBoxType>();
TypePair baseBoxTyPair =
baseIsBoxed ? getBoxTypePair(coor.getMemref().getType()) : TypePair{};
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get(
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
mlir::LLVM::IntegerOverflowFlags nsw =
mlir::LLVM::IntegerOverflowFlags::nsw;

// For each dimension of the array, generate the offset calculation.
for (unsigned i = 0; i < rank; ++i, ++indexOffset, ++shapeOffset,
Expand Down Expand Up @@ -2396,9 +2395,8 @@ struct CoordinateOpConversion
auto cpnTy = fir::dyn_cast_ptrOrBoxEleTy(boxObjTy);
mlir::Type llvmPtrTy = ::getLlvmPtrType(coor.getContext());
mlir::Type byteTy = ::getI8Type(coor.getContext());
mlir::LLVM::IntegerOverflowFlagsAttr nsw =
mlir::LLVM::IntegerOverflowFlagsAttr::get(
rewriter.getContext(), mlir::LLVM::IntegerOverflowFlags::nsw);
mlir::LLVM::IntegerOverflowFlags nsw =
mlir::LLVM::IntegerOverflowFlags::nsw;

for (unsigned i = 1, last = operands.size(); i < last; ++i) {
if (auto arrTy = cpnTy.dyn_cast<fir::SequenceType>()) {
Expand Down
22 changes: 11 additions & 11 deletions mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,6 @@ convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
LLVM::IntegerOverflowFlags
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);

/// Creates an LLVM overflow attribute from a given arithmetic overflow
/// attribute.
LLVM::IntegerOverflowFlagsAttr
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);

/// Creates an LLVM rounding mode enum value from a given arithmetic rounding
/// mode enum value.
LLVM::RoundingMode
Expand Down Expand Up @@ -72,6 +67,9 @@ class AttrConvertFastMathToLLVM {
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}

private:
NamedAttrList convertedAttr;
Expand All @@ -89,19 +87,18 @@ class AttrConvertOverflowToLLVM {
// Get the name of the arith overflow attribute.
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
// Remove the source overflow attribute.
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName));
if (arithAttr) {
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
convertedAttr.set(targetAttrName,
convertArithOverflowAttrToLLVM(arithAttr));
if (auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
convertedAttr.erase(arithAttrName))) {
overflowFlags = convertArithOverflowFlagsToLLVM(arithAttr.getValue());
}
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const { return overflowFlags; }

private:
NamedAttrList convertedAttr;
LLVM::IntegerOverflowFlags overflowFlags = LLVM::IntegerOverflowFlags::none;
};

template <typename SourceOp, typename TargetOp>
Expand Down Expand Up @@ -132,6 +129,9 @@ class AttrConverterConstrainedFPToLLVM {
}

ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}

private:
NamedAttrList convertedAttr;
Expand Down
14 changes: 9 additions & 5 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,24 @@

#include "mlir/Conversion/LLVMCommon/MemRefBuilder.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class CallOpInterface;

namespace LLVM {
namespace detail {
/// Handle generically setting flags as native properties on LLVM operations.
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags);

/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);

} // namespace detail
} // namespace LLVM
Expand Down
16 changes: 10 additions & 6 deletions mlir/include/mlir/Conversion/LLVMCommon/VectorPattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ LogicalResult handleMultidimensionalVectors(
std::function<Value(Type, ValueRange)> createOperand,
ConversionPatternRewriter &rewriter);

LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter);
LogicalResult vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
} // namespace detail
} // namespace LLVM

Expand All @@ -70,6 +70,9 @@ class AttrConvertPassThrough {
AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}

ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
LLVM::IntegerOverflowFlags getOverflowFlags() const {
return LLVM::IntegerOverflowFlags::none;
}

private:
ArrayRef<NamedAttribute> srcAttrs;
Expand Down Expand Up @@ -100,7 +103,8 @@ class VectorConvertToLLVMPattern : public ConvertOpToLLVMPattern<SourceOp> {

return LLVM::detail::vectorOneToOneRewrite(
op, TargetOp::getOperationName(), adaptor.getOperands(),
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter);
attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
attrConvert.getOverflowFlags());
}
};
} // namespace mlir
Expand Down
76 changes: 29 additions & 47 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,58 +50,40 @@ def FastmathFlagsInterface : OpInterface<"FastmathFlagsInterface"> {

def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface"> {
let description = [{
Access to op integer overflow flags.
This interface defines an LLVM operation with integer overflow flags and
provides a uniform API for accessing them.
}];

let cppNamespace = "::mlir::LLVM";

let methods = [
InterfaceMethod<
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
/*returnType=*/ "IntegerOverflowFlagsAttr",
/*methodName=*/ "getOverflowAttr",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
return op.getOverflowFlagsAttr();
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoUnsignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
}]
>,
InterfaceMethod<
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
/*returnType=*/ "bool",
/*methodName=*/ "hasNoSignedWrap",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
auto op = cast<ConcreteOp>(this->getOperation());
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
}]
>,
StaticInterfaceMethod<
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
for the operation}],
/*returnType=*/ "StringRef",
/*methodName=*/ "getIntegerOverflowAttrName",
/*args=*/ (ins),
/*methodBody=*/ [{}],
/*defaultImpl=*/ [{
return "overflowFlags";
}]
>
InterfaceMethod<[{
Get the integer overflow flags for the operation.
}], "IntegerOverflowFlags", "getOverflowFlags", (ins), [{}], [{
return $_op.getProperties().overflowFlags;
}]>,
InterfaceMethod<[{
Set the integer overflow flags for the operation.
}], "void", "setOverflowFlags", (ins "IntegerOverflowFlags":$flags), [{}], [{
$_op.getProperties().overflowFlags = flags;
}]>,
InterfaceMethod<[{
Returns whether the operation has the No Unsigned Wrap keyword.
}], "bool", "hasNoUnsignedWrap", (ins), [{}], [{
return bitEnumContainsAll($_op.getOverflowFlags(),
IntegerOverflowFlags::nuw);
}]>,
InterfaceMethod<[{
Returns whether the operation has the No Signed Wrap keyword.
}], "bool", "hasNoSignedWrap", (ins), [{}], [{
return bitEnumContainsAll($_op.getOverflowFlags(),
IntegerOverflowFlags::nsw);
}]>,
StaticInterfaceMethod<[{
Get the attribute name of the overflow flags property.
}], "StringRef", "getOverflowFlagsAttrName", (ins), [{}], [{
return "overflowFlags";
}]>,
];
}

Expand Down
23 changes: 18 additions & 5 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,30 @@ class LLVM_IntArithmeticOpWithOverflowFlag<string mnemonic, string instName,
list<Trait> traits = []> :
LLVM_ArithmeticOpBase<AnySignlessInteger, mnemonic, instName,
!listconcat([DeclareOpInterfaceMethods<IntegerOverflowFlagsInterface>], traits)> {
dag iofArg = (
ins DefaultValuedAttr<LLVM_IntegerOverflowFlagsAttr, "{}">:$overflowFlags);
dag iofArg = (ins EnumProperty<"IntegerOverflowFlags">:$overflowFlags);
let arguments = !con(commonArgs, iofArg);

let builders = [
OpBuilder<(ins "Type":$type, "Value":$lhs, "Value":$rhs,
"IntegerOverflowFlags":$overflowFlags), [{
build($_builder, $_state, type, lhs, rhs);
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
}]>,
OpBuilder<(ins "Value":$lhs, "Value":$rhs,
"IntegerOverflowFlags":$overflowFlags), [{
build($_builder, $_state, lhs, rhs);
$_state.getOrAddProperties<Properties>().overflowFlags = overflowFlags;
}]>
];

string mlirBuilder = [{
auto op = $_builder.create<$_qualCppClassName>($_location, $lhs, $rhs);
moduleImport.setIntegerOverflowFlagsAttr(inst, op);
moduleImport.setIntegerOverflowFlags(inst, op);
$res = op;
}];
let assemblyFormat = [{
$lhs `,` $rhs (`overflow` `` $overflowFlags^)?
custom<LLVMOpAttrs>(attr-dict) `:` type($res)
$lhs `,` $rhs `` custom<OverflowFlags>($overflowFlags)
`` custom<LLVMOpAttrs>(attr-dict) `:` type($res)
}];
string llvmBuilder =
"$res = builder.Create" # instName #
Expand Down
3 changes: 1 addition & 2 deletions mlir/include/mlir/Target/LLVMIR/ModuleImport.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ class ModuleImport {
/// Sets the integer overflow flags (nsw/nuw) attribute for the imported
/// operation `op` given the original instruction `inst`. Asserts if the
/// operation does not implement the integer overflow flag interface.
void setIntegerOverflowFlagsAttr(llvm::Instruction *inst,
Operation *op) const;
void setIntegerOverflowFlags(llvm::Instruction *inst, Operation *op) const;

/// Sets the fastmath flags attribute for the imported operation `op` given
/// the original instruction `inst`. Asserts if the operation does not
Expand Down
7 changes: 0 additions & 7 deletions mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,6 @@ LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
return llvmFlags;
}

LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
arith::IntegerOverflowFlagsAttr flagsAttr) {
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
return LLVM::IntegerOverflowFlagsAttr::get(
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
}

LLVM::RoundingMode
mlir::arith::convertArithRoundingModeToLLVM(arith::RoundingMode roundingMode) {
switch (roundingMode) {
Expand Down
19 changes: 13 additions & 6 deletions mlir/lib/Conversion/LLVMCommon/Pattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,14 +329,19 @@ LogicalResult ConvertToLLVMPattern::copyUnrankedDescriptors(
// Detail methods
//===----------------------------------------------------------------------===//

void LLVM::detail::setNativeProperties(Operation *op,
IntegerOverflowFlags overflowFlags) {
if (auto iface = dyn_cast<IntegerOverflowFlagsInterface>(op))
iface.setOverflowFlags(overflowFlags);
}

/// Replaces the given operation "op" with a new operation of type "targetOp"
/// and given operands.
LogicalResult
LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
LogicalResult LLVM::detail::oneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags) {
unsigned numResults = op->getNumResults();

SmallVector<Type> resultTypes;
Expand All @@ -352,6 +357,8 @@ LLVM::detail::oneToOneRewrite(Operation *op, StringRef targetOp,
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
resultTypes, targetAttrs);

setNativeProperties(newOp, overflowFlags);

// If the operation produced 0 or 1 result, return them immediately.
if (numResults == 0)
return rewriter.eraseOp(op), success();
Expand Down
26 changes: 13 additions & 13 deletions mlir/lib/Conversion/LLVMCommon/VectorPattern.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,11 @@ LogicalResult LLVM::detail::handleMultidimensionalVectors(
return success();
}

LogicalResult
LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter) {
LogicalResult LLVM::detail::vectorOneToOneRewrite(
Operation *op, StringRef targetOp, ValueRange operands,
ArrayRef<NamedAttribute> targetAttrs,
const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
IntegerOverflowFlags overflowFlags) {
assert(!operands.empty());

// Cannot convert ops if their operands are not of LLVM type.
Expand All @@ -118,14 +117,15 @@ LLVM::detail::vectorOneToOneRewrite(Operation *op, StringRef targetOp,
auto llvmNDVectorTy = operands[0].getType();
if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
rewriter);
rewriter, overflowFlags);

auto callback = [op, targetOp, targetAttrs, &rewriter](Type llvm1DVectorTy,
ValueRange operands) {
return rewriter
.create(op->getLoc(), rewriter.getStringAttr(targetOp), operands,
llvm1DVectorTy, targetAttrs)
->getResult(0);
auto callback = [op, targetOp, targetAttrs, overflowFlags,
&rewriter](Type llvm1DVectorTy, ValueRange operands) {
Operation *newOp =
rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
operands, llvm1DVectorTy, targetAttrs);
LLVM::detail::setNativeProperties(newOp, overflowFlags);
return newOp->getResult(0);
};

return handleMultidimensionalVectors(op, operands, typeConverter, callback,
Expand Down
Loading
Loading