Skip to content

Commit 8934b10

Browse files
jpienaaryiwu0b11
andauthored
[mlir][arith] Add overflow flags support to arith ops (#78376)
Add overflow flags support to the following ops: * `arith.addi` * `arith.subi` * `arith.muli` Example of new syntax: ``` %res = arith.addi %arg1, %arg2 overflow<nsw> : i64 ``` Similar to existing LLVM dialect syntax ``` %res = llvm.add %arg1, %arg2 overflow<nsw> : i64 ``` Tablegen canonicalization patterns updated to always drop flags, proper support with tests will be added later. Updated LLVMIR translation as part of this commit as it currenly written in a way that it will crash when new attributes added to arith ops otherwise. Also lower `arith` overflow flags to corresponding SPIR-V op decorations Discussion https://discourse.llvm.org/t/rfc-integer-overflow-flags-support-in-arith-dialect/76025 This effectively rolls forward #77211, #77700 and #77714 while adding a test to ensure the Python usage is not broken. More follow up needed but unrelated to the core change here. The changes here are minimal and just correspond to "textual namespacing" ODS side, no C++ or Python changes were needed. --------- --------- Co-authored-by: Ivan Butygin <[email protected]>, Yi Wu <[email protected]>
1 parent 9745c13 commit 8934b10

File tree

15 files changed

+443
-77
lines changed

15 files changed

+443
-77
lines changed

mlir/include/mlir/Conversion/ArithCommon/AttrToLLVMConverter.h

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,24 @@
1818

1919
namespace mlir {
2020
namespace arith {
21-
// Map arithmetic fastmath enum values to LLVMIR enum values.
21+
/// Maps arithmetic fastmath enum values to LLVM enum values.
2222
LLVM::FastmathFlags
2323
convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF);
2424

25-
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
25+
/// Creates an LLVM fastmath attribute from a given arithmetic fastmath
26+
/// attribute.
2627
LLVM::FastmathFlagsAttr
2728
convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr);
2829

30+
/// Maps arithmetic overflow enum values to LLVM enum values.
31+
LLVM::IntegerOverflowFlags
32+
convertArithOverflowFlagsToLLVM(arith::IntegerOverflowFlags arithFlags);
33+
34+
/// Creates an LLVM overflow attribute from a given arithmetic overflow
35+
/// attribute.
36+
LLVM::IntegerOverflowFlagsAttr
37+
convertArithOverflowAttrToLLVM(arith::IntegerOverflowFlagsAttr flagsAttr);
38+
2939
// Attribute converter that populates a NamedAttrList by removing the fastmath
3040
// attribute from the source operation attributes, and replacing it with an
3141
// equivalent LLVM fastmath attribute.
@@ -36,19 +46,46 @@ class AttrConvertFastMathToLLVM {
3646
// Copy the source attributes.
3747
convertedAttr = NamedAttrList{srcOp->getAttrs()};
3848
// Get the name of the arith fastmath attribute.
39-
llvm::StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
49+
StringRef arithFMFAttrName = SourceOp::getFastMathAttrName();
4050
// Remove the source fastmath attribute.
41-
auto arithFMFAttr = dyn_cast_or_null<arith::FastMathFlagsAttr>(
51+
auto arithFMFAttr = dyn_cast_if_present<arith::FastMathFlagsAttr>(
4252
convertedAttr.erase(arithFMFAttrName));
4353
if (arithFMFAttr) {
44-
llvm::StringRef targetAttrName = TargetOp::getFastmathAttrName();
54+
StringRef targetAttrName = TargetOp::getFastmathAttrName();
4555
convertedAttr.set(targetAttrName,
4656
convertArithFastMathAttrToLLVM(arithFMFAttr));
4757
}
4858
}
4959

5060
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
5161

62+
private:
63+
NamedAttrList convertedAttr;
64+
};
65+
66+
// Attribute converter that populates a NamedAttrList by removing the overflow
67+
// attribute from the source operation attributes, and replacing it with an
68+
// equivalent LLVM overflow attribute.
69+
template <typename SourceOp, typename TargetOp>
70+
class AttrConvertOverflowToLLVM {
71+
public:
72+
AttrConvertOverflowToLLVM(SourceOp srcOp) {
73+
// Copy the source attributes.
74+
convertedAttr = NamedAttrList{srcOp->getAttrs()};
75+
// Get the name of the arith overflow attribute.
76+
StringRef arithAttrName = SourceOp::getIntegerOverflowAttrName();
77+
// Remove the source overflow attribute.
78+
auto arithAttr = dyn_cast_if_present<arith::IntegerOverflowFlagsAttr>(
79+
convertedAttr.erase(arithAttrName));
80+
if (arithAttr) {
81+
StringRef targetAttrName = TargetOp::getIntegerOverflowAttrName();
82+
convertedAttr.set(targetAttrName,
83+
convertArithOverflowAttrToLLVM(arithAttr));
84+
}
85+
}
86+
87+
ArrayRef<NamedAttribute> getAttrs() const { return convertedAttr.getAttrs(); }
88+
5289
private:
5390
NamedAttrList convertedAttr;
5491
};

mlir/include/mlir/Dialect/Arith/IR/ArithBase.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,4 +133,27 @@ def Arith_FastMathAttr :
133133
let assemblyFormat = "`<` $value `>`";
134134
}
135135

136+
//===----------------------------------------------------------------------===//
137+
// Arith_IntegerOverflowFlags
138+
//===----------------------------------------------------------------------===//
139+
140+
def Arith_IOFnone : I32BitEnumAttrCaseNone<"none">;
141+
def Arith_IOFnsw : I32BitEnumAttrCaseBit<"nsw", 0>;
142+
def Arith_IOFnuw : I32BitEnumAttrCaseBit<"nuw", 1>;
143+
144+
def Arith_IntegerOverflowFlags : I32BitEnumAttr<
145+
"IntegerOverflowFlags",
146+
"Integer overflow arith flags",
147+
[Arith_IOFnone, Arith_IOFnsw, Arith_IOFnuw]> {
148+
let separator = ", ";
149+
let cppNamespace = "::mlir::arith";
150+
let genSpecializedAttr = 0;
151+
let printBitEnumPrimaryGroups = 1;
152+
}
153+
154+
def Arith_IntegerOverflowAttr :
155+
EnumAttr<Arith_Dialect, Arith_IntegerOverflowFlags, "overflow"> {
156+
let assemblyFormat = "`<` $value `>`";
157+
}
158+
136159
#endif // ARITH_BASE

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 81 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,20 @@ class Arith_CompareOpOfAnyRank<string mnemonic, list<Trait> traits = []> :
137137
let results = (outs BoolLikeOfAnyRank:$result);
138138
}
139139

140+
class Arith_IntBinaryOpWithOverflowFlags<string mnemonic, list<Trait> traits = []> :
141+
Arith_BinaryOp<mnemonic, traits #
142+
[Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>,
143+
DeclareOpInterfaceMethods<ArithIntegerOverflowFlagsInterface>]>,
144+
Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs,
145+
DefaultValuedAttr<
146+
Arith_IntegerOverflowAttr,
147+
"::mlir::arith::IntegerOverflowFlags::none">:$overflowFlags)>,
148+
Results<(outs SignlessIntegerLike:$result)> {
149+
150+
let assemblyFormat = [{ $lhs `,` $rhs (`overflow` `` $overflowFlags^)?
151+
attr-dict `:` type($result) }];
152+
}
153+
140154
//===----------------------------------------------------------------------===//
141155
// ConstantOp
142156
//===----------------------------------------------------------------------===//
@@ -192,7 +206,7 @@ def Arith_ConstantOp : Op<Arith_Dialect, "constant",
192206
// AddIOp
193207
//===----------------------------------------------------------------------===//
194208

195-
def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
209+
def Arith_AddIOp : Arith_IntBinaryOpWithOverflowFlags<"addi", [Commutative]> {
196210
let summary = "integer addition operation";
197211
let description = [{
198212
Performs N-bit addition on the operands. The operands are interpreted as
@@ -203,16 +217,23 @@ def Arith_AddIOp : Arith_TotalIntBinaryOp<"addi", [Commutative]> {
203217

204218
The `addi` operation takes two operands and returns one result, each of
205219
these is required to be the same type. This type may be an integer scalar type,
206-
a vector whose element type is integer, or a tensor of integers. It has no
207-
standard attributes.
220+
a vector whose element type is integer, or a tensor of integers.
221+
222+
This op supports `nuw`/`nsw` overflow flags which stands stand for
223+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
224+
`nsw` flags are present, and an unsigned/signed overflow occurs
225+
(respectively), the result is poison.
208226

209227
Example:
210228

211229
```mlir
212230
// Scalar addition.
213231
%a = arith.addi %b, %c : i64
214232

215-
// SIMD vector element-wise addition, e.g. for Intel SSE.
233+
// Scalar addition with overflow flags.
234+
%a = arith.addi %b, %c overflow<nsw, nuw> : i64
235+
236+
// SIMD vector element-wise addition.
216237
%f = arith.addi %g, %h : vector<4xi32>
217238

218239
// Tensor element-wise addition.
@@ -278,21 +299,41 @@ def Arith_AddUIExtendedOp : Arith_Op<"addui_extended", [Pure, Commutative,
278299
// SubIOp
279300
//===----------------------------------------------------------------------===//
280301

281-
def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
302+
def Arith_SubIOp : Arith_IntBinaryOpWithOverflowFlags<"subi"> {
282303
let summary = [{
283304
Integer subtraction operation.
284305
}];
285306
let description = [{
286-
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
287-
bitvectors. The result is represented by a bitvector containing the mathematical
288-
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
289-
integers use a two's complement representation, this operation is applicable on
307+
Performs N-bit subtraction on the operands. The operands are interpreted as unsigned
308+
bitvectors. The result is represented by a bitvector containing the mathematical
309+
value of the subtraction modulo 2^n, where `n` is the bitwidth. Because `arith`
310+
integers use a two's complement representation, this operation is applicable on
290311
both signed and unsigned integer operands.
291312

292313
The `subi` operation takes two operands and returns one result, each of
293-
these is required to be the same type. This type may be an integer scalar type,
294-
a vector whose element type is integer, or a tensor of integers. It has no
295-
standard attributes.
314+
these is required to be the same type. This type may be an integer scalar type,
315+
a vector whose element type is integer, or a tensor of integers.
316+
317+
This op supports `nuw`/`nsw` overflow flags which stands stand for
318+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
319+
`nsw` flags are present, and an unsigned/signed overflow occurs
320+
(respectively), the result is poison.
321+
322+
Example:
323+
324+
```mlir
325+
// Scalar subtraction.
326+
%a = arith.subi %b, %c : i64
327+
328+
// Scalar subtraction with overflow flags.
329+
%a = arith.subi %b, %c overflow<nsw, nuw> : i64
330+
331+
// SIMD vector element-wise subtraction.
332+
%f = arith.subi %g, %h : vector<4xi32>
333+
334+
// Tensor element-wise subtraction.
335+
%x = arith.subi %y, %z : tensor<4x?xi8>
336+
```
296337
}];
297338
let hasFolder = 1;
298339
let hasCanonicalizer = 1;
@@ -302,21 +343,41 @@ def Arith_SubIOp : Arith_TotalIntBinaryOp<"subi"> {
302343
// MulIOp
303344
//===----------------------------------------------------------------------===//
304345

305-
def Arith_MulIOp : Arith_TotalIntBinaryOp<"muli", [Commutative]> {
346+
def Arith_MulIOp : Arith_IntBinaryOpWithOverflowFlags<"muli", [Commutative]> {
306347
let summary = [{
307348
Integer multiplication operation.
308349
}];
309350
let description = [{
310-
Performs N-bit multiplication on the operands. The operands are interpreted as
311-
unsigned bitvectors. The result is represented by a bitvector containing the
312-
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
313-
Because `arith` integers use a two's complement representation, this operation is
351+
Performs N-bit multiplication on the operands. The operands are interpreted as
352+
unsigned bitvectors. The result is represented by a bitvector containing the
353+
mathematical value of the multiplication modulo 2^n, where `n` is the bitwidth.
354+
Because `arith` integers use a two's complement representation, this operation is
314355
applicable on both signed and unsigned integer operands.
315356

316357
The `muli` operation takes two operands and returns one result, each of
317-
these is required to be the same type. This type may be an integer scalar type,
318-
a vector whose element type is integer, or a tensor of integers. It has no
319-
standard attributes.
358+
these is required to be the same type. This type may be an integer scalar type,
359+
a vector whose element type is integer, or a tensor of integers.
360+
361+
This op supports `nuw`/`nsw` overflow flags which stands stand for
362+
"No Unsigned Wrap" and "No Signed Wrap", respectively. If the `nuw` and/or
363+
`nsw` flags are present, and an unsigned/signed overflow occurs
364+
(respectively), the result is poison.
365+
366+
Example:
367+
368+
```mlir
369+
// Scalar multiplication.
370+
%a = arith.muli %b, %c : i64
371+
372+
// Scalar multiplication with overflow flags.
373+
%a = arith.muli %b, %c overflow<nsw, nuw> : i64
374+
375+
// SIMD vector element-wise multiplication.
376+
%f = arith.muli %g, %h : vector<4xi32>
377+
378+
// Tensor element-wise multiplication.
379+
%x = arith.muli %y, %z : tensor<4x?xi8>
380+
```
320381
}];
321382
let hasFolder = 1;
322383
let hasCanonicalizer = 1;

mlir/include/mlir/Dialect/Arith/IR/ArithOpsInterfaces.td

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,61 @@ def ArithFastMathInterface : OpInterface<"ArithFastMathInterface"> {
4949
];
5050
}
5151

52+
def ArithIntegerOverflowFlagsInterface : OpInterface<"ArithIntegerOverflowFlagsInterface"> {
53+
let description = [{
54+
Access to op integer overflow flags.
55+
}];
56+
57+
let cppNamespace = "::mlir::arith";
58+
59+
let methods = [
60+
InterfaceMethod<
61+
/*desc=*/ "Returns an IntegerOverflowFlagsAttr attribute for the operation",
62+
/*returnType=*/ "IntegerOverflowFlagsAttr",
63+
/*methodName=*/ "getOverflowAttr",
64+
/*args=*/ (ins),
65+
/*methodBody=*/ [{}],
66+
/*defaultImpl=*/ [{
67+
auto op = cast<ConcreteOp>(this->getOperation());
68+
return op.getOverflowFlagsAttr();
69+
}]
70+
>,
71+
InterfaceMethod<
72+
/*desc=*/ "Returns whether the operation has the No Unsigned Wrap keyword",
73+
/*returnType=*/ "bool",
74+
/*methodName=*/ "hasNoUnsignedWrap",
75+
/*args=*/ (ins),
76+
/*methodBody=*/ [{}],
77+
/*defaultImpl=*/ [{
78+
auto op = cast<ConcreteOp>(this->getOperation());
79+
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
80+
return bitEnumContainsAll(flags, IntegerOverflowFlags::nuw);
81+
}]
82+
>,
83+
InterfaceMethod<
84+
/*desc=*/ "Returns whether the operation has the No Signed Wrap keyword",
85+
/*returnType=*/ "bool",
86+
/*methodName=*/ "hasNoSignedWrap",
87+
/*args=*/ (ins),
88+
/*methodBody=*/ [{}],
89+
/*defaultImpl=*/ [{
90+
auto op = cast<ConcreteOp>(this->getOperation());
91+
IntegerOverflowFlags flags = op.getOverflowFlagsAttr().getValue();
92+
return bitEnumContainsAll(flags, IntegerOverflowFlags::nsw);
93+
}]
94+
>,
95+
StaticInterfaceMethod<
96+
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
97+
for the operation}],
98+
/*returnType=*/ "StringRef",
99+
/*methodName=*/ "getIntegerOverflowAttrName",
100+
/*args=*/ (ins),
101+
/*methodBody=*/ [{}],
102+
/*defaultImpl=*/ [{
103+
return "overflowFlags";
104+
}]
105+
>
106+
];
107+
}
108+
52109
#endif // ARITH_OPS_INTERFACES

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def IntegerOverflowFlagsInterface : OpInterface<"IntegerOverflowFlagsInterface">
9292
}]
9393
>,
9494
StaticInterfaceMethod<
95-
/*desc=*/ [{Returns the name of the IntegerOveflowFlagsAttr attribute
95+
/*desc=*/ [{Returns the name of the IntegerOverflowFlagsAttr attribute
9696
for the operation}],
9797
/*returnType=*/ "StringRef",
9898
/*methodName=*/ "getIntegerOverflowAttrName",

mlir/lib/Conversion/ArithCommon/AttrToLLVMConverter.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
using namespace mlir;
1212

13-
// Map arithmetic fastmath enum values to LLVMIR enum values.
1413
LLVM::FastmathFlags
1514
mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
1615
LLVM::FastmathFlags llvmFMF{};
@@ -22,17 +21,37 @@ mlir::arith::convertArithFastMathFlagsToLLVM(arith::FastMathFlags arithFMF) {
2221
{arith::FastMathFlags::contract, LLVM::FastmathFlags::contract},
2322
{arith::FastMathFlags::afn, LLVM::FastmathFlags::afn},
2423
{arith::FastMathFlags::reassoc, LLVM::FastmathFlags::reassoc}};
25-
for (auto fmfMap : flags) {
26-
if (bitEnumContainsAny(arithFMF, fmfMap.first))
27-
llvmFMF = llvmFMF | fmfMap.second;
24+
for (auto [arithFlag, llvmFlag] : flags) {
25+
if (bitEnumContainsAny(arithFMF, arithFlag))
26+
llvmFMF = llvmFMF | llvmFlag;
2827
}
2928
return llvmFMF;
3029
}
3130

32-
// Create an LLVM fastmath attribute from a given arithmetic fastmath attribute.
3331
LLVM::FastmathFlagsAttr
3432
mlir::arith::convertArithFastMathAttrToLLVM(arith::FastMathFlagsAttr fmfAttr) {
3533
arith::FastMathFlags arithFMF = fmfAttr.getValue();
3634
return LLVM::FastmathFlagsAttr::get(
3735
fmfAttr.getContext(), convertArithFastMathFlagsToLLVM(arithFMF));
3836
}
37+
38+
LLVM::IntegerOverflowFlags mlir::arith::convertArithOverflowFlagsToLLVM(
39+
arith::IntegerOverflowFlags arithFlags) {
40+
LLVM::IntegerOverflowFlags llvmFlags{};
41+
const std::pair<arith::IntegerOverflowFlags, LLVM::IntegerOverflowFlags>
42+
flags[] = {
43+
{arith::IntegerOverflowFlags::nsw, LLVM::IntegerOverflowFlags::nsw},
44+
{arith::IntegerOverflowFlags::nuw, LLVM::IntegerOverflowFlags::nuw}};
45+
for (auto [arithFlag, llvmFlag] : flags) {
46+
if (bitEnumContainsAny(arithFlags, arithFlag))
47+
llvmFlags = llvmFlags | llvmFlag;
48+
}
49+
return llvmFlags;
50+
}
51+
52+
LLVM::IntegerOverflowFlagsAttr mlir::arith::convertArithOverflowAttrToLLVM(
53+
arith::IntegerOverflowFlagsAttr flagsAttr) {
54+
arith::IntegerOverflowFlags arithFlags = flagsAttr.getValue();
55+
return LLVM::IntegerOverflowFlagsAttr::get(
56+
flagsAttr.getContext(), convertArithOverflowFlagsToLLVM(arithFlags));
57+
}

0 commit comments

Comments
 (0)