Skip to content

Commit a7262d2

Browse files
Hardcode84yiwu0b11
andauthored
[mlir][arith] Add overflow flags support to arith ops (llvm#77211)
Add overflow flags support to the following ops: * `arith.addi` * `arith.subi` * `arith.muli` Example of new syntax: ``` %res = arith.addi %arg1, %arg2 overflow<nsw> : i64 ``` Similar to existing LLVM dialect syntax ``` %res = llvm.add %arg1, %arg2 overflow<nsw> : i64 ``` Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later. Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise. Discussion https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025 --------- Co-authored-by: Yi Wu <[email protected]>
1 parent b5d4332 commit a7262d2

File tree

11 files changed

+321
-73
lines changed

11 files changed

+321
-73
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,24 @@
1818

1919
namespace mlir {
2020
namespace arith {
21-
// Map arithmetic fastmath enum values to LLVMIR enum values.
21+
/// Maps arithmetic fastmath enum values to LLVM enum values.
2222
LLVM::FastmathFlags
2323
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
2424

25-
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
25+
/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
26+
/// attribute.
2627
LLVM::FastmathFlagsAttr
2728
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
2829

30+
/// Maps arithmetic overflow enum values to LLVM enum values.
31+
LLVM::IntegerOverflowFlags
32+
convertArithOveflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
33+
34+
/// Creates an LLVM overflow attribute from a given arithmetic overflow
35+
/// attribute.
36+
LLVM::IntegerOverflowFlagsAttr
37+
convertArithOveflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
38+
2939
// Attribute converter that populates a NamedAttrList by removing the fastmath
3040
// attribute from the source operation attributes, and replacing it with an
3141
// equivalent LLVM fastmath attribute.
@@ -36,19 +46,46 @@ class AttrConvertFastMathToLLVM {
3646
// Copy the source attributes.
3747
convertedAttr = NamedAttrList{srcOp->getAttrs()};
3848
// Get the name of the arith fastmath attribute.
39-
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
49+
StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
4050
// Remove the source fastmath attribute.
41-
auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
51+
auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
4252
convertedAttr.erase(arithFMFAttrName));
4353
if (arithFMFAttr) {
44-
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
54+
StringRef targetAttrName = TargetOp::getFastmathAttrName();
4555
convertedAttr.set(targetAttrName,
4656
convertArithFastMathAttrToLLVM(arithFMFAttr));
4757
}
4858
}
4959

5060
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
5161

62+
private:
63+
NamedAttrList convertedAttr;
64+
};
65+
66+
// Attribute converter that populates a NamedAttrList by removing the overflow
67+
// attribute from the source operation attributes, and replacing it with an
68+
// equivalent LLVM overflow attribute.
69+
template <typename SourceOp, typename TargetOp>
70+
class AttrConvertOverflowToLLVM {
71+
public:
72+
AttrConvertOverflowToLLVM(SourceOp srcOp) {
73+
// Copy the source attributes.
74+
convertedAttr = NamedAttrList{srcOp->getAttrs()};
75+
// Get the name of the arith overflow attribute.
76+
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
77+
// Remove the source overflow attribute.
78+
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
79+
convertedAttr.erase(arithAttrName));
80+
if (arithAttr) {
81+
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
82+
convertedAttr.set(targetAttrName,
83+
convertArithOveflowAttrToLLVM(arithAttr));
84+
}
85+
}
86+
87+
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
88+
5289
private:
5390
NamedAttrList convertedAttr;
5491
};

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
133133
let assemblyFormat = "`<` $value `>`";
134134
}
135135

136+
//===----------------------------------------------------------------------===//
137+
// IntegerOverflowFlags
138+
//===----------------------------------------------------------------------===//
139+
140+
def IOFnone : I32BitEnumAttrCaseNone<"none">;
141+
def IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
142+
def IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
143+
144+
def IntegerOverflowFlags : I32BitEnumAttr<
145+
"IntegerOverflowFlags",
146+
"Integer overflow arith flags",
147+
[IOFnone, IOFnsw, IOFnuw]> {
148+
let separator = ", ";
149+
let cppNamespace = "::mlir::arith";
150+
let genSpecializedAttr = 0;
151+
let printBitEnumPrimaryGroups = 1;
152+
}
153+
154+
def Arith_IntegerOverflowAttr :
155+
EnumAttr<Arith_Dialect, IntegerOverflowFlags, "overflow"> {
156+
let assemblyFormat = "`<` $value `>`";
157+
}
158+
136159
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
137137
let results = (outs BoolLikeOfAnyRank:$result);
138138
}
139139

140+
class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
141+
Arith_BinaryOp<mnemonic, traits #
142+
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
143+
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
144+
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
145+
DefaultValuedAttr<
146+
Arith_IntegerOverflowAttr,
147+
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
148+
Results<(outs SignlessIntegerLike:$result)> {
149+
150+
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
151+
attr-dict `:` type($result) }];
152+
}
153+
140154
//===----------------------------------------------------------------------===//
141155
// ConstantOp
142156
//===----------------------------------------------------------------------===//
@@ -192,7 +206,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
192206
// AddIOp
193207
//===----------------------------------------------------------------------===//
194208

195-
def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
209+
def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
196210
let summary = "integer addition operation";
197211
let description = [{
198212
Performs N-bit addition on the operands. The operands are interpreted as
@@ -203,16 +217,23 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
203217

204218
The `addi` operation takes two operands and returns one result, each of
205219
these is required to be the same type. This type may be an integer scalar type,
206-
a vector whose element type is integer, or a tensor of integers. It has no
207-
standard attributes.
220+
a vector whose element type is integer, or a tensor of integers.
221+
222+
This op supports `nuw`/`nsw` overflow flags which stands stand for
223+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
224+
`nsw` flags are present, and an unsigned/signed overflow occurs
225+
(respectively), the result is poison.
208226

209227
Example:
210228

211229
```mlir
212230
// Scalar addition.
213231
%a = arith.addi %b, %c : i64
214232

215-
// SIMD vector element-wise addition, e.g. for Intel SSE.
233+
// Scalar addition with overflow flags.
234+
%a = arith.addi %b, %c overflow<nsw, nuw> : i64
235+
236+
// SIMD vector element-wise addition.
216237
%f = arith.addi %g, %h : vector<4xi32>
217238

218239
// Tensor element-wise addition.
@@ -278,21 +299,41 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
278299
// SubIOp
279300
//===----------------------------------------------------------------------===//
280301

281-
def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
302+
def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
282303
let summary = [{
283304
Integer subtraction operation.
284305
}];
285306
let description = [{
286-
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
287-
bitvectors. The result is represented by a bitvector containing the mathematical
288-
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
289-
integers use a two's complement representation, this operation is applicable on
307+
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
308+
bitvectors. The result is represented by a bitvector containing the mathematical
309+
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
310+
integers use a two's complement representation, this operation is applicable on
290311
both signed and unsigned integer operands.
291312

292313
The `subi` operation takes two operands and returns one result, each of
293-
these is required to be the same type. This type may be an integer scalar type,
294-
a vector whose element type is integer, or a tensor of integers. It has no
295-
standard attributes.
314+
these is required to be the same type. This type may be an integer scalar type,
315+
a vector whose element type is integer, or a tensor of integers.
316+
317+
This op supports `nuw`/`nsw` overflow flags which stands stand for
318+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
319+
`nsw` flags are present, and an unsigned/signed overflow occurs
320+
(respectively), the result is poison.
321+
322+
Example:
323+
324+
```mlir
325+
// Scalar subtraction.
326+
%a = arith.subi %b, %c : i64
327+
328+
// Scalar subtraction with overflow flags.
329+
%a = arith.subi %b, %c overflow<nsw, nuw> : i64
330+
331+
// SIMD vector element-wise subtraction.
332+
%f = arith.subi %g, %h : vector<4xi32>
333+
334+
// Tensor element-wise subtraction.
335+
%x = arith.subi %y, %z : tensor<4x?xi8>
336+
```
296337
}];
297338
let hasFolder = 1;
298339
let hasCanonicalizer = 1;
@@ -302,21 +343,41 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
302343
// MulIOp
303344
//===----------------------------------------------------------------------===//
304345

305-
def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
346+
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
306347
let summary = [{
307348
Integer multiplication operation.
308349
}];
309350
let description = [{
310-
Performs N-bit multiplication on the operands. The operands are interpreted as
311-
unsigned bitvectors. The result is represented by a bitvector containing the
312-
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
313-
Because `arith` integers use a two's complement representation, this operation is
351+
Performs N-bit multiplication on the operands. The operands are interpreted as
352+
unsigned bitvectors. The result is represented by a bitvector containing the
353+
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
354+
Because `arith` integers use a two's complement representation, this operation is
314355
applicable on both signed and unsigned integer operands.
315356

316357
The `muli` operation takes two operands and returns one result, each of
317-
these is required to be the same type. This type may be an integer scalar type,
318-
a vector whose element type is integer, or a tensor of integers. It has no
319-
standard attributes.
358+
these is required to be the same type. This type may be an integer scalar type,
359+
a vector whose element type is integer, or a tensor of integers.
360+
361+
This op supports `nuw`/`nsw` overflow flags which stands stand for
362+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
363+
`nsw` flags are present, and an unsigned/signed overflow occurs
364+
(respectively), the result is poison.
365+
366+
Example:
367+
368+
```mlir
369+
// Scalar multiplication.
370+
%a = arith.muli %b, %c : i64
371+
372+
// Scalar multiplication with overflow flags.
373+
%a = arith.muli %b, %c overflow<nsw, nuw> : i64
374+
375+
// SIMD vector element-wise multiplication.
376+
%f = arith.muli %g, %h : vector<4xi32>
377+
378+
// Tensor element-wise multiplication.
379+
%x = arith.muli %y, %z : tensor<4x?xi8>
380+
```
320381
}];
321382
let hasFolder = 1;
322383
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
4949
];
5050
}
5151

52+
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
53+
let description = [{
54+
Access to op integer overflow flags.
55+
}];
56+
57+
let cppNamespace = "::mlir::arith";
58+
59+
let methods = [
60+
InterfaceMethod<
61+
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
62+
/*returnType=*/ "IntegerOverflowFlagsAttr",
63+
/*methodName=*/ "getOverflowAttr",
64+
/*args=*/ (ins),
65+
/*methodBody=*/ [{}],
66+
/*defaultImpl=*/ [{
67+
auto op = cast<ConcreteOp>(this->getOperation());
68+
return op.getOverflowFlagsAttr();
69+
}]
70+
>,
71+
InterfaceMethod<
72+
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
73+
/*returnType=*/ "bool",
74+
/*methodName=*/ "hasNoUnsignedWrap",
75+
/*args=*/ (ins),
76+
/*methodBody=*/ [{}],
77+
/*defaultImpl=*/ [{
78+
auto op = cast<ConcreteOp>(this->getOperation());
79+
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
80+
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
81+
}]
82+
>,
83+
InterfaceMethod<
84+
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
85+
/*returnType=*/ "bool",
86+
/*methodName=*/ "hasNoSignedWrap",
87+
/*args=*/ (ins),
88+
/*methodBody=*/ [{}],
89+
/*defaultImpl=*/ [{
90+
auto op = cast<ConcreteOp>(this->getOperation());
91+
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
92+
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
93+
}]
94+
>,
95+
StaticInterfaceMethod<
96+
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
97+
for the operation}],
98+
/*returnType=*/ "StringRef",
99+
/*methodName=*/ "getIntegerOverflowAttrName",
100+
/*args=*/ (ins),
101+
/*methodBody=*/ [{}],
102+
/*defaultImpl=*/ [{
103+
return "overflowFlags";
104+
}]
105+
>
106+
];
107+
}
108+
52109
#endif // ARITH_OPS_INTERFACES

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
using namespace mlir;
1212

13-
// Map arithmetic fastmath enum values to LLVMIR enum values.
1413
LLVM::FastmathFlags
1514
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
1615
LLVM::FastmathFlags llvmFMF{};
@@ -22,17 +21,37 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
2221
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
2322
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
2423
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
25-
for (auto fmfMap : flags) {
26-
if (bitEnumContainsAny(arithFMF, fmfMap.first))
27-
llvmFMF = llvmFMF | fmfMap.second;
24+
for (auto [arithFlag, llvmFlag] : flags) {
25+
if (bitEnumContainsAny(arithFMF, arithFlag))
26+
llvmFMF = llvmFMF | llvmFlag;
2827
}
2928
return llvmFMF;
3029
}
3130

32-
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
3331
LLVM::FastmathFlagsAttr
3432
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
3533
arith::FastMathFlags arithFMF = fmfAttr.getValue();
3634
return LLVM::FastmathFlagsAttr::get(
3735
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
3836
}
37+
38+
LLVM::IntegerOverflowFlags mlir::arith::convertArithOveflowFlagsToLLVM(
39+
arith::IntegerOverflowFlags arithFlags) {
40+
LLVM::IntegerOverflowFlags llvmFlags{};
41+
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
42+
flags[] = {
43+
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
44+
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
45+
for (auto [arithFlag, llvmFlag] : flags) {
46+
if (bitEnumContainsAny(arithFlags, arithFlag))
47+
llvmFlags = llvmFlags | llvmFlag;
48+
}
49+
return llvmFlags;
50+
}
51+
52+
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOveflowAttrToLLVM(
53+
arith::IntegerOverflowFlagsAttr flagsAttr) {
54+
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
55+
return LLVM::IntegerOverflowFlagsAttr::get(
56+
flagsAttr.getContext(), convertArithOveflowFlagsToLLVM(arithFlags));
57+
}

0 commit comments

Comments
 (0)